diff --git a/Gopkg.lock b/Gopkg.lock index c63559599f7..6c589645817 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -23,6 +23,14 @@ pruneopts = "NUT" revision = "46cf3e2cf1acef7876068f66cf69ec31aad2d0b2" +[[projects]] + branch = "master" + digest = "1:f12358576cd79bba0ae626530d23cde63416744f486c8bc817802c6907eaadd7" + name = "github.com/armon/go-metrics" + packages = ["."] + pruneopts = "NUT" + revision = "f0300d1749da6fa982027e449ec0c7a145510c3c" + [[projects]] branch = "master" digest = "1:707ebe952a8b3d00b343c01536c79c73771d100f63ec6babeaed5c79e2b8a8dd" @@ -31,6 +39,14 @@ pruneopts = "NUT" revision = "3a771d992973f24aa725d07868b467d1ddfceafb" +[[projects]] + digest = "1:a12d94258c5298ead75e142e8001224bf029f302fed9e96cd39c0eaf90f3954d" + name = "github.com/boltdb/bolt" + packages = ["."] + pruneopts = "NUT" + revision = "2f1ce7a837dcb8da3ec595b1dac9d0632f0f99e8" + version = "v1.3.1" + [[projects]] digest = "1:4fb088ed7f384178cfc4552661e280a12ccc93be7f30a1ca994958a61a8e1d13" name = "github.com/bsm/sarama-cluster" @@ -123,7 +139,9 @@ digest = "1:1b3dd24f14a5280710fc7a3aa2480b6e4d20fdfc905841de9a3aa2aa2f1d4ee9" name = "github.com/gogo/protobuf" packages = [ + "gogoproto", "proto", + "protoc-gen-gogo/descriptor", "sortkeys", ] pruneopts = "NUT" @@ -255,6 +273,22 @@ pruneopts = "NUT" revision = "9cad4c3443a7200dd6400aef47183728de563a38" +[[projects]] + digest = "1:1cf16b098a70d6c02899608abbb567296d11c7b830635014dfe6124a02dc1369" + name = "github.com/hashicorp/go-immutable-radix" + packages = ["."] + pruneopts = "NUT" + revision = "27df80928bb34bb1b0d6d0e01b9e679902e7a6b5" + version = "v1.0.0" + +[[projects]] + branch = "master" + digest = "1:caf220d32af01b3899cfa5d37bdc9cf41c424519e74a6c74d9aec00c45f07adc" + name = "github.com/hashicorp/go-msgpack" + packages = ["codec"] + pruneopts = "NUT" + revision = "fa3f63826f7c23912c15263591e65d54d080b458" + [[projects]] branch = "master" digest = "1:13e2fa5735a82a5fb044f290cfd0dba633d1c5e516b27da0509e0dbb3515a18e" @@ -266,6 +300,14 @@ pruneopts = "NUT" revision = "0fb14efe8c47ae851c0034ed7a448854d3d34cf3" +[[projects]] + digest = "1:544ff85fc54fe5c7ed27b1d292ef82253d4ef7a4b854696774ee61049b0cbcdc" + name = "github.com/hashicorp/raft" + packages = ["."] + pruneopts = "NUT" + revision = "6d14f0c70869faabd9e60ba7ed88a6cbbd6a661f" + version = "v1.0.0" + [[projects]] digest = "1:9a52adf44086cead3b384e5d0dbf7a1c1cce65e67552ee3383a8561c42a18cd3" name = "github.com/imdario/mergo" @@ -400,6 +442,65 @@ revision = "1df9eeb2bb81f327b96228865c5687bc2194af3f" version = "1.0.0" +[[projects]] + digest = "1:66011d8367095e2e5a5f0542369e7613597f4ddf80bb725ea273d4ab38bb8cf4" + name = "github.com/nats-io/gnatsd" + packages = [ + "conf", + "logger", + "server", + "server/pse", + "util", + ] + pruneopts = "NUT" + revision = "eed4fbc1458ce110ad1aa1adf904229bc8fda2a7" + version = "v1.3.0" + +[[projects]] + digest = "1:7b5fe86d83990fd594ee51261e836a7d5f230b9ec4e3d6f7bb6aaff97b72fa75" + name = "github.com/nats-io/go-nats" + packages = [ + ".", + "encoders/builtin", + "util", + ] + pruneopts = "NUT" + revision = "fb0396ee0bdb8018b0fef30d6d1de798ce99cd05" + version = "v1.6.0" + +[[projects]] + digest = "1:2b28a7a408f27fd163e2c661cece969efa3ee8f5331343586fce3d11d2d27d77" + name = "github.com/nats-io/go-nats-streaming" + packages = [ + ".", + "pb", + ] + pruneopts = "NUT" + revision = "e15a53f85e4932540600a16b56f6c4f65f58176f" + version = "v0.4.0" + +[[projects]] + digest = "1:b2c27be5f066a76a3e0d76e1736c125b39bb6bbd2f3b27f2643e37fe080ac089" + name = "github.com/nats-io/nats-streaming-server" + packages = [ + "logger", + "server", + "spb", + "stores", + "util", + ] + pruneopts = "NUT" + revision = "8910c0c347bc51cc87227aeb27bef19d409bf5c2" + version = "v0.11.0" + +[[projects]] + digest = "1:552100985450eaa4a8fa8d18fe8b6b1a04997c4d2ac24f154cc9471430cc4e1c" + name = "github.com/nats-io/nuid" + packages = ["."] + pruneopts = "NUT" + revision = "289cccf02c178dc782430d534e3c1f5b72af807f" + version = "v1.0.0" + [[projects]] digest = "1:93b1d84c5fa6d1ea52f4114c37714cddd84d5b78f151b62bb101128dd51399bf" name = "github.com/pborman/uuid" @@ -435,6 +536,14 @@ revision = "6b9367c9ff401dbc54fabce3fb8d972e799b702d" version = "v2.0.2" +[[projects]] + digest = "1:0028cb19b2e4c3112225cd871870f2d9cf49b9b4276531f03438a88e94be86fe" + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + pruneopts = "NUT" + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + [[projects]] digest = "1:03bca087b180bf24c4f9060775f137775550a0834e18f0bca0520a868679dbd7" name = "github.com/prometheus/client_golang" @@ -551,7 +660,11 @@ branch = "master" digest = "1:3f3a05ae0b95893d90b9b3b5afdb79a9b3d96e4e36e099d841ae602e4aca0da8" name = "golang.org/x/crypto" - packages = ["ssh/terminal"] + packages = [ + "bcrypt", + "blowfish", + "ssh/terminal", + ] pruneopts = "NUT" revision = "5ba7f63082460102a45837dbd1827e10f9479ac0" @@ -603,6 +716,11 @@ packages = [ "unix", "windows", + "windows/registry", + "windows/svc", + "windows/svc/debug", + "windows/svc/eventlog", + "windows/svc/mgr", ] pruneopts = "NUT" revision = "c11f84a56e43e20a78cee75a7c034031ecf57d1f" @@ -1116,6 +1234,8 @@ "github.com/knative/serving/pkg/client/clientset/versioned", "github.com/knative/serving/pkg/client/clientset/versioned/typed/serving/v1alpha1", "github.com/knative/test-infra", + "github.com/nats-io/go-nats-streaming", + "github.com/nats-io/nats-streaming-server/server", "github.com/prometheus/client_golang/prometheus/promhttp", "go.opencensus.io/trace", "go.uber.org/atomic", diff --git a/Gopkg.toml b/Gopkg.toml index a03659e1eb1..314b36e2fa0 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -96,3 +96,15 @@ required = [ name = "sigs.k8s.io/controller-runtime" # HEAD as of 2018-09-19 revision = "5373e8e1f3188ff4266902a6fc86372bc14b3815" + +[[override]] + name = "github.com/nats-io/go-nats" + version = "1.6.0" + +[[override]] + name = "github.com/nats-io/go-nats-streaming" + version = "0.4.0" + +[[override]] + name = "github.com/nats-io/nats-streaming-server" + version = "0.11.0" diff --git a/config/buses/natss/100-natss.yaml b/config/buses/natss/100-natss.yaml new file mode 100644 index 00000000000..06e5c995fc8 --- /dev/null +++ b/config/buses/natss/100-natss.yaml @@ -0,0 +1,101 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + labels: + app: nats-streaming + name: nats-streaming + namespace: knative-eventing +data: + gnatsd.conf: | + # configuration file used to override default NATS server settings + stan.conf: | + # content of configuration file used to override default NATS Streaming server settings +--- +apiVersion: v1 +kind: Service +metadata: + name: nats-streaming + namespace: knative-eventing + labels: + app: nats-streaming +spec: + type: ClusterIP + ports: + - name: client + port: 4222 + protocol: TCP + targetPort: client + selector: + app: nats-streaming + sessionAffinity: None +--- +apiVersion: apps/v1beta1 +kind: StatefulSet +metadata: + name: nats-streaming + namespace: knative-eventing + labels: + app: nats-streaming +spec: + serviceName: nats-streaming + replicas: 1 + selector: + matchLabels: + app: nats-streaming + template: + metadata: + annotations: + sidecar.istio.io/inject: "true" + labels: + app: nats-streaming + spec: + containers: + - name: nats-streaming + image: nats-streaming:0.11.0 + imagePullPolicy: IfNotPresent + args: + - -D + - -SD + - --cluster_id=knative-nats-streaming + - --http_port=8222 + - --max_age=24h + - --store=FILE + - --dir=/var/lib/nats-streaming/core-nats-streaming/$(POD_NAME) + - --port=4222 + - --config=/etc/nats-streaming/core-nats-streaming/gnatsd.conf + - --stan_config=/etc/nats-streaming/core-nats-streaming/stan.conf + env: + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + ports: + - containerPort: 4222 + name: client + protocol: TCP + - containerPort: 8222 + name: monitoring + protocol: TCP + volumeMounts: + - mountPath: /var/lib/nats-streaming/core-nats-streaming + name: datadir + - mountPath: /etc/nats-streaming/core-nats-streaming + name: config-volume + resources: + requests: + cpu: "100m" + limits: + memory: "32M" + volumes: + - configMap: + name: nats-streaming + name: config-volume + volumeClaimTemplates: + - metadata: + name: datadir + spec: + accessModes: + - "ReadWriteOnce" + resources: + requests: + storage: "1Gi" diff --git a/config/buses/natss/200-natss-bus.yaml b/config/buses/natss/200-natss-bus.yaml new file mode 100644 index 00000000000..9aac210db29 --- /dev/null +++ b/config/buses/natss/200-natss-bus.yaml @@ -0,0 +1,20 @@ +apiVersion: channels.knative.dev/v1alpha1 +kind: Bus +metadata: + name: natss +spec: + provisioner: + name: provisioner + image: github.com/knative/eventing/pkg/buses/natss/provisioner + args: [ + "-logtostderr", + "-stderrthreshold", "INFO", + ] + dispatcher: + name: dispatcher + image: github.com/knative/eventing/pkg/buses/natss/dispatcher + args: [ + "-logtostderr", + "-stderrthreshold", "INFO", + ] + diff --git a/config/buses/natss/README.md b/config/buses/natss/README.md new file mode 100644 index 00000000000..0f19fe94c84 --- /dev/null +++ b/config/buses/natss/README.md @@ -0,0 +1,9 @@ +# NATS Streaming - Knative Bus + +Deployment steps: +1. Setup [Knative Eventing](../../../DEVELOPMENT.md) +1. Apply the 'natss' bus: +```ko apply -f config/buses/natss/``` +1. Create Channels that reference the 'natss' bus + +The NATSS Streaming bus uses NATS Streaming based on a simple setup, see [Natss Streaming](./100-natss.yaml) . diff --git a/pkg/buses/natss/bus.go b/pkg/buses/natss/bus.go new file mode 100644 index 00000000000..6758131174a --- /dev/null +++ b/pkg/buses/natss/bus.go @@ -0,0 +1,178 @@ +/* + * Copyright 2018 The Knative Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package natss + +import ( + "go.uber.org/zap" + "time" + + "github.com/knative/eventing/pkg/buses" + stanutil "github.com/knative/eventing/pkg/buses/natss/stanutil" + stan "github.com/nats-io/go-nats-streaming" +) + +// BusType is the type of the stub bus +const ( + BusType = "natss" + NatssUrl = "nats://nats-streaming.knative-eventing.svc.cluster.local:4222" +) + +type NatssBus struct { + natssUrl string + subscribers map[string]*stan.Subscription + + ref buses.BusReference + dispatcher buses.BusDispatcher + dispEventHandler buses.EventHandlerFuncs + provisioner buses.BusProvisioner + provEventHandler buses.EventHandlerFuncs + + logger *zap.SugaredLogger +} + +type SetupNatssBus func(*NatssBus) error + +var ( + natsConn *stan.Conn +) + +func NewNatssBusProvisioner(ref buses.BusReference, setup SetupNatssBus) (*NatssBus, error) { + bus := &NatssBus{ + ref: ref, + } + bus.provEventHandler = buses.EventHandlerFuncs{ + ProvisionFunc: func(channel buses.ChannelReference, parameters buses.ResolvedParameters) error { + bus.logger.Infof("Provision channel %q", channel.Name) + bus.logger.Infof("channel=%+v; parameters=%+v", channel, parameters) + return nil + }, + UnprovisionFunc: func(channel buses.ChannelReference) error { + bus.logger.Infof("Unprovision channel %q", channel.Name) + bus.logger.Infof("channel=%+v", channel) + return nil + }, + } + setup(bus) + return bus, nil +} + +func NewNatssBusDispatcher(ref buses.BusReference, setup SetupNatssBus) (*NatssBus, error) { + bus := &NatssBus{ + ref: ref, + subscribers: make(map[string]*stan.Subscription), + } + bus.dispEventHandler = buses.EventHandlerFuncs{ + ReceiveMessageFunc: func(channel buses.ChannelReference, message *buses.Message) error { + bus.logger.Infof("Recieved message from %q channel", channel.String()) + if err := stanutil.Publish(natsConn, channel.Name, &message.Payload, bus.logger); err != nil { + bus.logger.Errorf("Error during publish: %+v", err) + return err + } + bus.logger.Infof("Published [%s] : '%s'", channel.String(), message) + return nil + }, + SubscribeFunc: bus.subscribe, + UnsubscribeFunc: bus.unsubscribe, + } + setup(bus) + return bus, nil +} + +func (b *NatssBus) Run(threadness int, stopCh <-chan struct{}, clientId string) { + b.logger.Infof("try to connect to NATSS from %q", clientId) + var err error + for i := 0; i < 60; i++ { + if natsConn, err = stanutil.Connect("knative-nats-streaming", clientId, b.natssUrl, b.logger); err != nil { + b.logger.Errorf(" Create new connection failed: %+v", err) + time.Sleep(1 * time.Second) + } else { + break + } + } + if err != nil { + b.logger.Errorf(" Create new connection failed: %+v", err) + return + } + b.logger.Info("connection to NATSS established, natsConn=%+v", natsConn) + + if b.dispatcher != nil { + b.dispatcher.Run(threadness, stopCh) + } + if b.provisioner != nil { + b.provisioner.Run(threadness, stopCh) + } +} + +func SetNewBusProvisioner(opts *buses.BusOpts) SetupNatssBus { + return func(b *NatssBus) error { + b.natssUrl = NatssUrl + b.provisioner = buses.NewBusProvisioner(b.ref, b.provEventHandler, opts) + b.logger = opts.Logger + return nil + } +} + +func SetNewBusDispatcher(opts *buses.BusOpts) SetupNatssBus { + return func(b *NatssBus) error { + b.natssUrl = NatssUrl + b.dispatcher = buses.NewBusDispatcher(b.ref, b.dispEventHandler, opts) + b.logger = opts.Logger + return nil + } +} + +func (bus *NatssBus) subscribe(channel buses.ChannelReference, subscription buses.SubscriptionReference, parameters buses.ResolvedParameters) error { + bus.logger.Infof("Subscribe %q to %q channel", subscription.Name, channel.Name) + + mcb := func(msg *stan.Msg) { + bus.logger.Infof("NATSS message received: %+v", msg) + message := buses.Message{ + Headers: map[string]string{}, + Payload: []byte(msg.Data), + } + if err := bus.dispatcher.DispatchMessage(subscription, &message); err != nil { + bus.logger.Warnf("Failed to dispatch message: %v", err) + return + } + if err := msg.Ack(); err != nil { + bus.logger.Warnf("Failed to acknowledge message: %v", err) + } + } + // subscribe to a NATSS subject + if natsStreamingSub, err := (*natsConn).Subscribe(channel.Name, mcb, stan.DurableName(subscription.Name), stan.SetManualAckMode(), stan.AckWait(1*time.Minute)); err != nil { + bus.logger.Errorf(" Create new NATSS Subscription failed: %+v", err) + return err + } else { + bus.logger.Infof("NATSS Subscription created: %+v", natsStreamingSub) + bus.subscribers[subscription.Name] = &natsStreamingSub + } + return nil +} + +func (bus *NatssBus) unsubscribe(channel buses.ChannelReference, subscription buses.SubscriptionReference) error { + bus.logger.Infof("Unsubscribe %q from %q channel", subscription.Name, channel.Name) + + // unsubscribe from a NATSS subject + if natsStreamingSub, ok := bus.subscribers[subscription.Name]; ok { + if err := (*natsStreamingSub).Unsubscribe(); err != nil { + bus.logger.Errorf(" Unsubscribe() failed: %+v", err) + return err + } + delete(bus.subscribers, subscription.Name) + } + return nil +} diff --git a/pkg/buses/natss/bus_test.go b/pkg/buses/natss/bus_test.go new file mode 100644 index 00000000000..372e6011429 --- /dev/null +++ b/pkg/buses/natss/bus_test.go @@ -0,0 +1,177 @@ +/* +Copyright 2018 The Knative Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package natss + +import ( + "github.com/google/go-cmp/cmp" + "github.com/knative/eventing/pkg/buses" + "github.com/nats-io/nats-streaming-server/server" + "go.uber.org/zap" + "os" + "testing" + "time" +) + +const ( + clusterId = "knative-nats-streaming" + natssTestUrl = "nats://localhost:4222" + messagePayload = "test-message-1" +) + +var ( + logger *zap.SugaredLogger + done = make(chan string) +) + +func TestMain(m *testing.M) { + logger = buses.NewBusLoggerFromConfig(buses.NewLoggingConfig()) + defer logger.Sync() + + stanServer, err := startNatss() + if err != nil { + panic(err) + } + defer stopNatss(stanServer) + + retCode := m.Run() + + os.Exit(retCode) +} + +func TestNatss(t *testing.T) { + ref := buses.NewBusReferenceFromNames("NATSS_TEST_NAME", "NATSS_TEST_NAMESPACE") + opts := &buses.BusOpts{ + Logger: logger, + KubeConfig: "", + MasterURL: "", + } + + busProv, err := NewNatssBusProvisioner(ref, setNewTestBusProvisioner(t, opts)) + if err != nil { + t.Fatalf("Unexpected error from NewNatssBusProvisioner: %v", err) + } + + busDisp, err := NewNatssBusDispatcher(ref, setNewTestBusDispatcher(t, opts)) + if err != nil { + t.Fatalf("Unexpected error from NewNatssBusDispatcher: %v", err) + } + + stopCh := make(chan struct{}) + defer close(stopCh) + busProv.Run(1, stopCh, "test-provisioner-natss") + busDisp.Run(1, stopCh, "test-dispatcher-natss") + + // create a dummy channel + channel := buses.ChannelReference{"default", "test-channel"} + busProv.provEventHandler.ProvisionFunc(channel, make(buses.ResolvedParameters)) + + // create a dummy subscription + subscription := buses.SubscriptionReference{"default", "test-subscription"} + busDisp.dispEventHandler.SubscribeFunc(channel, subscription, make(buses.ResolvedParameters)) + + // send a message .... + message := buses.Message{make(map[string]string), []byte(messagePayload)} + busDisp.dispEventHandler.ReceiveMessageFunc(channel, &message) + + // wait for subscriber to respond + select { + case payload := <-done: + logger.Info("Subscriber finished") + if diff := cmp.Diff(messagePayload, payload); diff != "" { + t.Errorf("Unexpected message payload (-want, +got): %v", diff) + } + case <-time.After(5 * time.Second): + t.Error("Subscriber timeout") + } + + // unsubscribe + busDisp.dispEventHandler.UnsubscribeFunc(channel, subscription) +} + +func startNatss() (*server.StanServer, error) { + var err error + var stanServer *server.StanServer + for i := 0; i < 10; i++ { + if stanServer, err = server.RunServer(clusterId); err != nil { + logger.Errorf("Start NATSS failed: %+v", err) + time.Sleep(1 * time.Second) + } else { + break + } + } + if err != nil { + return nil, err + } + return stanServer, nil +} + +func stopNatss(server *server.StanServer) { + server.Shutdown() +} + +// set dummy provisioner +func setNewTestBusProvisioner(t *testing.T, opts *buses.BusOpts) SetupNatssBus { + return func(b *NatssBus) error { + b.natssUrl = natssTestUrl + b.provisioner = newDummyProvisioner(opts.Logger) + b.logger = opts.Logger + return nil + } +} + +func newDummyProvisioner(logger *zap.SugaredLogger) buses.BusProvisioner { + var b = &dummyBusProvisioner{ + logger: logger, + } + return b +} + +type dummyBusProvisioner struct { + logger *zap.SugaredLogger +} + +func (m *dummyBusProvisioner) Run(threadiness int, stopCh <-chan struct{}) {} + +//set dummy dispatcher +func setNewTestBusDispatcher(t *testing.T, opts *buses.BusOpts) SetupNatssBus { + return func(b *NatssBus) error { + b.natssUrl = natssTestUrl + b.dispatcher = newDummyDispatcher(opts.Logger) + b.logger = opts.Logger + return nil + } +} + +func newDummyDispatcher(logger *zap.SugaredLogger) buses.BusDispatcher { + var b = &dummyBusDispatcher{ + logger: logger, + } + return b +} + +type dummyBusDispatcher struct { + logger *zap.SugaredLogger +} + +func (m *dummyBusDispatcher) Run(threadiness int, stopCh <-chan struct{}) {} + +func (m *dummyBusDispatcher) DispatchMessage(subscription buses.SubscriptionReference, message *buses.Message) error { + payload := string(message.Payload) + m.logger.Infof("dummyBusDispatcher() - received message: %s", payload) + done <- payload + return nil +} diff --git a/pkg/buses/natss/dispatcher/main.go b/pkg/buses/natss/dispatcher/main.go new file mode 100644 index 00000000000..129ba1ff24e --- /dev/null +++ b/pkg/buses/natss/dispatcher/main.go @@ -0,0 +1,64 @@ +/* + * Copyright 2018 The Knative Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "flag" + "os" + + "github.com/knative/eventing/pkg/buses" + "github.com/knative/eventing/pkg/buses/natss" + "github.com/knative/pkg/signals" + "go.uber.org/zap" +) + +const ( + threadsPerReconciler = 1 +) + +func main() { + ref := buses.NewBusReferenceFromNames( + os.Getenv("BUS_NAME"), + os.Getenv("BUS_NAMESPACE"), + ) + + config := buses.NewLoggingConfig() + logger := buses.NewBusLoggerFromConfig(config) + defer logger.Sync() + logger = logger.With( + zap.String("channels.knative.dev/bus", ref.String()), + zap.String("channels.knative.dev/busType", natss.BusType), + zap.String("channels.knative.dev/busComponent", buses.Dispatcher), + ) + + opts := &buses.BusOpts{ + Logger: logger, + } + + flag.StringVar(&opts.KubeConfig, "kubeconfig", "", "Path to a kubeconfig. Only required if out-of-cluster.") + flag.StringVar(&opts.MasterURL, "master", "", "The address of the Kubernetes API server. Overrides any value in kubeconfig. Only required if out-of-cluster.") + flag.Parse() + + bus, err := natss.NewNatssBusDispatcher(ref, natss.SetNewBusDispatcher(opts)) + if err != nil { + logger.Fatalf("Error starting natss bus dispatcher: %v", err) + } + + // set up signals so we handle the first shutdown signal gracefully + stopCh := signals.SetupSignalHandler() + bus.Run(threadsPerReconciler, stopCh, "knative-dispatcher-natss") +} diff --git a/pkg/buses/natss/provisioner/main.go b/pkg/buses/natss/provisioner/main.go new file mode 100644 index 00000000000..3ca78949e17 --- /dev/null +++ b/pkg/buses/natss/provisioner/main.go @@ -0,0 +1,64 @@ +/* + * Copyright 2018 The Knative Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "flag" + "os" + + "github.com/knative/eventing/pkg/buses" + "github.com/knative/eventing/pkg/buses/natss" + "github.com/knative/pkg/signals" + "go.uber.org/zap" +) + +const ( + threadsPerReconciler = 1 +) + +func main() { + ref := buses.NewBusReferenceFromNames( + os.Getenv("BUS_NAME"), + os.Getenv("BUS_NAMESPACE"), + ) + + config := buses.NewLoggingConfig() + logger := buses.NewBusLoggerFromConfig(config) + defer logger.Sync() + logger = logger.With( + zap.String("channels.knative.dev/bus", ref.String()), + zap.String("channels.knative.dev/busType", natss.BusType), + zap.String("channels.knative.dev/busComponent", buses.Provisioner), + ) + + opts := &buses.BusOpts{ + Logger: logger, + } + + flag.StringVar(&opts.KubeConfig, "kubeconfig", "", "Path to a kubeconfig. Only required if out-of-cluster.") + flag.StringVar(&opts.MasterURL, "master", "", "The address of the Kubernetes API server. Overrides any value in kubeconfig. Only required if out-of-cluster.") + flag.Parse() + + bus, err := natss.NewNatssBusProvisioner(ref, natss.SetNewBusProvisioner(opts)) + if err != nil { + logger.Fatalf("Error starting natss bus provisioner: %v", err) + } + + // set up signals so we handle the first shutdown signal gracefully + stopCh := signals.SetupSignalHandler() + bus.Run(threadsPerReconciler, stopCh, "knative-provisioner-natss") +} diff --git a/pkg/buses/natss/stanutil/stanutil.go b/pkg/buses/natss/stanutil/stanutil.go new file mode 100644 index 00000000000..de05fb68269 --- /dev/null +++ b/pkg/buses/natss/stanutil/stanutil.go @@ -0,0 +1,74 @@ +/* + * Copyright 2018 The Knative Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package stanutil + +import ( + "errors" + "fmt" + "go.uber.org/zap" + + stan "github.com/nats-io/go-nats-streaming" +) + +// Connect creates a new NATS-Streaming connection +func Connect(clusterID string, clientID string, natsURL string, logger *zap.SugaredLogger) (*stan.Conn, error) { + sc, err := stan.Connect(clusterID, clientID, stan.NatsURL(natsURL)) + if err != nil { + logger.Errorf("Can't connect to: %s ; error: %v; NATS URL: %s", clusterID, err, natsURL) + } + return &sc, err +} + +// Close must be the last call to close the connection +func Close(sc *stan.Conn, logger *zap.SugaredLogger) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("recovered from: %v", r) + logger.Errorf("Close(): %v", err.Error()) + } + }() + + if sc == nil { + err = errors.New("can't close empty stan connection") + return + } + err = (*sc).Close() + if err != nil { + logger.Errorf("Can't close connection: %+v", err) + } + return +} + +// Publish a message to a subject +func Publish(sc *stan.Conn, subj string, msg *[]byte, logger *zap.SugaredLogger) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("recovered from: %v", r) + logger.Errorf("Publish(): %v", err.Error()) + } + }() + + if sc == nil { + err = errors.New("cant'publish on empty stan connection") + return + } + err = (*sc).Publish(subj, *msg) + if err != nil { + logger.Errorf("Error during publish: %v\n", err) + } + return +} diff --git a/pkg/buses/natss/stanutil/stanutil_test.go b/pkg/buses/natss/stanutil/stanutil_test.go new file mode 100644 index 00000000000..510a13c318f --- /dev/null +++ b/pkg/buses/natss/stanutil/stanutil_test.go @@ -0,0 +1,73 @@ +package stanutil + +import ( + "github.com/knative/eventing/pkg/buses" + "github.com/nats-io/nats-streaming-server/server" + "go.uber.org/zap" + "os" + "testing" + "time" +) + +const ( + clusterId = "knative-eventing" + clientId = "testClient" + natssUrl = "nats://localhost:4222" +) + +var ( + logger *zap.SugaredLogger +) + +func TestMain(m *testing.M) { + logger = buses.NewBusLoggerFromConfig(buses.NewLoggingConfig()) + defer logger.Sync() + + stanServer, err := startNatss() + if err != nil { + panic(err) + } + defer stopNatss(stanServer) + + retCode := m.Run() + + os.Exit(retCode) +} + +func TestConnectPublishClose(t *testing.T) { + // connect + natssConn, err := Connect(clusterId, clientId, natssUrl, logger) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + defer Close(natssConn, logger) + logger.Infof("natssConn: %v", natssConn) + + //publish + msg := []byte("testMessage") + err = Publish(natssConn, "testTopic", &msg, logger) + if err != nil { + t.Errorf("Publish failed: %v", err) + } +} + +func startNatss() (*server.StanServer, error) { + var err error + var stanServer *server.StanServer + for i := 0; i < 10; i++ { + if stanServer, err = server.RunServer(clusterId); err != nil { + logger.Errorf("Start NATSS failed: %+v", err) + time.Sleep(1 * time.Second) + } else { + break + } + } + if err != nil { + return nil, err + } + return stanServer, nil +} + +func stopNatss(server *server.StanServer) { + server.Shutdown() +} diff --git a/vendor/github.com/armon/go-metrics/LICENSE b/vendor/github.com/armon/go-metrics/LICENSE new file mode 100644 index 00000000000..106569e542b --- /dev/null +++ b/vendor/github.com/armon/go-metrics/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2013 Armon Dadgar + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/armon/go-metrics/const_unix.go b/vendor/github.com/armon/go-metrics/const_unix.go new file mode 100644 index 00000000000..31098dd57e5 --- /dev/null +++ b/vendor/github.com/armon/go-metrics/const_unix.go @@ -0,0 +1,12 @@ +// +build !windows + +package metrics + +import ( + "syscall" +) + +const ( + // DefaultSignal is used with DefaultInmemSignal + DefaultSignal = syscall.SIGUSR1 +) diff --git a/vendor/github.com/armon/go-metrics/const_windows.go b/vendor/github.com/armon/go-metrics/const_windows.go new file mode 100644 index 00000000000..38136af3e42 --- /dev/null +++ b/vendor/github.com/armon/go-metrics/const_windows.go @@ -0,0 +1,13 @@ +// +build windows + +package metrics + +import ( + "syscall" +) + +const ( + // DefaultSignal is used with DefaultInmemSignal + // Windows has no SIGUSR1, use SIGBREAK + DefaultSignal = syscall.Signal(21) +) diff --git a/vendor/github.com/armon/go-metrics/inmem.go b/vendor/github.com/armon/go-metrics/inmem.go new file mode 100644 index 00000000000..4e2d6a709e2 --- /dev/null +++ b/vendor/github.com/armon/go-metrics/inmem.go @@ -0,0 +1,348 @@ +package metrics + +import ( + "bytes" + "fmt" + "math" + "net/url" + "strings" + "sync" + "time" +) + +// InmemSink provides a MetricSink that does in-memory aggregation +// without sending metrics over a network. It can be embedded within +// an application to provide profiling information. +type InmemSink struct { + // How long is each aggregation interval + interval time.Duration + + // Retain controls how many metrics interval we keep + retain time.Duration + + // maxIntervals is the maximum length of intervals. + // It is retain / interval. + maxIntervals int + + // intervals is a slice of the retained intervals + intervals []*IntervalMetrics + intervalLock sync.RWMutex + + rateDenom float64 +} + +// IntervalMetrics stores the aggregated metrics +// for a specific interval +type IntervalMetrics struct { + sync.RWMutex + + // The start time of the interval + Interval time.Time + + // Gauges maps the key to the last set value + Gauges map[string]GaugeValue + + // Points maps the string to the list of emitted values + // from EmitKey + Points map[string][]float32 + + // Counters maps the string key to a sum of the counter + // values + Counters map[string]SampledValue + + // Samples maps the key to an AggregateSample, + // which has the rolled up view of a sample + Samples map[string]SampledValue +} + +// NewIntervalMetrics creates a new IntervalMetrics for a given interval +func NewIntervalMetrics(intv time.Time) *IntervalMetrics { + return &IntervalMetrics{ + Interval: intv, + Gauges: make(map[string]GaugeValue), + Points: make(map[string][]float32), + Counters: make(map[string]SampledValue), + Samples: make(map[string]SampledValue), + } +} + +// AggregateSample is used to hold aggregate metrics +// about a sample +type AggregateSample struct { + Count int // The count of emitted pairs + Rate float64 // The values rate per time unit (usually 1 second) + Sum float64 // The sum of values + SumSq float64 `json:"-"` // The sum of squared values + Min float64 // Minimum value + Max float64 // Maximum value + LastUpdated time.Time `json:"-"` // When value was last updated +} + +// Computes a Stddev of the values +func (a *AggregateSample) Stddev() float64 { + num := (float64(a.Count) * a.SumSq) - math.Pow(a.Sum, 2) + div := float64(a.Count * (a.Count - 1)) + if div == 0 { + return 0 + } + return math.Sqrt(num / div) +} + +// Computes a mean of the values +func (a *AggregateSample) Mean() float64 { + if a.Count == 0 { + return 0 + } + return a.Sum / float64(a.Count) +} + +// Ingest is used to update a sample +func (a *AggregateSample) Ingest(v float64, rateDenom float64) { + a.Count++ + a.Sum += v + a.SumSq += (v * v) + if v < a.Min || a.Count == 1 { + a.Min = v + } + if v > a.Max || a.Count == 1 { + a.Max = v + } + a.Rate = float64(a.Sum) / rateDenom + a.LastUpdated = time.Now() +} + +func (a *AggregateSample) String() string { + if a.Count == 0 { + return "Count: 0" + } else if a.Stddev() == 0 { + return fmt.Sprintf("Count: %d Sum: %0.3f LastUpdated: %s", a.Count, a.Sum, a.LastUpdated) + } else { + return fmt.Sprintf("Count: %d Min: %0.3f Mean: %0.3f Max: %0.3f Stddev: %0.3f Sum: %0.3f LastUpdated: %s", + a.Count, a.Min, a.Mean(), a.Max, a.Stddev(), a.Sum, a.LastUpdated) + } +} + +// NewInmemSinkFromURL creates an InmemSink from a URL. It is used +// (and tested) from NewMetricSinkFromURL. +func NewInmemSinkFromURL(u *url.URL) (MetricSink, error) { + params := u.Query() + + interval, err := time.ParseDuration(params.Get("interval")) + if err != nil { + return nil, fmt.Errorf("Bad 'interval' param: %s", err) + } + + retain, err := time.ParseDuration(params.Get("retain")) + if err != nil { + return nil, fmt.Errorf("Bad 'retain' param: %s", err) + } + + return NewInmemSink(interval, retain), nil +} + +// NewInmemSink is used to construct a new in-memory sink. +// Uses an aggregation interval and maximum retention period. +func NewInmemSink(interval, retain time.Duration) *InmemSink { + rateTimeUnit := time.Second + i := &InmemSink{ + interval: interval, + retain: retain, + maxIntervals: int(retain / interval), + rateDenom: float64(interval.Nanoseconds()) / float64(rateTimeUnit.Nanoseconds()), + } + i.intervals = make([]*IntervalMetrics, 0, i.maxIntervals) + return i +} + +func (i *InmemSink) SetGauge(key []string, val float32) { + i.SetGaugeWithLabels(key, val, nil) +} + +func (i *InmemSink) SetGaugeWithLabels(key []string, val float32, labels []Label) { + k, name := i.flattenKeyLabels(key, labels) + intv := i.getInterval() + + intv.Lock() + defer intv.Unlock() + intv.Gauges[k] = GaugeValue{Name: name, Value: val, Labels: labels} +} + +func (i *InmemSink) EmitKey(key []string, val float32) { + k := i.flattenKey(key) + intv := i.getInterval() + + intv.Lock() + defer intv.Unlock() + vals := intv.Points[k] + intv.Points[k] = append(vals, val) +} + +func (i *InmemSink) IncrCounter(key []string, val float32) { + i.IncrCounterWithLabels(key, val, nil) +} + +func (i *InmemSink) IncrCounterWithLabels(key []string, val float32, labels []Label) { + k, name := i.flattenKeyLabels(key, labels) + intv := i.getInterval() + + intv.Lock() + defer intv.Unlock() + + agg, ok := intv.Counters[k] + if !ok { + agg = SampledValue{ + Name: name, + AggregateSample: &AggregateSample{}, + Labels: labels, + } + intv.Counters[k] = agg + } + agg.Ingest(float64(val), i.rateDenom) +} + +func (i *InmemSink) AddSample(key []string, val float32) { + i.AddSampleWithLabels(key, val, nil) +} + +func (i *InmemSink) AddSampleWithLabels(key []string, val float32, labels []Label) { + k, name := i.flattenKeyLabels(key, labels) + intv := i.getInterval() + + intv.Lock() + defer intv.Unlock() + + agg, ok := intv.Samples[k] + if !ok { + agg = SampledValue{ + Name: name, + AggregateSample: &AggregateSample{}, + Labels: labels, + } + intv.Samples[k] = agg + } + agg.Ingest(float64(val), i.rateDenom) +} + +// Data is used to retrieve all the aggregated metrics +// Intervals may be in use, and a read lock should be acquired +func (i *InmemSink) Data() []*IntervalMetrics { + // Get the current interval, forces creation + i.getInterval() + + i.intervalLock.RLock() + defer i.intervalLock.RUnlock() + + n := len(i.intervals) + intervals := make([]*IntervalMetrics, n) + + copy(intervals[:n-1], i.intervals[:n-1]) + current := i.intervals[n-1] + + // make its own copy for current interval + intervals[n-1] = &IntervalMetrics{} + copyCurrent := intervals[n-1] + current.RLock() + *copyCurrent = *current + + copyCurrent.Gauges = make(map[string]GaugeValue, len(current.Gauges)) + for k, v := range current.Gauges { + copyCurrent.Gauges[k] = v + } + // saved values will be not change, just copy its link + copyCurrent.Points = make(map[string][]float32, len(current.Points)) + for k, v := range current.Points { + copyCurrent.Points[k] = v + } + copyCurrent.Counters = make(map[string]SampledValue, len(current.Counters)) + for k, v := range current.Counters { + copyCurrent.Counters[k] = v + } + copyCurrent.Samples = make(map[string]SampledValue, len(current.Samples)) + for k, v := range current.Samples { + copyCurrent.Samples[k] = v + } + current.RUnlock() + + return intervals +} + +func (i *InmemSink) getExistingInterval(intv time.Time) *IntervalMetrics { + i.intervalLock.RLock() + defer i.intervalLock.RUnlock() + + n := len(i.intervals) + if n > 0 && i.intervals[n-1].Interval == intv { + return i.intervals[n-1] + } + return nil +} + +func (i *InmemSink) createInterval(intv time.Time) *IntervalMetrics { + i.intervalLock.Lock() + defer i.intervalLock.Unlock() + + // Check for an existing interval + n := len(i.intervals) + if n > 0 && i.intervals[n-1].Interval == intv { + return i.intervals[n-1] + } + + // Add the current interval + current := NewIntervalMetrics(intv) + i.intervals = append(i.intervals, current) + n++ + + // Truncate the intervals if they are too long + if n >= i.maxIntervals { + copy(i.intervals[0:], i.intervals[n-i.maxIntervals:]) + i.intervals = i.intervals[:i.maxIntervals] + } + return current +} + +// getInterval returns the current interval to write to +func (i *InmemSink) getInterval() *IntervalMetrics { + intv := time.Now().Truncate(i.interval) + if m := i.getExistingInterval(intv); m != nil { + return m + } + return i.createInterval(intv) +} + +// Flattens the key for formatting, removes spaces +func (i *InmemSink) flattenKey(parts []string) string { + buf := &bytes.Buffer{} + replacer := strings.NewReplacer(" ", "_") + + if len(parts) > 0 { + replacer.WriteString(buf, parts[0]) + } + for _, part := range parts[1:] { + replacer.WriteString(buf, ".") + replacer.WriteString(buf, part) + } + + return buf.String() +} + +// Flattens the key for formatting along with its labels, removes spaces +func (i *InmemSink) flattenKeyLabels(parts []string, labels []Label) (string, string) { + buf := &bytes.Buffer{} + replacer := strings.NewReplacer(" ", "_") + + if len(parts) > 0 { + replacer.WriteString(buf, parts[0]) + } + for _, part := range parts[1:] { + replacer.WriteString(buf, ".") + replacer.WriteString(buf, part) + } + + key := buf.String() + + for _, label := range labels { + replacer.WriteString(buf, fmt.Sprintf(";%s=%s", label.Name, label.Value)) + } + + return buf.String(), key +} diff --git a/vendor/github.com/armon/go-metrics/inmem_endpoint.go b/vendor/github.com/armon/go-metrics/inmem_endpoint.go new file mode 100644 index 00000000000..504f1b37485 --- /dev/null +++ b/vendor/github.com/armon/go-metrics/inmem_endpoint.go @@ -0,0 +1,118 @@ +package metrics + +import ( + "fmt" + "net/http" + "sort" + "time" +) + +// MetricsSummary holds a roll-up of metrics info for a given interval +type MetricsSummary struct { + Timestamp string + Gauges []GaugeValue + Points []PointValue + Counters []SampledValue + Samples []SampledValue +} + +type GaugeValue struct { + Name string + Hash string `json:"-"` + Value float32 + + Labels []Label `json:"-"` + DisplayLabels map[string]string `json:"Labels"` +} + +type PointValue struct { + Name string + Points []float32 +} + +type SampledValue struct { + Name string + Hash string `json:"-"` + *AggregateSample + Mean float64 + Stddev float64 + + Labels []Label `json:"-"` + DisplayLabels map[string]string `json:"Labels"` +} + +// DisplayMetrics returns a summary of the metrics from the most recent finished interval. +func (i *InmemSink) DisplayMetrics(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + data := i.Data() + + var interval *IntervalMetrics + n := len(data) + switch { + case n == 0: + return nil, fmt.Errorf("no metric intervals have been initialized yet") + case n == 1: + // Show the current interval if it's all we have + interval = i.intervals[0] + default: + // Show the most recent finished interval if we have one + interval = i.intervals[n-2] + } + + summary := MetricsSummary{ + Timestamp: interval.Interval.Round(time.Second).UTC().String(), + Gauges: make([]GaugeValue, 0, len(interval.Gauges)), + Points: make([]PointValue, 0, len(interval.Points)), + } + + // Format and sort the output of each metric type, so it gets displayed in a + // deterministic order. + for name, points := range interval.Points { + summary.Points = append(summary.Points, PointValue{name, points}) + } + sort.Slice(summary.Points, func(i, j int) bool { + return summary.Points[i].Name < summary.Points[j].Name + }) + + for hash, value := range interval.Gauges { + value.Hash = hash + value.DisplayLabels = make(map[string]string) + for _, label := range value.Labels { + value.DisplayLabels[label.Name] = label.Value + } + value.Labels = nil + + summary.Gauges = append(summary.Gauges, value) + } + sort.Slice(summary.Gauges, func(i, j int) bool { + return summary.Gauges[i].Hash < summary.Gauges[j].Hash + }) + + summary.Counters = formatSamples(interval.Counters) + summary.Samples = formatSamples(interval.Samples) + + return summary, nil +} + +func formatSamples(source map[string]SampledValue) []SampledValue { + output := make([]SampledValue, 0, len(source)) + for hash, sample := range source { + displayLabels := make(map[string]string) + for _, label := range sample.Labels { + displayLabels[label.Name] = label.Value + } + + output = append(output, SampledValue{ + Name: sample.Name, + Hash: hash, + AggregateSample: sample.AggregateSample, + Mean: sample.AggregateSample.Mean(), + Stddev: sample.AggregateSample.Stddev(), + DisplayLabels: displayLabels, + }) + } + sort.Slice(output, func(i, j int) bool { + return output[i].Hash < output[j].Hash + }) + + return output +} diff --git a/vendor/github.com/armon/go-metrics/inmem_signal.go b/vendor/github.com/armon/go-metrics/inmem_signal.go new file mode 100644 index 00000000000..0937f4aedf7 --- /dev/null +++ b/vendor/github.com/armon/go-metrics/inmem_signal.go @@ -0,0 +1,117 @@ +package metrics + +import ( + "bytes" + "fmt" + "io" + "os" + "os/signal" + "strings" + "sync" + "syscall" +) + +// InmemSignal is used to listen for a given signal, and when received, +// to dump the current metrics from the InmemSink to an io.Writer +type InmemSignal struct { + signal syscall.Signal + inm *InmemSink + w io.Writer + sigCh chan os.Signal + + stop bool + stopCh chan struct{} + stopLock sync.Mutex +} + +// NewInmemSignal creates a new InmemSignal which listens for a given signal, +// and dumps the current metrics out to a writer +func NewInmemSignal(inmem *InmemSink, sig syscall.Signal, w io.Writer) *InmemSignal { + i := &InmemSignal{ + signal: sig, + inm: inmem, + w: w, + sigCh: make(chan os.Signal, 1), + stopCh: make(chan struct{}), + } + signal.Notify(i.sigCh, sig) + go i.run() + return i +} + +// DefaultInmemSignal returns a new InmemSignal that responds to SIGUSR1 +// and writes output to stderr. Windows uses SIGBREAK +func DefaultInmemSignal(inmem *InmemSink) *InmemSignal { + return NewInmemSignal(inmem, DefaultSignal, os.Stderr) +} + +// Stop is used to stop the InmemSignal from listening +func (i *InmemSignal) Stop() { + i.stopLock.Lock() + defer i.stopLock.Unlock() + + if i.stop { + return + } + i.stop = true + close(i.stopCh) + signal.Stop(i.sigCh) +} + +// run is a long running routine that handles signals +func (i *InmemSignal) run() { + for { + select { + case <-i.sigCh: + i.dumpStats() + case <-i.stopCh: + return + } + } +} + +// dumpStats is used to dump the data to output writer +func (i *InmemSignal) dumpStats() { + buf := bytes.NewBuffer(nil) + + data := i.inm.Data() + // Skip the last period which is still being aggregated + for j := 0; j < len(data)-1; j++ { + intv := data[j] + intv.RLock() + for _, val := range intv.Gauges { + name := i.flattenLabels(val.Name, val.Labels) + fmt.Fprintf(buf, "[%v][G] '%s': %0.3f\n", intv.Interval, name, val.Value) + } + for name, vals := range intv.Points { + for _, val := range vals { + fmt.Fprintf(buf, "[%v][P] '%s': %0.3f\n", intv.Interval, name, val) + } + } + for _, agg := range intv.Counters { + name := i.flattenLabels(agg.Name, agg.Labels) + fmt.Fprintf(buf, "[%v][C] '%s': %s\n", intv.Interval, name, agg.AggregateSample) + } + for _, agg := range intv.Samples { + name := i.flattenLabels(agg.Name, agg.Labels) + fmt.Fprintf(buf, "[%v][S] '%s': %s\n", intv.Interval, name, agg.AggregateSample) + } + intv.RUnlock() + } + + // Write out the bytes + i.w.Write(buf.Bytes()) +} + +// Flattens the key for formatting along with its labels, removes spaces +func (i *InmemSignal) flattenLabels(name string, labels []Label) string { + buf := bytes.NewBufferString(name) + replacer := strings.NewReplacer(" ", "_", ":", "_") + + for _, label := range labels { + replacer.WriteString(buf, ".") + replacer.WriteString(buf, label.Value) + } + + return buf.String() +} diff --git a/vendor/github.com/armon/go-metrics/metrics.go b/vendor/github.com/armon/go-metrics/metrics.go new file mode 100644 index 00000000000..cf9def748e2 --- /dev/null +++ b/vendor/github.com/armon/go-metrics/metrics.go @@ -0,0 +1,278 @@ +package metrics + +import ( + "runtime" + "strings" + "time" + + "github.com/hashicorp/go-immutable-radix" +) + +type Label struct { + Name string + Value string +} + +func (m *Metrics) SetGauge(key []string, val float32) { + m.SetGaugeWithLabels(key, val, nil) +} + +func (m *Metrics) SetGaugeWithLabels(key []string, val float32, labels []Label) { + if m.HostName != "" { + if m.EnableHostnameLabel { + labels = append(labels, Label{"host", m.HostName}) + } else if m.EnableHostname { + key = insert(0, m.HostName, key) + } + } + if m.EnableTypePrefix { + key = insert(0, "gauge", key) + } + if m.ServiceName != "" { + if m.EnableServiceLabel { + labels = append(labels, Label{"service", m.ServiceName}) + } else { + key = insert(0, m.ServiceName, key) + } + } + allowed, labelsFiltered := m.allowMetric(key, labels) + if !allowed { + return + } + m.sink.SetGaugeWithLabels(key, val, labelsFiltered) +} + +func (m *Metrics) EmitKey(key []string, val float32) { + if m.EnableTypePrefix { + key = insert(0, "kv", key) + } + if m.ServiceName != "" { + key = insert(0, m.ServiceName, key) + } + allowed, _ := m.allowMetric(key, nil) + if !allowed { + return + } + m.sink.EmitKey(key, val) +} + +func (m *Metrics) IncrCounter(key []string, val float32) { + m.IncrCounterWithLabels(key, val, nil) +} + +func (m *Metrics) IncrCounterWithLabels(key []string, val float32, labels []Label) { + if m.HostName != "" && m.EnableHostnameLabel { + labels = append(labels, Label{"host", m.HostName}) + } + if m.EnableTypePrefix { + key = insert(0, "counter", key) + } + if m.ServiceName != "" { + if m.EnableServiceLabel { + labels = append(labels, Label{"service", m.ServiceName}) + } else { + key = insert(0, m.ServiceName, key) + } + } + allowed, labelsFiltered := m.allowMetric(key, labels) + if !allowed { + return + } + m.sink.IncrCounterWithLabels(key, val, labelsFiltered) +} + +func (m *Metrics) AddSample(key []string, val float32) { + m.AddSampleWithLabels(key, val, nil) +} + +func (m *Metrics) AddSampleWithLabels(key []string, val float32, labels []Label) { + if m.HostName != "" && m.EnableHostnameLabel { + labels = append(labels, Label{"host", m.HostName}) + } + if m.EnableTypePrefix { + key = insert(0, "sample", key) + } + if m.ServiceName != "" { + if m.EnableServiceLabel { + labels = append(labels, Label{"service", m.ServiceName}) + } else { + key = insert(0, m.ServiceName, key) + } + } + allowed, labelsFiltered := m.allowMetric(key, labels) + if !allowed { + return + } + m.sink.AddSampleWithLabels(key, val, labelsFiltered) +} + +func (m *Metrics) MeasureSince(key []string, start time.Time) { + m.MeasureSinceWithLabels(key, start, nil) +} + +func (m *Metrics) MeasureSinceWithLabels(key []string, start time.Time, labels []Label) { + if m.HostName != "" && m.EnableHostnameLabel { + labels = append(labels, Label{"host", m.HostName}) + } + if m.EnableTypePrefix { + key = insert(0, "timer", key) + } + if m.ServiceName != "" { + if m.EnableServiceLabel { + labels = append(labels, Label{"service", m.ServiceName}) + } else { + key = insert(0, m.ServiceName, key) + } + } + allowed, labelsFiltered := m.allowMetric(key, labels) + if !allowed { + return + } + now := time.Now() + elapsed := now.Sub(start) + msec := float32(elapsed.Nanoseconds()) / float32(m.TimerGranularity) + m.sink.AddSampleWithLabels(key, msec, labelsFiltered) +} + +// UpdateFilter overwrites the existing filter with the given rules. +func (m *Metrics) UpdateFilter(allow, block []string) { + m.UpdateFilterAndLabels(allow, block, m.AllowedLabels, m.BlockedLabels) +} + +// UpdateFilterAndLabels overwrites the existing filter with the given rules. +func (m *Metrics) UpdateFilterAndLabels(allow, block, allowedLabels, blockedLabels []string) { + m.filterLock.Lock() + defer m.filterLock.Unlock() + + m.AllowedPrefixes = allow + m.BlockedPrefixes = block + + if allowedLabels == nil { + // Having a white list means we take only elements from it + m.allowedLabels = nil + } else { + m.allowedLabels = make(map[string]bool) + for _, v := range allowedLabels { + m.allowedLabels[v] = true + } + } + m.blockedLabels = make(map[string]bool) + for _, v := range blockedLabels { + m.blockedLabels[v] = true + } + m.AllowedLabels = allowedLabels + m.BlockedLabels = blockedLabels + + m.filter = iradix.New() + for _, prefix := range m.AllowedPrefixes { + m.filter, _, _ = m.filter.Insert([]byte(prefix), true) + } + for _, prefix := range m.BlockedPrefixes { + m.filter, _, _ = m.filter.Insert([]byte(prefix), false) + } +} + +// labelIsAllowed return true if a should be included in metric +// the caller should lock m.filterLock while calling this method +func (m *Metrics) labelIsAllowed(label *Label) bool { + labelName := (*label).Name + if m.blockedLabels != nil { + _, ok := m.blockedLabels[labelName] + if ok { + // If present, let's remove this label + return false + } + } + if m.allowedLabels != nil { + _, ok := m.allowedLabels[labelName] + return ok + } + // Allow by default + return true +} + +// filterLabels return only allowed labels +// the caller should lock m.filterLock while calling this method +func (m *Metrics) filterLabels(labels []Label) []Label { + if labels == nil { + return nil + } + toReturn := labels[:0] + for _, label := range labels { + if m.labelIsAllowed(&label) { + toReturn = append(toReturn, label) + } + } + return toReturn +} + +// Returns whether the metric should be allowed based on configured prefix filters +// Also return the applicable labels +func (m *Metrics) allowMetric(key []string, labels []Label) (bool, []Label) { + m.filterLock.RLock() + defer m.filterLock.RUnlock() + + if m.filter == nil || m.filter.Len() == 0 { + return m.Config.FilterDefault, m.filterLabels(labels) + } + + _, allowed, ok := m.filter.Root().LongestPrefix([]byte(strings.Join(key, "."))) + if !ok { + return m.Config.FilterDefault, m.filterLabels(labels) + } + + return allowed.(bool), m.filterLabels(labels) +} + +// Periodically collects runtime stats to publish +func (m *Metrics) collectStats() { + for { + time.Sleep(m.ProfileInterval) + m.emitRuntimeStats() + } +} + +// Emits various runtime statsitics +func (m *Metrics) emitRuntimeStats() { + // Export number of Goroutines + numRoutines := runtime.NumGoroutine() + m.SetGauge([]string{"runtime", "num_goroutines"}, float32(numRoutines)) + + // Export memory stats + var stats runtime.MemStats + runtime.ReadMemStats(&stats) + m.SetGauge([]string{"runtime", "alloc_bytes"}, float32(stats.Alloc)) + m.SetGauge([]string{"runtime", "sys_bytes"}, float32(stats.Sys)) + m.SetGauge([]string{"runtime", "malloc_count"}, float32(stats.Mallocs)) + m.SetGauge([]string{"runtime", "free_count"}, float32(stats.Frees)) + m.SetGauge([]string{"runtime", "heap_objects"}, float32(stats.HeapObjects)) + m.SetGauge([]string{"runtime", "total_gc_pause_ns"}, float32(stats.PauseTotalNs)) + m.SetGauge([]string{"runtime", "total_gc_runs"}, float32(stats.NumGC)) + + // Export info about the last few GC runs + num := stats.NumGC + + // Handle wrap around + if num < m.lastNumGC { + m.lastNumGC = 0 + } + + // Ensure we don't scan more than 256 + if num-m.lastNumGC >= 256 { + m.lastNumGC = num - 255 + } + + for i := m.lastNumGC; i < num; i++ { + pause := stats.PauseNs[i%256] + m.AddSample([]string{"runtime", "gc_pause_ns"}, float32(pause)) + } + m.lastNumGC = num +} + +// Inserts a string value at an index into the slice +func insert(i int, v string, s []string) []string { + s = append(s, "") + copy(s[i+1:], s[i:]) + s[i] = v + return s +} diff --git a/vendor/github.com/armon/go-metrics/sink.go b/vendor/github.com/armon/go-metrics/sink.go new file mode 100644 index 00000000000..0b7d6e4be43 --- /dev/null +++ b/vendor/github.com/armon/go-metrics/sink.go @@ -0,0 +1,115 @@ +package metrics + +import ( + "fmt" + "net/url" +) + +// The MetricSink interface is used to transmit metrics information +// to an external system +type MetricSink interface { + // A Gauge should retain the last value it is set to + SetGauge(key []string, val float32) + SetGaugeWithLabels(key []string, val float32, labels []Label) + + // Should emit a Key/Value pair for each call + EmitKey(key []string, val float32) + + // Counters should accumulate values + IncrCounter(key []string, val float32) + IncrCounterWithLabels(key []string, val float32, labels []Label) + + // Samples are for timing information, where quantiles are used + AddSample(key []string, val float32) + AddSampleWithLabels(key []string, val float32, labels []Label) +} + +// BlackholeSink is used to just blackhole messages +type BlackholeSink struct{} + +func (*BlackholeSink) SetGauge(key []string, val float32) {} +func (*BlackholeSink) SetGaugeWithLabels(key []string, val float32, labels []Label) {} +func (*BlackholeSink) EmitKey(key []string, val float32) {} +func (*BlackholeSink) IncrCounter(key []string, val float32) {} +func (*BlackholeSink) IncrCounterWithLabels(key []string, val float32, labels []Label) {} +func (*BlackholeSink) AddSample(key []string, val float32) {} +func (*BlackholeSink) AddSampleWithLabels(key []string, val float32, labels []Label) {} + +// FanoutSink is used to sink to fanout values to multiple sinks +type FanoutSink []MetricSink + +func (fh FanoutSink) SetGauge(key []string, val float32) { + fh.SetGaugeWithLabels(key, val, nil) +} + +func (fh FanoutSink) SetGaugeWithLabels(key []string, val float32, labels []Label) { + for _, s := range fh { + s.SetGaugeWithLabels(key, val, labels) + } +} + +func (fh FanoutSink) EmitKey(key []string, val float32) { + for _, s := range fh { + s.EmitKey(key, val) + } +} + +func (fh FanoutSink) IncrCounter(key []string, val float32) { + fh.IncrCounterWithLabels(key, val, nil) +} + +func (fh FanoutSink) IncrCounterWithLabels(key []string, val float32, labels []Label) { + for _, s := range fh { + s.IncrCounterWithLabels(key, val, labels) + } +} + +func (fh FanoutSink) AddSample(key []string, val float32) { + fh.AddSampleWithLabels(key, val, nil) +} + +func (fh FanoutSink) AddSampleWithLabels(key []string, val float32, labels []Label) { + for _, s := range fh { + s.AddSampleWithLabels(key, val, labels) + } +} + +// sinkURLFactoryFunc is an generic interface around the *SinkFromURL() function provided +// by each sink type +type sinkURLFactoryFunc func(*url.URL) (MetricSink, error) + +// sinkRegistry supports the generic NewMetricSink function by mapping URL +// schemes to metric sink factory functions +var sinkRegistry = map[string]sinkURLFactoryFunc{ + "statsd": NewStatsdSinkFromURL, + "statsite": NewStatsiteSinkFromURL, + "inmem": NewInmemSinkFromURL, +} + +// NewMetricSinkFromURL allows a generic URL input to configure any of the +// supported sinks. The scheme of the URL identifies the type of the sink, the +// and query parameters are used to set options. +// +// "statsd://" - Initializes a StatsdSink. The host and port are passed through +// as the "addr" of the sink +// +// "statsite://" - Initializes a StatsiteSink. The host and port become the +// "addr" of the sink +// +// "inmem://" - Initializes an InmemSink. The host and port are ignored. The +// "interval" and "duration" query parameters must be specified with valid +// durations, see NewInmemSink for details. +func NewMetricSinkFromURL(urlStr string) (MetricSink, error) { + u, err := url.Parse(urlStr) + if err != nil { + return nil, err + } + + sinkURLFactoryFunc := sinkRegistry[u.Scheme] + if sinkURLFactoryFunc == nil { + return nil, fmt.Errorf( + "cannot create metric sink, unrecognized sink name: %q", u.Scheme) + } + + return sinkURLFactoryFunc(u) +} diff --git a/vendor/github.com/armon/go-metrics/start.go b/vendor/github.com/armon/go-metrics/start.go new file mode 100644 index 00000000000..32a28c48378 --- /dev/null +++ b/vendor/github.com/armon/go-metrics/start.go @@ -0,0 +1,141 @@ +package metrics + +import ( + "os" + "sync" + "sync/atomic" + "time" + + "github.com/hashicorp/go-immutable-radix" +) + +// Config is used to configure metrics settings +type Config struct { + ServiceName string // Prefixed with keys to separate services + HostName string // Hostname to use. If not provided and EnableHostname, it will be os.Hostname + EnableHostname bool // Enable prefixing gauge values with hostname + EnableHostnameLabel bool // Enable adding hostname to labels + EnableServiceLabel bool // Enable adding service to labels + EnableRuntimeMetrics bool // Enables profiling of runtime metrics (GC, Goroutines, Memory) + EnableTypePrefix bool // Prefixes key with a type ("counter", "gauge", "timer") + TimerGranularity time.Duration // Granularity of timers. + ProfileInterval time.Duration // Interval to profile runtime metrics + + AllowedPrefixes []string // A list of metric prefixes to allow, with '.' as the separator + BlockedPrefixes []string // A list of metric prefixes to block, with '.' as the separator + AllowedLabels []string // A list of metric labels to allow, with '.' as the separator + BlockedLabels []string // A list of metric labels to block, with '.' as the separator + FilterDefault bool // Whether to allow metrics by default +} + +// Metrics represents an instance of a metrics sink that can +// be used to emit +type Metrics struct { + Config + lastNumGC uint32 + sink MetricSink + filter *iradix.Tree + allowedLabels map[string]bool + blockedLabels map[string]bool + filterLock sync.RWMutex // Lock filters and allowedLabels/blockedLabels access +} + +// Shared global metrics instance +var globalMetrics atomic.Value // *Metrics + +func init() { + // Initialize to a blackhole sink to avoid errors + globalMetrics.Store(&Metrics{sink: &BlackholeSink{}}) +} + +// DefaultConfig provides a sane default configuration +func DefaultConfig(serviceName string) *Config { + c := &Config{ + ServiceName: serviceName, // Use client provided service + HostName: "", + EnableHostname: true, // Enable hostname prefix + EnableRuntimeMetrics: true, // Enable runtime profiling + EnableTypePrefix: false, // Disable type prefix + TimerGranularity: time.Millisecond, // Timers are in milliseconds + ProfileInterval: time.Second, // Poll runtime every second + FilterDefault: true, // Don't filter metrics by default + } + + // Try to get the hostname + name, _ := os.Hostname() + c.HostName = name + return c +} + +// New is used to create a new instance of Metrics +func New(conf *Config, sink MetricSink) (*Metrics, error) { + met := &Metrics{} + met.Config = *conf + met.sink = sink + met.UpdateFilterAndLabels(conf.AllowedPrefixes, conf.BlockedPrefixes, conf.AllowedLabels, conf.BlockedLabels) + + // Start the runtime collector + if conf.EnableRuntimeMetrics { + go met.collectStats() + } + return met, nil +} + +// NewGlobal is the same as New, but it assigns the metrics object to be +// used globally as well as returning it. +func NewGlobal(conf *Config, sink MetricSink) (*Metrics, error) { + metrics, err := New(conf, sink) + if err == nil { + globalMetrics.Store(metrics) + } + return metrics, err +} + +// Proxy all the methods to the globalMetrics instance +func SetGauge(key []string, val float32) { + globalMetrics.Load().(*Metrics).SetGauge(key, val) +} + +func SetGaugeWithLabels(key []string, val float32, labels []Label) { + globalMetrics.Load().(*Metrics).SetGaugeWithLabels(key, val, labels) +} + +func EmitKey(key []string, val float32) { + globalMetrics.Load().(*Metrics).EmitKey(key, val) +} + +func IncrCounter(key []string, val float32) { + globalMetrics.Load().(*Metrics).IncrCounter(key, val) +} + +func IncrCounterWithLabels(key []string, val float32, labels []Label) { + globalMetrics.Load().(*Metrics).IncrCounterWithLabels(key, val, labels) +} + +func AddSample(key []string, val float32) { + globalMetrics.Load().(*Metrics).AddSample(key, val) +} + +func AddSampleWithLabels(key []string, val float32, labels []Label) { + globalMetrics.Load().(*Metrics).AddSampleWithLabels(key, val, labels) +} + +func MeasureSince(key []string, start time.Time) { + globalMetrics.Load().(*Metrics).MeasureSince(key, start) +} + +func MeasureSinceWithLabels(key []string, start time.Time, labels []Label) { + globalMetrics.Load().(*Metrics).MeasureSinceWithLabels(key, start, labels) +} + +func UpdateFilter(allow, block []string) { + globalMetrics.Load().(*Metrics).UpdateFilter(allow, block) +} + +// UpdateFilterAndLabels set allow/block prefixes of metrics while allowedLabels +// and blockedLabels - when not nil - allow filtering of labels in order to +// block/allow globally labels (especially useful when having large number of +// values for a given label). See README.md for more information about usage. +func UpdateFilterAndLabels(allow, block, allowedLabels, blockedLabels []string) { + globalMetrics.Load().(*Metrics).UpdateFilterAndLabels(allow, block, allowedLabels, blockedLabels) +} diff --git a/vendor/github.com/armon/go-metrics/statsd.go b/vendor/github.com/armon/go-metrics/statsd.go new file mode 100644 index 00000000000..1bfffce46e2 --- /dev/null +++ b/vendor/github.com/armon/go-metrics/statsd.go @@ -0,0 +1,184 @@ +package metrics + +import ( + "bytes" + "fmt" + "log" + "net" + "net/url" + "strings" + "time" +) + +const ( + // statsdMaxLen is the maximum size of a packet + // to send to statsd + statsdMaxLen = 1400 +) + +// StatsdSink provides a MetricSink that can be used +// with a statsite or statsd metrics server. It uses +// only UDP packets, while StatsiteSink uses TCP. +type StatsdSink struct { + addr string + metricQueue chan string +} + +// NewStatsdSinkFromURL creates an StatsdSink from a URL. It is used +// (and tested) from NewMetricSinkFromURL. +func NewStatsdSinkFromURL(u *url.URL) (MetricSink, error) { + return NewStatsdSink(u.Host) +} + +// NewStatsdSink is used to create a new StatsdSink +func NewStatsdSink(addr string) (*StatsdSink, error) { + s := &StatsdSink{ + addr: addr, + metricQueue: make(chan string, 4096), + } + go s.flushMetrics() + return s, nil +} + +// Close is used to stop flushing to statsd +func (s *StatsdSink) Shutdown() { + close(s.metricQueue) +} + +func (s *StatsdSink) SetGauge(key []string, val float32) { + flatKey := s.flattenKey(key) + s.pushMetric(fmt.Sprintf("%s:%f|g\n", flatKey, val)) +} + +func (s *StatsdSink) SetGaugeWithLabels(key []string, val float32, labels []Label) { + flatKey := s.flattenKeyLabels(key, labels) + s.pushMetric(fmt.Sprintf("%s:%f|g\n", flatKey, val)) +} + +func (s *StatsdSink) EmitKey(key []string, val float32) { + flatKey := s.flattenKey(key) + s.pushMetric(fmt.Sprintf("%s:%f|kv\n", flatKey, val)) +} + +func (s *StatsdSink) IncrCounter(key []string, val float32) { + flatKey := s.flattenKey(key) + s.pushMetric(fmt.Sprintf("%s:%f|c\n", flatKey, val)) +} + +func (s *StatsdSink) IncrCounterWithLabels(key []string, val float32, labels []Label) { + flatKey := s.flattenKeyLabels(key, labels) + s.pushMetric(fmt.Sprintf("%s:%f|c\n", flatKey, val)) +} + +func (s *StatsdSink) AddSample(key []string, val float32) { + flatKey := s.flattenKey(key) + s.pushMetric(fmt.Sprintf("%s:%f|ms\n", flatKey, val)) +} + +func (s *StatsdSink) AddSampleWithLabels(key []string, val float32, labels []Label) { + flatKey := s.flattenKeyLabels(key, labels) + s.pushMetric(fmt.Sprintf("%s:%f|ms\n", flatKey, val)) +} + +// Flattens the key for formatting, removes spaces +func (s *StatsdSink) flattenKey(parts []string) string { + joined := strings.Join(parts, ".") + return strings.Map(func(r rune) rune { + switch r { + case ':': + fallthrough + case ' ': + return '_' + default: + return r + } + }, joined) +} + +// Flattens the key along with labels for formatting, removes spaces +func (s *StatsdSink) flattenKeyLabels(parts []string, labels []Label) string { + for _, label := range labels { + parts = append(parts, label.Value) + } + return s.flattenKey(parts) +} + +// Does a non-blocking push to the metrics queue +func (s *StatsdSink) pushMetric(m string) { + select { + case s.metricQueue <- m: + default: + } +} + +// Flushes metrics +func (s *StatsdSink) flushMetrics() { + var sock net.Conn + var err error + var wait <-chan time.Time + ticker := time.NewTicker(flushInterval) + defer ticker.Stop() + +CONNECT: + // Create a buffer + buf := bytes.NewBuffer(nil) + + // Attempt to connect + sock, err = net.Dial("udp", s.addr) + if err != nil { + log.Printf("[ERR] Error connecting to statsd! Err: %s", err) + goto WAIT + } + + for { + select { + case metric, ok := <-s.metricQueue: + // Get a metric from the queue + if !ok { + goto QUIT + } + + // Check if this would overflow the packet size + if len(metric)+buf.Len() > statsdMaxLen { + _, err := sock.Write(buf.Bytes()) + buf.Reset() + if err != nil { + log.Printf("[ERR] Error writing to statsd! Err: %s", err) + goto WAIT + } + } + + // Append to the buffer + buf.WriteString(metric) + + case <-ticker.C: + if buf.Len() == 0 { + continue + } + + _, err := sock.Write(buf.Bytes()) + buf.Reset() + if err != nil { + log.Printf("[ERR] Error flushing to statsd! Err: %s", err) + goto WAIT + } + } + } + +WAIT: + // Wait for a while + wait = time.After(time.Duration(5) * time.Second) + for { + select { + // Dequeue the messages to avoid backlog + case _, ok := <-s.metricQueue: + if !ok { + goto QUIT + } + case <-wait: + goto CONNECT + } + } +QUIT: + s.metricQueue = nil +} diff --git a/vendor/github.com/armon/go-metrics/statsite.go b/vendor/github.com/armon/go-metrics/statsite.go new file mode 100644 index 00000000000..6c0d284d2dd --- /dev/null +++ b/vendor/github.com/armon/go-metrics/statsite.go @@ -0,0 +1,172 @@ +package metrics + +import ( + "bufio" + "fmt" + "log" + "net" + "net/url" + "strings" + "time" +) + +const ( + // We force flush the statsite metrics after this period of + // inactivity. Prevents stats from getting stuck in a buffer + // forever. + flushInterval = 100 * time.Millisecond +) + +// NewStatsiteSinkFromURL creates an StatsiteSink from a URL. It is used +// (and tested) from NewMetricSinkFromURL. +func NewStatsiteSinkFromURL(u *url.URL) (MetricSink, error) { + return NewStatsiteSink(u.Host) +} + +// StatsiteSink provides a MetricSink that can be used with a +// statsite metrics server +type StatsiteSink struct { + addr string + metricQueue chan string +} + +// NewStatsiteSink is used to create a new StatsiteSink +func NewStatsiteSink(addr string) (*StatsiteSink, error) { + s := &StatsiteSink{ + addr: addr, + metricQueue: make(chan string, 4096), + } + go s.flushMetrics() + return s, nil +} + +// Close is used to stop flushing to statsite +func (s *StatsiteSink) Shutdown() { + close(s.metricQueue) +} + +func (s *StatsiteSink) SetGauge(key []string, val float32) { + flatKey := s.flattenKey(key) + s.pushMetric(fmt.Sprintf("%s:%f|g\n", flatKey, val)) +} + +func (s *StatsiteSink) SetGaugeWithLabels(key []string, val float32, labels []Label) { + flatKey := s.flattenKeyLabels(key, labels) + s.pushMetric(fmt.Sprintf("%s:%f|g\n", flatKey, val)) +} + +func (s *StatsiteSink) EmitKey(key []string, val float32) { + flatKey := s.flattenKey(key) + s.pushMetric(fmt.Sprintf("%s:%f|kv\n", flatKey, val)) +} + +func (s *StatsiteSink) IncrCounter(key []string, val float32) { + flatKey := s.flattenKey(key) + s.pushMetric(fmt.Sprintf("%s:%f|c\n", flatKey, val)) +} + +func (s *StatsiteSink) IncrCounterWithLabels(key []string, val float32, labels []Label) { + flatKey := s.flattenKeyLabels(key, labels) + s.pushMetric(fmt.Sprintf("%s:%f|c\n", flatKey, val)) +} + +func (s *StatsiteSink) AddSample(key []string, val float32) { + flatKey := s.flattenKey(key) + s.pushMetric(fmt.Sprintf("%s:%f|ms\n", flatKey, val)) +} + +func (s *StatsiteSink) AddSampleWithLabels(key []string, val float32, labels []Label) { + flatKey := s.flattenKeyLabels(key, labels) + s.pushMetric(fmt.Sprintf("%s:%f|ms\n", flatKey, val)) +} + +// Flattens the key for formatting, removes spaces +func (s *StatsiteSink) flattenKey(parts []string) string { + joined := strings.Join(parts, ".") + return strings.Map(func(r rune) rune { + switch r { + case ':': + fallthrough + case ' ': + return '_' + default: + return r + } + }, joined) +} + +// Flattens the key along with labels for formatting, removes spaces +func (s *StatsiteSink) flattenKeyLabels(parts []string, labels []Label) string { + for _, label := range labels { + parts = append(parts, label.Value) + } + return s.flattenKey(parts) +} + +// Does a non-blocking push to the metrics queue +func (s *StatsiteSink) pushMetric(m string) { + select { + case s.metricQueue <- m: + default: + } +} + +// Flushes metrics +func (s *StatsiteSink) flushMetrics() { + var sock net.Conn + var err error + var wait <-chan time.Time + var buffered *bufio.Writer + ticker := time.NewTicker(flushInterval) + defer ticker.Stop() + +CONNECT: + // Attempt to connect + sock, err = net.Dial("tcp", s.addr) + if err != nil { + log.Printf("[ERR] Error connecting to statsite! Err: %s", err) + goto WAIT + } + + // Create a buffered writer + buffered = bufio.NewWriter(sock) + + for { + select { + case metric, ok := <-s.metricQueue: + // Get a metric from the queue + if !ok { + goto QUIT + } + + // Try to send to statsite + _, err := buffered.Write([]byte(metric)) + if err != nil { + log.Printf("[ERR] Error writing to statsite! Err: %s", err) + goto WAIT + } + case <-ticker.C: + if err := buffered.Flush(); err != nil { + log.Printf("[ERR] Error flushing to statsite! Err: %s", err) + goto WAIT + } + } + } + +WAIT: + // Wait for a while + wait = time.After(time.Duration(5) * time.Second) + for { + select { + // Dequeue the messages to avoid backlog + case _, ok := <-s.metricQueue: + if !ok { + goto QUIT + } + case <-wait: + goto CONNECT + } + } +QUIT: + s.metricQueue = nil +} diff --git a/vendor/github.com/boltdb/bolt/LICENSE b/vendor/github.com/boltdb/bolt/LICENSE new file mode 100644 index 00000000000..004e77fe5d2 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2013 Ben Johnson + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/boltdb/bolt/bolt_386.go b/vendor/github.com/boltdb/bolt/bolt_386.go new file mode 100644 index 00000000000..820d533c15f --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_386.go @@ -0,0 +1,10 @@ +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0x7FFFFFFF // 2GB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0xFFFFFFF + +// Are unaligned load/stores broken on this arch? +var brokenUnaligned = false diff --git a/vendor/github.com/boltdb/bolt/bolt_amd64.go b/vendor/github.com/boltdb/bolt/bolt_amd64.go new file mode 100644 index 00000000000..98fafdb47d8 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_amd64.go @@ -0,0 +1,10 @@ +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0xFFFFFFFFFFFF // 256TB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0x7FFFFFFF + +// Are unaligned load/stores broken on this arch? +var brokenUnaligned = false diff --git a/vendor/github.com/boltdb/bolt/bolt_arm.go b/vendor/github.com/boltdb/bolt/bolt_arm.go new file mode 100644 index 00000000000..7e5cb4b9412 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_arm.go @@ -0,0 +1,28 @@ +package bolt + +import "unsafe" + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0x7FFFFFFF // 2GB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0xFFFFFFF + +// Are unaligned load/stores broken on this arch? +var brokenUnaligned bool + +func init() { + // Simple check to see whether this arch handles unaligned load/stores + // correctly. + + // ARM9 and older devices require load/stores to be from/to aligned + // addresses. If not, the lower 2 bits are cleared and that address is + // read in a jumbled up order. + + // See http://infocenter.arm.com/help/index.jsp?topic=/com.arm.doc.faqs/ka15414.html + + raw := [6]byte{0xfe, 0xef, 0x11, 0x22, 0x22, 0x11} + val := *(*uint32)(unsafe.Pointer(uintptr(unsafe.Pointer(&raw)) + 2)) + + brokenUnaligned = val != 0x11222211 +} diff --git a/vendor/github.com/boltdb/bolt/bolt_arm64.go b/vendor/github.com/boltdb/bolt/bolt_arm64.go new file mode 100644 index 00000000000..b26d84f91ba --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_arm64.go @@ -0,0 +1,12 @@ +// +build arm64 + +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0xFFFFFFFFFFFF // 256TB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0x7FFFFFFF + +// Are unaligned load/stores broken on this arch? +var brokenUnaligned = false diff --git a/vendor/github.com/boltdb/bolt/bolt_linux.go b/vendor/github.com/boltdb/bolt/bolt_linux.go new file mode 100644 index 00000000000..2b676661409 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_linux.go @@ -0,0 +1,10 @@ +package bolt + +import ( + "syscall" +) + +// fdatasync flushes written data to a file descriptor. +func fdatasync(db *DB) error { + return syscall.Fdatasync(int(db.file.Fd())) +} diff --git a/vendor/github.com/boltdb/bolt/bolt_openbsd.go b/vendor/github.com/boltdb/bolt/bolt_openbsd.go new file mode 100644 index 00000000000..7058c3d734e --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_openbsd.go @@ -0,0 +1,27 @@ +package bolt + +import ( + "syscall" + "unsafe" +) + +const ( + msAsync = 1 << iota // perform asynchronous writes + msSync // perform synchronous writes + msInvalidate // invalidate cached data +) + +func msync(db *DB) error { + _, _, errno := syscall.Syscall(syscall.SYS_MSYNC, uintptr(unsafe.Pointer(db.data)), uintptr(db.datasz), msInvalidate) + if errno != 0 { + return errno + } + return nil +} + +func fdatasync(db *DB) error { + if db.data != nil { + return msync(db) + } + return db.file.Sync() +} diff --git a/vendor/github.com/boltdb/bolt/bolt_ppc.go b/vendor/github.com/boltdb/bolt/bolt_ppc.go new file mode 100644 index 00000000000..645ddc3edc2 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_ppc.go @@ -0,0 +1,9 @@ +// +build ppc + +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0x7FFFFFFF // 2GB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0xFFFFFFF diff --git a/vendor/github.com/boltdb/bolt/bolt_ppc64.go b/vendor/github.com/boltdb/bolt/bolt_ppc64.go new file mode 100644 index 00000000000..9331d9771eb --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_ppc64.go @@ -0,0 +1,12 @@ +// +build ppc64 + +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0xFFFFFFFFFFFF // 256TB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0x7FFFFFFF + +// Are unaligned load/stores broken on this arch? +var brokenUnaligned = false diff --git a/vendor/github.com/boltdb/bolt/bolt_ppc64le.go b/vendor/github.com/boltdb/bolt/bolt_ppc64le.go new file mode 100644 index 00000000000..8c143bc5d19 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_ppc64le.go @@ -0,0 +1,12 @@ +// +build ppc64le + +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0xFFFFFFFFFFFF // 256TB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0x7FFFFFFF + +// Are unaligned load/stores broken on this arch? +var brokenUnaligned = false diff --git a/vendor/github.com/boltdb/bolt/bolt_s390x.go b/vendor/github.com/boltdb/bolt/bolt_s390x.go new file mode 100644 index 00000000000..d7c39af9253 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_s390x.go @@ -0,0 +1,12 @@ +// +build s390x + +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0xFFFFFFFFFFFF // 256TB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0x7FFFFFFF + +// Are unaligned load/stores broken on this arch? +var brokenUnaligned = false diff --git a/vendor/github.com/boltdb/bolt/bolt_unix.go b/vendor/github.com/boltdb/bolt/bolt_unix.go new file mode 100644 index 00000000000..cad62dda1e3 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_unix.go @@ -0,0 +1,89 @@ +// +build !windows,!plan9,!solaris + +package bolt + +import ( + "fmt" + "os" + "syscall" + "time" + "unsafe" +) + +// flock acquires an advisory lock on a file descriptor. +func flock(db *DB, mode os.FileMode, exclusive bool, timeout time.Duration) error { + var t time.Time + for { + // If we're beyond our timeout then return an error. + // This can only occur after we've attempted a flock once. + if t.IsZero() { + t = time.Now() + } else if timeout > 0 && time.Since(t) > timeout { + return ErrTimeout + } + flag := syscall.LOCK_SH + if exclusive { + flag = syscall.LOCK_EX + } + + // Otherwise attempt to obtain an exclusive lock. + err := syscall.Flock(int(db.file.Fd()), flag|syscall.LOCK_NB) + if err == nil { + return nil + } else if err != syscall.EWOULDBLOCK { + return err + } + + // Wait for a bit and try again. + time.Sleep(50 * time.Millisecond) + } +} + +// funlock releases an advisory lock on a file descriptor. +func funlock(db *DB) error { + return syscall.Flock(int(db.file.Fd()), syscall.LOCK_UN) +} + +// mmap memory maps a DB's data file. +func mmap(db *DB, sz int) error { + // Map the data file to memory. + b, err := syscall.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags) + if err != nil { + return err + } + + // Advise the kernel that the mmap is accessed randomly. + if err := madvise(b, syscall.MADV_RANDOM); err != nil { + return fmt.Errorf("madvise: %s", err) + } + + // Save the original byte slice and convert to a byte array pointer. + db.dataref = b + db.data = (*[maxMapSize]byte)(unsafe.Pointer(&b[0])) + db.datasz = sz + return nil +} + +// munmap unmaps a DB's data file from memory. +func munmap(db *DB) error { + // Ignore the unmap if we have no mapped data. + if db.dataref == nil { + return nil + } + + // Unmap using the original byte slice. + err := syscall.Munmap(db.dataref) + db.dataref = nil + db.data = nil + db.datasz = 0 + return err +} + +// NOTE: This function is copied from stdlib because it is not available on darwin. +func madvise(b []byte, advice int) (err error) { + _, _, e1 := syscall.Syscall(syscall.SYS_MADVISE, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), uintptr(advice)) + if e1 != 0 { + err = e1 + } + return +} diff --git a/vendor/github.com/boltdb/bolt/bolt_unix_solaris.go b/vendor/github.com/boltdb/bolt/bolt_unix_solaris.go new file mode 100644 index 00000000000..307bf2b3ee9 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_unix_solaris.go @@ -0,0 +1,90 @@ +package bolt + +import ( + "fmt" + "os" + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/unix" +) + +// flock acquires an advisory lock on a file descriptor. +func flock(db *DB, mode os.FileMode, exclusive bool, timeout time.Duration) error { + var t time.Time + for { + // If we're beyond our timeout then return an error. + // This can only occur after we've attempted a flock once. + if t.IsZero() { + t = time.Now() + } else if timeout > 0 && time.Since(t) > timeout { + return ErrTimeout + } + var lock syscall.Flock_t + lock.Start = 0 + lock.Len = 0 + lock.Pid = 0 + lock.Whence = 0 + lock.Pid = 0 + if exclusive { + lock.Type = syscall.F_WRLCK + } else { + lock.Type = syscall.F_RDLCK + } + err := syscall.FcntlFlock(db.file.Fd(), syscall.F_SETLK, &lock) + if err == nil { + return nil + } else if err != syscall.EAGAIN { + return err + } + + // Wait for a bit and try again. + time.Sleep(50 * time.Millisecond) + } +} + +// funlock releases an advisory lock on a file descriptor. +func funlock(db *DB) error { + var lock syscall.Flock_t + lock.Start = 0 + lock.Len = 0 + lock.Type = syscall.F_UNLCK + lock.Whence = 0 + return syscall.FcntlFlock(uintptr(db.file.Fd()), syscall.F_SETLK, &lock) +} + +// mmap memory maps a DB's data file. +func mmap(db *DB, sz int) error { + // Map the data file to memory. + b, err := unix.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags) + if err != nil { + return err + } + + // Advise the kernel that the mmap is accessed randomly. + if err := unix.Madvise(b, syscall.MADV_RANDOM); err != nil { + return fmt.Errorf("madvise: %s", err) + } + + // Save the original byte slice and convert to a byte array pointer. + db.dataref = b + db.data = (*[maxMapSize]byte)(unsafe.Pointer(&b[0])) + db.datasz = sz + return nil +} + +// munmap unmaps a DB's data file from memory. +func munmap(db *DB) error { + // Ignore the unmap if we have no mapped data. + if db.dataref == nil { + return nil + } + + // Unmap using the original byte slice. + err := unix.Munmap(db.dataref) + db.dataref = nil + db.data = nil + db.datasz = 0 + return err +} diff --git a/vendor/github.com/boltdb/bolt/bolt_windows.go b/vendor/github.com/boltdb/bolt/bolt_windows.go new file mode 100644 index 00000000000..b00fb0720a4 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bolt_windows.go @@ -0,0 +1,144 @@ +package bolt + +import ( + "fmt" + "os" + "syscall" + "time" + "unsafe" +) + +// LockFileEx code derived from golang build filemutex_windows.go @ v1.5.1 +var ( + modkernel32 = syscall.NewLazyDLL("kernel32.dll") + procLockFileEx = modkernel32.NewProc("LockFileEx") + procUnlockFileEx = modkernel32.NewProc("UnlockFileEx") +) + +const ( + lockExt = ".lock" + + // see https://msdn.microsoft.com/en-us/library/windows/desktop/aa365203(v=vs.85).aspx + flagLockExclusive = 2 + flagLockFailImmediately = 1 + + // see https://msdn.microsoft.com/en-us/library/windows/desktop/ms681382(v=vs.85).aspx + errLockViolation syscall.Errno = 0x21 +) + +func lockFileEx(h syscall.Handle, flags, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) { + r, _, err := procLockFileEx.Call(uintptr(h), uintptr(flags), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol))) + if r == 0 { + return err + } + return nil +} + +func unlockFileEx(h syscall.Handle, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) { + r, _, err := procUnlockFileEx.Call(uintptr(h), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol)), 0) + if r == 0 { + return err + } + return nil +} + +// fdatasync flushes written data to a file descriptor. +func fdatasync(db *DB) error { + return db.file.Sync() +} + +// flock acquires an advisory lock on a file descriptor. +func flock(db *DB, mode os.FileMode, exclusive bool, timeout time.Duration) error { + // Create a separate lock file on windows because a process + // cannot share an exclusive lock on the same file. This is + // needed during Tx.WriteTo(). + f, err := os.OpenFile(db.path+lockExt, os.O_CREATE, mode) + if err != nil { + return err + } + db.lockfile = f + + var t time.Time + for { + // If we're beyond our timeout then return an error. + // This can only occur after we've attempted a flock once. + if t.IsZero() { + t = time.Now() + } else if timeout > 0 && time.Since(t) > timeout { + return ErrTimeout + } + + var flag uint32 = flagLockFailImmediately + if exclusive { + flag |= flagLockExclusive + } + + err := lockFileEx(syscall.Handle(db.lockfile.Fd()), flag, 0, 1, 0, &syscall.Overlapped{}) + if err == nil { + return nil + } else if err != errLockViolation { + return err + } + + // Wait for a bit and try again. + time.Sleep(50 * time.Millisecond) + } +} + +// funlock releases an advisory lock on a file descriptor. +func funlock(db *DB) error { + err := unlockFileEx(syscall.Handle(db.lockfile.Fd()), 0, 1, 0, &syscall.Overlapped{}) + db.lockfile.Close() + os.Remove(db.path + lockExt) + return err +} + +// mmap memory maps a DB's data file. +// Based on: https://github.com/edsrzf/mmap-go +func mmap(db *DB, sz int) error { + if !db.readOnly { + // Truncate the database to the size of the mmap. + if err := db.file.Truncate(int64(sz)); err != nil { + return fmt.Errorf("truncate: %s", err) + } + } + + // Open a file mapping handle. + sizelo := uint32(sz >> 32) + sizehi := uint32(sz) & 0xffffffff + h, errno := syscall.CreateFileMapping(syscall.Handle(db.file.Fd()), nil, syscall.PAGE_READONLY, sizelo, sizehi, nil) + if h == 0 { + return os.NewSyscallError("CreateFileMapping", errno) + } + + // Create the memory map. + addr, errno := syscall.MapViewOfFile(h, syscall.FILE_MAP_READ, 0, 0, uintptr(sz)) + if addr == 0 { + return os.NewSyscallError("MapViewOfFile", errno) + } + + // Close mapping handle. + if err := syscall.CloseHandle(syscall.Handle(h)); err != nil { + return os.NewSyscallError("CloseHandle", err) + } + + // Convert to a byte array. + db.data = ((*[maxMapSize]byte)(unsafe.Pointer(addr))) + db.datasz = sz + + return nil +} + +// munmap unmaps a pointer from a file. +// Based on: https://github.com/edsrzf/mmap-go +func munmap(db *DB) error { + if db.data == nil { + return nil + } + + addr := (uintptr)(unsafe.Pointer(&db.data[0])) + if err := syscall.UnmapViewOfFile(addr); err != nil { + return os.NewSyscallError("UnmapViewOfFile", err) + } + return nil +} diff --git a/vendor/github.com/boltdb/bolt/boltsync_unix.go b/vendor/github.com/boltdb/bolt/boltsync_unix.go new file mode 100644 index 00000000000..f50442523c3 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/boltsync_unix.go @@ -0,0 +1,8 @@ +// +build !windows,!plan9,!linux,!openbsd + +package bolt + +// fdatasync flushes written data to a file descriptor. +func fdatasync(db *DB) error { + return db.file.Sync() +} diff --git a/vendor/github.com/boltdb/bolt/bucket.go b/vendor/github.com/boltdb/bolt/bucket.go new file mode 100644 index 00000000000..0c5bf27463e --- /dev/null +++ b/vendor/github.com/boltdb/bolt/bucket.go @@ -0,0 +1,777 @@ +package bolt + +import ( + "bytes" + "fmt" + "unsafe" +) + +const ( + // MaxKeySize is the maximum length of a key, in bytes. + MaxKeySize = 32768 + + // MaxValueSize is the maximum length of a value, in bytes. + MaxValueSize = (1 << 31) - 2 +) + +const ( + maxUint = ^uint(0) + minUint = 0 + maxInt = int(^uint(0) >> 1) + minInt = -maxInt - 1 +) + +const bucketHeaderSize = int(unsafe.Sizeof(bucket{})) + +const ( + minFillPercent = 0.1 + maxFillPercent = 1.0 +) + +// DefaultFillPercent is the percentage that split pages are filled. +// This value can be changed by setting Bucket.FillPercent. +const DefaultFillPercent = 0.5 + +// Bucket represents a collection of key/value pairs inside the database. +type Bucket struct { + *bucket + tx *Tx // the associated transaction + buckets map[string]*Bucket // subbucket cache + page *page // inline page reference + rootNode *node // materialized node for the root page. + nodes map[pgid]*node // node cache + + // Sets the threshold for filling nodes when they split. By default, + // the bucket will fill to 50% but it can be useful to increase this + // amount if you know that your write workloads are mostly append-only. + // + // This is non-persisted across transactions so it must be set in every Tx. + FillPercent float64 +} + +// bucket represents the on-file representation of a bucket. +// This is stored as the "value" of a bucket key. If the bucket is small enough, +// then its root page can be stored inline in the "value", after the bucket +// header. In the case of inline buckets, the "root" will be 0. +type bucket struct { + root pgid // page id of the bucket's root-level page + sequence uint64 // monotonically incrementing, used by NextSequence() +} + +// newBucket returns a new bucket associated with a transaction. +func newBucket(tx *Tx) Bucket { + var b = Bucket{tx: tx, FillPercent: DefaultFillPercent} + if tx.writable { + b.buckets = make(map[string]*Bucket) + b.nodes = make(map[pgid]*node) + } + return b +} + +// Tx returns the tx of the bucket. +func (b *Bucket) Tx() *Tx { + return b.tx +} + +// Root returns the root of the bucket. +func (b *Bucket) Root() pgid { + return b.root +} + +// Writable returns whether the bucket is writable. +func (b *Bucket) Writable() bool { + return b.tx.writable +} + +// Cursor creates a cursor associated with the bucket. +// The cursor is only valid as long as the transaction is open. +// Do not use a cursor after the transaction is closed. +func (b *Bucket) Cursor() *Cursor { + // Update transaction statistics. + b.tx.stats.CursorCount++ + + // Allocate and return a cursor. + return &Cursor{ + bucket: b, + stack: make([]elemRef, 0), + } +} + +// Bucket retrieves a nested bucket by name. +// Returns nil if the bucket does not exist. +// The bucket instance is only valid for the lifetime of the transaction. +func (b *Bucket) Bucket(name []byte) *Bucket { + if b.buckets != nil { + if child := b.buckets[string(name)]; child != nil { + return child + } + } + + // Move cursor to key. + c := b.Cursor() + k, v, flags := c.seek(name) + + // Return nil if the key doesn't exist or it is not a bucket. + if !bytes.Equal(name, k) || (flags&bucketLeafFlag) == 0 { + return nil + } + + // Otherwise create a bucket and cache it. + var child = b.openBucket(v) + if b.buckets != nil { + b.buckets[string(name)] = child + } + + return child +} + +// Helper method that re-interprets a sub-bucket value +// from a parent into a Bucket +func (b *Bucket) openBucket(value []byte) *Bucket { + var child = newBucket(b.tx) + + // If unaligned load/stores are broken on this arch and value is + // unaligned simply clone to an aligned byte array. + unaligned := brokenUnaligned && uintptr(unsafe.Pointer(&value[0]))&3 != 0 + + if unaligned { + value = cloneBytes(value) + } + + // If this is a writable transaction then we need to copy the bucket entry. + // Read-only transactions can point directly at the mmap entry. + if b.tx.writable && !unaligned { + child.bucket = &bucket{} + *child.bucket = *(*bucket)(unsafe.Pointer(&value[0])) + } else { + child.bucket = (*bucket)(unsafe.Pointer(&value[0])) + } + + // Save a reference to the inline page if the bucket is inline. + if child.root == 0 { + child.page = (*page)(unsafe.Pointer(&value[bucketHeaderSize])) + } + + return &child +} + +// CreateBucket creates a new bucket at the given key and returns the new bucket. +// Returns an error if the key already exists, if the bucket name is blank, or if the bucket name is too long. +// The bucket instance is only valid for the lifetime of the transaction. +func (b *Bucket) CreateBucket(key []byte) (*Bucket, error) { + if b.tx.db == nil { + return nil, ErrTxClosed + } else if !b.tx.writable { + return nil, ErrTxNotWritable + } else if len(key) == 0 { + return nil, ErrBucketNameRequired + } + + // Move cursor to correct position. + c := b.Cursor() + k, _, flags := c.seek(key) + + // Return an error if there is an existing key. + if bytes.Equal(key, k) { + if (flags & bucketLeafFlag) != 0 { + return nil, ErrBucketExists + } + return nil, ErrIncompatibleValue + } + + // Create empty, inline bucket. + var bucket = Bucket{ + bucket: &bucket{}, + rootNode: &node{isLeaf: true}, + FillPercent: DefaultFillPercent, + } + var value = bucket.write() + + // Insert into node. + key = cloneBytes(key) + c.node().put(key, key, value, 0, bucketLeafFlag) + + // Since subbuckets are not allowed on inline buckets, we need to + // dereference the inline page, if it exists. This will cause the bucket + // to be treated as a regular, non-inline bucket for the rest of the tx. + b.page = nil + + return b.Bucket(key), nil +} + +// CreateBucketIfNotExists creates a new bucket if it doesn't already exist and returns a reference to it. +// Returns an error if the bucket name is blank, or if the bucket name is too long. +// The bucket instance is only valid for the lifetime of the transaction. +func (b *Bucket) CreateBucketIfNotExists(key []byte) (*Bucket, error) { + child, err := b.CreateBucket(key) + if err == ErrBucketExists { + return b.Bucket(key), nil + } else if err != nil { + return nil, err + } + return child, nil +} + +// DeleteBucket deletes a bucket at the given key. +// Returns an error if the bucket does not exists, or if the key represents a non-bucket value. +func (b *Bucket) DeleteBucket(key []byte) error { + if b.tx.db == nil { + return ErrTxClosed + } else if !b.Writable() { + return ErrTxNotWritable + } + + // Move cursor to correct position. + c := b.Cursor() + k, _, flags := c.seek(key) + + // Return an error if bucket doesn't exist or is not a bucket. + if !bytes.Equal(key, k) { + return ErrBucketNotFound + } else if (flags & bucketLeafFlag) == 0 { + return ErrIncompatibleValue + } + + // Recursively delete all child buckets. + child := b.Bucket(key) + err := child.ForEach(func(k, v []byte) error { + if v == nil { + if err := child.DeleteBucket(k); err != nil { + return fmt.Errorf("delete bucket: %s", err) + } + } + return nil + }) + if err != nil { + return err + } + + // Remove cached copy. + delete(b.buckets, string(key)) + + // Release all bucket pages to freelist. + child.nodes = nil + child.rootNode = nil + child.free() + + // Delete the node if we have a matching key. + c.node().del(key) + + return nil +} + +// Get retrieves the value for a key in the bucket. +// Returns a nil value if the key does not exist or if the key is a nested bucket. +// The returned value is only valid for the life of the transaction. +func (b *Bucket) Get(key []byte) []byte { + k, v, flags := b.Cursor().seek(key) + + // Return nil if this is a bucket. + if (flags & bucketLeafFlag) != 0 { + return nil + } + + // If our target node isn't the same key as what's passed in then return nil. + if !bytes.Equal(key, k) { + return nil + } + return v +} + +// Put sets the value for a key in the bucket. +// If the key exist then its previous value will be overwritten. +// Supplied value must remain valid for the life of the transaction. +// Returns an error if the bucket was created from a read-only transaction, if the key is blank, if the key is too large, or if the value is too large. +func (b *Bucket) Put(key []byte, value []byte) error { + if b.tx.db == nil { + return ErrTxClosed + } else if !b.Writable() { + return ErrTxNotWritable + } else if len(key) == 0 { + return ErrKeyRequired + } else if len(key) > MaxKeySize { + return ErrKeyTooLarge + } else if int64(len(value)) > MaxValueSize { + return ErrValueTooLarge + } + + // Move cursor to correct position. + c := b.Cursor() + k, _, flags := c.seek(key) + + // Return an error if there is an existing key with a bucket value. + if bytes.Equal(key, k) && (flags&bucketLeafFlag) != 0 { + return ErrIncompatibleValue + } + + // Insert into node. + key = cloneBytes(key) + c.node().put(key, key, value, 0, 0) + + return nil +} + +// Delete removes a key from the bucket. +// If the key does not exist then nothing is done and a nil error is returned. +// Returns an error if the bucket was created from a read-only transaction. +func (b *Bucket) Delete(key []byte) error { + if b.tx.db == nil { + return ErrTxClosed + } else if !b.Writable() { + return ErrTxNotWritable + } + + // Move cursor to correct position. + c := b.Cursor() + _, _, flags := c.seek(key) + + // Return an error if there is already existing bucket value. + if (flags & bucketLeafFlag) != 0 { + return ErrIncompatibleValue + } + + // Delete the node if we have a matching key. + c.node().del(key) + + return nil +} + +// Sequence returns the current integer for the bucket without incrementing it. +func (b *Bucket) Sequence() uint64 { return b.bucket.sequence } + +// SetSequence updates the sequence number for the bucket. +func (b *Bucket) SetSequence(v uint64) error { + if b.tx.db == nil { + return ErrTxClosed + } else if !b.Writable() { + return ErrTxNotWritable + } + + // Materialize the root node if it hasn't been already so that the + // bucket will be saved during commit. + if b.rootNode == nil { + _ = b.node(b.root, nil) + } + + // Increment and return the sequence. + b.bucket.sequence = v + return nil +} + +// NextSequence returns an autoincrementing integer for the bucket. +func (b *Bucket) NextSequence() (uint64, error) { + if b.tx.db == nil { + return 0, ErrTxClosed + } else if !b.Writable() { + return 0, ErrTxNotWritable + } + + // Materialize the root node if it hasn't been already so that the + // bucket will be saved during commit. + if b.rootNode == nil { + _ = b.node(b.root, nil) + } + + // Increment and return the sequence. + b.bucket.sequence++ + return b.bucket.sequence, nil +} + +// ForEach executes a function for each key/value pair in a bucket. +// If the provided function returns an error then the iteration is stopped and +// the error is returned to the caller. The provided function must not modify +// the bucket; this will result in undefined behavior. +func (b *Bucket) ForEach(fn func(k, v []byte) error) error { + if b.tx.db == nil { + return ErrTxClosed + } + c := b.Cursor() + for k, v := c.First(); k != nil; k, v = c.Next() { + if err := fn(k, v); err != nil { + return err + } + } + return nil +} + +// Stat returns stats on a bucket. +func (b *Bucket) Stats() BucketStats { + var s, subStats BucketStats + pageSize := b.tx.db.pageSize + s.BucketN += 1 + if b.root == 0 { + s.InlineBucketN += 1 + } + b.forEachPage(func(p *page, depth int) { + if (p.flags & leafPageFlag) != 0 { + s.KeyN += int(p.count) + + // used totals the used bytes for the page + used := pageHeaderSize + + if p.count != 0 { + // If page has any elements, add all element headers. + used += leafPageElementSize * int(p.count-1) + + // Add all element key, value sizes. + // The computation takes advantage of the fact that the position + // of the last element's key/value equals to the total of the sizes + // of all previous elements' keys and values. + // It also includes the last element's header. + lastElement := p.leafPageElement(p.count - 1) + used += int(lastElement.pos + lastElement.ksize + lastElement.vsize) + } + + if b.root == 0 { + // For inlined bucket just update the inline stats + s.InlineBucketInuse += used + } else { + // For non-inlined bucket update all the leaf stats + s.LeafPageN++ + s.LeafInuse += used + s.LeafOverflowN += int(p.overflow) + + // Collect stats from sub-buckets. + // Do that by iterating over all element headers + // looking for the ones with the bucketLeafFlag. + for i := uint16(0); i < p.count; i++ { + e := p.leafPageElement(i) + if (e.flags & bucketLeafFlag) != 0 { + // For any bucket element, open the element value + // and recursively call Stats on the contained bucket. + subStats.Add(b.openBucket(e.value()).Stats()) + } + } + } + } else if (p.flags & branchPageFlag) != 0 { + s.BranchPageN++ + lastElement := p.branchPageElement(p.count - 1) + + // used totals the used bytes for the page + // Add header and all element headers. + used := pageHeaderSize + (branchPageElementSize * int(p.count-1)) + + // Add size of all keys and values. + // Again, use the fact that last element's position equals to + // the total of key, value sizes of all previous elements. + used += int(lastElement.pos + lastElement.ksize) + s.BranchInuse += used + s.BranchOverflowN += int(p.overflow) + } + + // Keep track of maximum page depth. + if depth+1 > s.Depth { + s.Depth = (depth + 1) + } + }) + + // Alloc stats can be computed from page counts and pageSize. + s.BranchAlloc = (s.BranchPageN + s.BranchOverflowN) * pageSize + s.LeafAlloc = (s.LeafPageN + s.LeafOverflowN) * pageSize + + // Add the max depth of sub-buckets to get total nested depth. + s.Depth += subStats.Depth + // Add the stats for all sub-buckets + s.Add(subStats) + return s +} + +// forEachPage iterates over every page in a bucket, including inline pages. +func (b *Bucket) forEachPage(fn func(*page, int)) { + // If we have an inline page then just use that. + if b.page != nil { + fn(b.page, 0) + return + } + + // Otherwise traverse the page hierarchy. + b.tx.forEachPage(b.root, 0, fn) +} + +// forEachPageNode iterates over every page (or node) in a bucket. +// This also includes inline pages. +func (b *Bucket) forEachPageNode(fn func(*page, *node, int)) { + // If we have an inline page or root node then just use that. + if b.page != nil { + fn(b.page, nil, 0) + return + } + b._forEachPageNode(b.root, 0, fn) +} + +func (b *Bucket) _forEachPageNode(pgid pgid, depth int, fn func(*page, *node, int)) { + var p, n = b.pageNode(pgid) + + // Execute function. + fn(p, n, depth) + + // Recursively loop over children. + if p != nil { + if (p.flags & branchPageFlag) != 0 { + for i := 0; i < int(p.count); i++ { + elem := p.branchPageElement(uint16(i)) + b._forEachPageNode(elem.pgid, depth+1, fn) + } + } + } else { + if !n.isLeaf { + for _, inode := range n.inodes { + b._forEachPageNode(inode.pgid, depth+1, fn) + } + } + } +} + +// spill writes all the nodes for this bucket to dirty pages. +func (b *Bucket) spill() error { + // Spill all child buckets first. + for name, child := range b.buckets { + // If the child bucket is small enough and it has no child buckets then + // write it inline into the parent bucket's page. Otherwise spill it + // like a normal bucket and make the parent value a pointer to the page. + var value []byte + if child.inlineable() { + child.free() + value = child.write() + } else { + if err := child.spill(); err != nil { + return err + } + + // Update the child bucket header in this bucket. + value = make([]byte, unsafe.Sizeof(bucket{})) + var bucket = (*bucket)(unsafe.Pointer(&value[0])) + *bucket = *child.bucket + } + + // Skip writing the bucket if there are no materialized nodes. + if child.rootNode == nil { + continue + } + + // Update parent node. + var c = b.Cursor() + k, _, flags := c.seek([]byte(name)) + if !bytes.Equal([]byte(name), k) { + panic(fmt.Sprintf("misplaced bucket header: %x -> %x", []byte(name), k)) + } + if flags&bucketLeafFlag == 0 { + panic(fmt.Sprintf("unexpected bucket header flag: %x", flags)) + } + c.node().put([]byte(name), []byte(name), value, 0, bucketLeafFlag) + } + + // Ignore if there's not a materialized root node. + if b.rootNode == nil { + return nil + } + + // Spill nodes. + if err := b.rootNode.spill(); err != nil { + return err + } + b.rootNode = b.rootNode.root() + + // Update the root node for this bucket. + if b.rootNode.pgid >= b.tx.meta.pgid { + panic(fmt.Sprintf("pgid (%d) above high water mark (%d)", b.rootNode.pgid, b.tx.meta.pgid)) + } + b.root = b.rootNode.pgid + + return nil +} + +// inlineable returns true if a bucket is small enough to be written inline +// and if it contains no subbuckets. Otherwise returns false. +func (b *Bucket) inlineable() bool { + var n = b.rootNode + + // Bucket must only contain a single leaf node. + if n == nil || !n.isLeaf { + return false + } + + // Bucket is not inlineable if it contains subbuckets or if it goes beyond + // our threshold for inline bucket size. + var size = pageHeaderSize + for _, inode := range n.inodes { + size += leafPageElementSize + len(inode.key) + len(inode.value) + + if inode.flags&bucketLeafFlag != 0 { + return false + } else if size > b.maxInlineBucketSize() { + return false + } + } + + return true +} + +// Returns the maximum total size of a bucket to make it a candidate for inlining. +func (b *Bucket) maxInlineBucketSize() int { + return b.tx.db.pageSize / 4 +} + +// write allocates and writes a bucket to a byte slice. +func (b *Bucket) write() []byte { + // Allocate the appropriate size. + var n = b.rootNode + var value = make([]byte, bucketHeaderSize+n.size()) + + // Write a bucket header. + var bucket = (*bucket)(unsafe.Pointer(&value[0])) + *bucket = *b.bucket + + // Convert byte slice to a fake page and write the root node. + var p = (*page)(unsafe.Pointer(&value[bucketHeaderSize])) + n.write(p) + + return value +} + +// rebalance attempts to balance all nodes. +func (b *Bucket) rebalance() { + for _, n := range b.nodes { + n.rebalance() + } + for _, child := range b.buckets { + child.rebalance() + } +} + +// node creates a node from a page and associates it with a given parent. +func (b *Bucket) node(pgid pgid, parent *node) *node { + _assert(b.nodes != nil, "nodes map expected") + + // Retrieve node if it's already been created. + if n := b.nodes[pgid]; n != nil { + return n + } + + // Otherwise create a node and cache it. + n := &node{bucket: b, parent: parent} + if parent == nil { + b.rootNode = n + } else { + parent.children = append(parent.children, n) + } + + // Use the inline page if this is an inline bucket. + var p = b.page + if p == nil { + p = b.tx.page(pgid) + } + + // Read the page into the node and cache it. + n.read(p) + b.nodes[pgid] = n + + // Update statistics. + b.tx.stats.NodeCount++ + + return n +} + +// free recursively frees all pages in the bucket. +func (b *Bucket) free() { + if b.root == 0 { + return + } + + var tx = b.tx + b.forEachPageNode(func(p *page, n *node, _ int) { + if p != nil { + tx.db.freelist.free(tx.meta.txid, p) + } else { + n.free() + } + }) + b.root = 0 +} + +// dereference removes all references to the old mmap. +func (b *Bucket) dereference() { + if b.rootNode != nil { + b.rootNode.root().dereference() + } + + for _, child := range b.buckets { + child.dereference() + } +} + +// pageNode returns the in-memory node, if it exists. +// Otherwise returns the underlying page. +func (b *Bucket) pageNode(id pgid) (*page, *node) { + // Inline buckets have a fake page embedded in their value so treat them + // differently. We'll return the rootNode (if available) or the fake page. + if b.root == 0 { + if id != 0 { + panic(fmt.Sprintf("inline bucket non-zero page access(2): %d != 0", id)) + } + if b.rootNode != nil { + return nil, b.rootNode + } + return b.page, nil + } + + // Check the node cache for non-inline buckets. + if b.nodes != nil { + if n := b.nodes[id]; n != nil { + return nil, n + } + } + + // Finally lookup the page from the transaction if no node is materialized. + return b.tx.page(id), nil +} + +// BucketStats records statistics about resources used by a bucket. +type BucketStats struct { + // Page count statistics. + BranchPageN int // number of logical branch pages + BranchOverflowN int // number of physical branch overflow pages + LeafPageN int // number of logical leaf pages + LeafOverflowN int // number of physical leaf overflow pages + + // Tree statistics. + KeyN int // number of keys/value pairs + Depth int // number of levels in B+tree + + // Page size utilization. + BranchAlloc int // bytes allocated for physical branch pages + BranchInuse int // bytes actually used for branch data + LeafAlloc int // bytes allocated for physical leaf pages + LeafInuse int // bytes actually used for leaf data + + // Bucket statistics + BucketN int // total number of buckets including the top bucket + InlineBucketN int // total number on inlined buckets + InlineBucketInuse int // bytes used for inlined buckets (also accounted for in LeafInuse) +} + +func (s *BucketStats) Add(other BucketStats) { + s.BranchPageN += other.BranchPageN + s.BranchOverflowN += other.BranchOverflowN + s.LeafPageN += other.LeafPageN + s.LeafOverflowN += other.LeafOverflowN + s.KeyN += other.KeyN + if s.Depth < other.Depth { + s.Depth = other.Depth + } + s.BranchAlloc += other.BranchAlloc + s.BranchInuse += other.BranchInuse + s.LeafAlloc += other.LeafAlloc + s.LeafInuse += other.LeafInuse + + s.BucketN += other.BucketN + s.InlineBucketN += other.InlineBucketN + s.InlineBucketInuse += other.InlineBucketInuse +} + +// cloneBytes returns a copy of a given slice. +func cloneBytes(v []byte) []byte { + var clone = make([]byte, len(v)) + copy(clone, v) + return clone +} diff --git a/vendor/github.com/boltdb/bolt/cursor.go b/vendor/github.com/boltdb/bolt/cursor.go new file mode 100644 index 00000000000..1be9f35e3ef --- /dev/null +++ b/vendor/github.com/boltdb/bolt/cursor.go @@ -0,0 +1,400 @@ +package bolt + +import ( + "bytes" + "fmt" + "sort" +) + +// Cursor represents an iterator that can traverse over all key/value pairs in a bucket in sorted order. +// Cursors see nested buckets with value == nil. +// Cursors can be obtained from a transaction and are valid as long as the transaction is open. +// +// Keys and values returned from the cursor are only valid for the life of the transaction. +// +// Changing data while traversing with a cursor may cause it to be invalidated +// and return unexpected keys and/or values. You must reposition your cursor +// after mutating data. +type Cursor struct { + bucket *Bucket + stack []elemRef +} + +// Bucket returns the bucket that this cursor was created from. +func (c *Cursor) Bucket() *Bucket { + return c.bucket +} + +// First moves the cursor to the first item in the bucket and returns its key and value. +// If the bucket is empty then a nil key and value are returned. +// The returned key and value are only valid for the life of the transaction. +func (c *Cursor) First() (key []byte, value []byte) { + _assert(c.bucket.tx.db != nil, "tx closed") + c.stack = c.stack[:0] + p, n := c.bucket.pageNode(c.bucket.root) + c.stack = append(c.stack, elemRef{page: p, node: n, index: 0}) + c.first() + + // If we land on an empty page then move to the next value. + // https://github.com/boltdb/bolt/issues/450 + if c.stack[len(c.stack)-1].count() == 0 { + c.next() + } + + k, v, flags := c.keyValue() + if (flags & uint32(bucketLeafFlag)) != 0 { + return k, nil + } + return k, v + +} + +// Last moves the cursor to the last item in the bucket and returns its key and value. +// If the bucket is empty then a nil key and value are returned. +// The returned key and value are only valid for the life of the transaction. +func (c *Cursor) Last() (key []byte, value []byte) { + _assert(c.bucket.tx.db != nil, "tx closed") + c.stack = c.stack[:0] + p, n := c.bucket.pageNode(c.bucket.root) + ref := elemRef{page: p, node: n} + ref.index = ref.count() - 1 + c.stack = append(c.stack, ref) + c.last() + k, v, flags := c.keyValue() + if (flags & uint32(bucketLeafFlag)) != 0 { + return k, nil + } + return k, v +} + +// Next moves the cursor to the next item in the bucket and returns its key and value. +// If the cursor is at the end of the bucket then a nil key and value are returned. +// The returned key and value are only valid for the life of the transaction. +func (c *Cursor) Next() (key []byte, value []byte) { + _assert(c.bucket.tx.db != nil, "tx closed") + k, v, flags := c.next() + if (flags & uint32(bucketLeafFlag)) != 0 { + return k, nil + } + return k, v +} + +// Prev moves the cursor to the previous item in the bucket and returns its key and value. +// If the cursor is at the beginning of the bucket then a nil key and value are returned. +// The returned key and value are only valid for the life of the transaction. +func (c *Cursor) Prev() (key []byte, value []byte) { + _assert(c.bucket.tx.db != nil, "tx closed") + + // Attempt to move back one element until we're successful. + // Move up the stack as we hit the beginning of each page in our stack. + for i := len(c.stack) - 1; i >= 0; i-- { + elem := &c.stack[i] + if elem.index > 0 { + elem.index-- + break + } + c.stack = c.stack[:i] + } + + // If we've hit the end then return nil. + if len(c.stack) == 0 { + return nil, nil + } + + // Move down the stack to find the last element of the last leaf under this branch. + c.last() + k, v, flags := c.keyValue() + if (flags & uint32(bucketLeafFlag)) != 0 { + return k, nil + } + return k, v +} + +// Seek moves the cursor to a given key and returns it. +// If the key does not exist then the next key is used. If no keys +// follow, a nil key is returned. +// The returned key and value are only valid for the life of the transaction. +func (c *Cursor) Seek(seek []byte) (key []byte, value []byte) { + k, v, flags := c.seek(seek) + + // If we ended up after the last element of a page then move to the next one. + if ref := &c.stack[len(c.stack)-1]; ref.index >= ref.count() { + k, v, flags = c.next() + } + + if k == nil { + return nil, nil + } else if (flags & uint32(bucketLeafFlag)) != 0 { + return k, nil + } + return k, v +} + +// Delete removes the current key/value under the cursor from the bucket. +// Delete fails if current key/value is a bucket or if the transaction is not writable. +func (c *Cursor) Delete() error { + if c.bucket.tx.db == nil { + return ErrTxClosed + } else if !c.bucket.Writable() { + return ErrTxNotWritable + } + + key, _, flags := c.keyValue() + // Return an error if current value is a bucket. + if (flags & bucketLeafFlag) != 0 { + return ErrIncompatibleValue + } + c.node().del(key) + + return nil +} + +// seek moves the cursor to a given key and returns it. +// If the key does not exist then the next key is used. +func (c *Cursor) seek(seek []byte) (key []byte, value []byte, flags uint32) { + _assert(c.bucket.tx.db != nil, "tx closed") + + // Start from root page/node and traverse to correct page. + c.stack = c.stack[:0] + c.search(seek, c.bucket.root) + ref := &c.stack[len(c.stack)-1] + + // If the cursor is pointing to the end of page/node then return nil. + if ref.index >= ref.count() { + return nil, nil, 0 + } + + // If this is a bucket then return a nil value. + return c.keyValue() +} + +// first moves the cursor to the first leaf element under the last page in the stack. +func (c *Cursor) first() { + for { + // Exit when we hit a leaf page. + var ref = &c.stack[len(c.stack)-1] + if ref.isLeaf() { + break + } + + // Keep adding pages pointing to the first element to the stack. + var pgid pgid + if ref.node != nil { + pgid = ref.node.inodes[ref.index].pgid + } else { + pgid = ref.page.branchPageElement(uint16(ref.index)).pgid + } + p, n := c.bucket.pageNode(pgid) + c.stack = append(c.stack, elemRef{page: p, node: n, index: 0}) + } +} + +// last moves the cursor to the last leaf element under the last page in the stack. +func (c *Cursor) last() { + for { + // Exit when we hit a leaf page. + ref := &c.stack[len(c.stack)-1] + if ref.isLeaf() { + break + } + + // Keep adding pages pointing to the last element in the stack. + var pgid pgid + if ref.node != nil { + pgid = ref.node.inodes[ref.index].pgid + } else { + pgid = ref.page.branchPageElement(uint16(ref.index)).pgid + } + p, n := c.bucket.pageNode(pgid) + + var nextRef = elemRef{page: p, node: n} + nextRef.index = nextRef.count() - 1 + c.stack = append(c.stack, nextRef) + } +} + +// next moves to the next leaf element and returns the key and value. +// If the cursor is at the last leaf element then it stays there and returns nil. +func (c *Cursor) next() (key []byte, value []byte, flags uint32) { + for { + // Attempt to move over one element until we're successful. + // Move up the stack as we hit the end of each page in our stack. + var i int + for i = len(c.stack) - 1; i >= 0; i-- { + elem := &c.stack[i] + if elem.index < elem.count()-1 { + elem.index++ + break + } + } + + // If we've hit the root page then stop and return. This will leave the + // cursor on the last element of the last page. + if i == -1 { + return nil, nil, 0 + } + + // Otherwise start from where we left off in the stack and find the + // first element of the first leaf page. + c.stack = c.stack[:i+1] + c.first() + + // If this is an empty page then restart and move back up the stack. + // https://github.com/boltdb/bolt/issues/450 + if c.stack[len(c.stack)-1].count() == 0 { + continue + } + + return c.keyValue() + } +} + +// search recursively performs a binary search against a given page/node until it finds a given key. +func (c *Cursor) search(key []byte, pgid pgid) { + p, n := c.bucket.pageNode(pgid) + if p != nil && (p.flags&(branchPageFlag|leafPageFlag)) == 0 { + panic(fmt.Sprintf("invalid page type: %d: %x", p.id, p.flags)) + } + e := elemRef{page: p, node: n} + c.stack = append(c.stack, e) + + // If we're on a leaf page/node then find the specific node. + if e.isLeaf() { + c.nsearch(key) + return + } + + if n != nil { + c.searchNode(key, n) + return + } + c.searchPage(key, p) +} + +func (c *Cursor) searchNode(key []byte, n *node) { + var exact bool + index := sort.Search(len(n.inodes), func(i int) bool { + // TODO(benbjohnson): Optimize this range search. It's a bit hacky right now. + // sort.Search() finds the lowest index where f() != -1 but we need the highest index. + ret := bytes.Compare(n.inodes[i].key, key) + if ret == 0 { + exact = true + } + return ret != -1 + }) + if !exact && index > 0 { + index-- + } + c.stack[len(c.stack)-1].index = index + + // Recursively search to the next page. + c.search(key, n.inodes[index].pgid) +} + +func (c *Cursor) searchPage(key []byte, p *page) { + // Binary search for the correct range. + inodes := p.branchPageElements() + + var exact bool + index := sort.Search(int(p.count), func(i int) bool { + // TODO(benbjohnson): Optimize this range search. It's a bit hacky right now. + // sort.Search() finds the lowest index where f() != -1 but we need the highest index. + ret := bytes.Compare(inodes[i].key(), key) + if ret == 0 { + exact = true + } + return ret != -1 + }) + if !exact && index > 0 { + index-- + } + c.stack[len(c.stack)-1].index = index + + // Recursively search to the next page. + c.search(key, inodes[index].pgid) +} + +// nsearch searches the leaf node on the top of the stack for a key. +func (c *Cursor) nsearch(key []byte) { + e := &c.stack[len(c.stack)-1] + p, n := e.page, e.node + + // If we have a node then search its inodes. + if n != nil { + index := sort.Search(len(n.inodes), func(i int) bool { + return bytes.Compare(n.inodes[i].key, key) != -1 + }) + e.index = index + return + } + + // If we have a page then search its leaf elements. + inodes := p.leafPageElements() + index := sort.Search(int(p.count), func(i int) bool { + return bytes.Compare(inodes[i].key(), key) != -1 + }) + e.index = index +} + +// keyValue returns the key and value of the current leaf element. +func (c *Cursor) keyValue() ([]byte, []byte, uint32) { + ref := &c.stack[len(c.stack)-1] + if ref.count() == 0 || ref.index >= ref.count() { + return nil, nil, 0 + } + + // Retrieve value from node. + if ref.node != nil { + inode := &ref.node.inodes[ref.index] + return inode.key, inode.value, inode.flags + } + + // Or retrieve value from page. + elem := ref.page.leafPageElement(uint16(ref.index)) + return elem.key(), elem.value(), elem.flags +} + +// node returns the node that the cursor is currently positioned on. +func (c *Cursor) node() *node { + _assert(len(c.stack) > 0, "accessing a node with a zero-length cursor stack") + + // If the top of the stack is a leaf node then just return it. + if ref := &c.stack[len(c.stack)-1]; ref.node != nil && ref.isLeaf() { + return ref.node + } + + // Start from root and traverse down the hierarchy. + var n = c.stack[0].node + if n == nil { + n = c.bucket.node(c.stack[0].page.id, nil) + } + for _, ref := range c.stack[:len(c.stack)-1] { + _assert(!n.isLeaf, "expected branch node") + n = n.childAt(int(ref.index)) + } + _assert(n.isLeaf, "expected leaf node") + return n +} + +// elemRef represents a reference to an element on a given page/node. +type elemRef struct { + page *page + node *node + index int +} + +// isLeaf returns whether the ref is pointing at a leaf page/node. +func (r *elemRef) isLeaf() bool { + if r.node != nil { + return r.node.isLeaf + } + return (r.page.flags & leafPageFlag) != 0 +} + +// count returns the number of inodes or page elements. +func (r *elemRef) count() int { + if r.node != nil { + return len(r.node.inodes) + } + return int(r.page.count) +} diff --git a/vendor/github.com/boltdb/bolt/db.go b/vendor/github.com/boltdb/bolt/db.go new file mode 100644 index 00000000000..f352ff14fe4 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/db.go @@ -0,0 +1,1039 @@ +package bolt + +import ( + "errors" + "fmt" + "hash/fnv" + "log" + "os" + "runtime" + "runtime/debug" + "strings" + "sync" + "time" + "unsafe" +) + +// The largest step that can be taken when remapping the mmap. +const maxMmapStep = 1 << 30 // 1GB + +// The data file format version. +const version = 2 + +// Represents a marker value to indicate that a file is a Bolt DB. +const magic uint32 = 0xED0CDAED + +// IgnoreNoSync specifies whether the NoSync field of a DB is ignored when +// syncing changes to a file. This is required as some operating systems, +// such as OpenBSD, do not have a unified buffer cache (UBC) and writes +// must be synchronized using the msync(2) syscall. +const IgnoreNoSync = runtime.GOOS == "openbsd" + +// Default values if not set in a DB instance. +const ( + DefaultMaxBatchSize int = 1000 + DefaultMaxBatchDelay = 10 * time.Millisecond + DefaultAllocSize = 16 * 1024 * 1024 +) + +// default page size for db is set to the OS page size. +var defaultPageSize = os.Getpagesize() + +// DB represents a collection of buckets persisted to a file on disk. +// All data access is performed through transactions which can be obtained through the DB. +// All the functions on DB will return a ErrDatabaseNotOpen if accessed before Open() is called. +type DB struct { + // When enabled, the database will perform a Check() after every commit. + // A panic is issued if the database is in an inconsistent state. This + // flag has a large performance impact so it should only be used for + // debugging purposes. + StrictMode bool + + // Setting the NoSync flag will cause the database to skip fsync() + // calls after each commit. This can be useful when bulk loading data + // into a database and you can restart the bulk load in the event of + // a system failure or database corruption. Do not set this flag for + // normal use. + // + // If the package global IgnoreNoSync constant is true, this value is + // ignored. See the comment on that constant for more details. + // + // THIS IS UNSAFE. PLEASE USE WITH CAUTION. + NoSync bool + + // When true, skips the truncate call when growing the database. + // Setting this to true is only safe on non-ext3/ext4 systems. + // Skipping truncation avoids preallocation of hard drive space and + // bypasses a truncate() and fsync() syscall on remapping. + // + // https://github.com/boltdb/bolt/issues/284 + NoGrowSync bool + + // If you want to read the entire database fast, you can set MmapFlag to + // syscall.MAP_POPULATE on Linux 2.6.23+ for sequential read-ahead. + MmapFlags int + + // MaxBatchSize is the maximum size of a batch. Default value is + // copied from DefaultMaxBatchSize in Open. + // + // If <=0, disables batching. + // + // Do not change concurrently with calls to Batch. + MaxBatchSize int + + // MaxBatchDelay is the maximum delay before a batch starts. + // Default value is copied from DefaultMaxBatchDelay in Open. + // + // If <=0, effectively disables batching. + // + // Do not change concurrently with calls to Batch. + MaxBatchDelay time.Duration + + // AllocSize is the amount of space allocated when the database + // needs to create new pages. This is done to amortize the cost + // of truncate() and fsync() when growing the data file. + AllocSize int + + path string + file *os.File + lockfile *os.File // windows only + dataref []byte // mmap'ed readonly, write throws SEGV + data *[maxMapSize]byte + datasz int + filesz int // current on disk file size + meta0 *meta + meta1 *meta + pageSize int + opened bool + rwtx *Tx + txs []*Tx + freelist *freelist + stats Stats + + pagePool sync.Pool + + batchMu sync.Mutex + batch *batch + + rwlock sync.Mutex // Allows only one writer at a time. + metalock sync.Mutex // Protects meta page access. + mmaplock sync.RWMutex // Protects mmap access during remapping. + statlock sync.RWMutex // Protects stats access. + + ops struct { + writeAt func(b []byte, off int64) (n int, err error) + } + + // Read only mode. + // When true, Update() and Begin(true) return ErrDatabaseReadOnly immediately. + readOnly bool +} + +// Path returns the path to currently open database file. +func (db *DB) Path() string { + return db.path +} + +// GoString returns the Go string representation of the database. +func (db *DB) GoString() string { + return fmt.Sprintf("bolt.DB{path:%q}", db.path) +} + +// String returns the string representation of the database. +func (db *DB) String() string { + return fmt.Sprintf("DB<%q>", db.path) +} + +// Open creates and opens a database at the given path. +// If the file does not exist then it will be created automatically. +// Passing in nil options will cause Bolt to open the database with the default options. +func Open(path string, mode os.FileMode, options *Options) (*DB, error) { + var db = &DB{opened: true} + + // Set default options if no options are provided. + if options == nil { + options = DefaultOptions + } + db.NoGrowSync = options.NoGrowSync + db.MmapFlags = options.MmapFlags + + // Set default values for later DB operations. + db.MaxBatchSize = DefaultMaxBatchSize + db.MaxBatchDelay = DefaultMaxBatchDelay + db.AllocSize = DefaultAllocSize + + flag := os.O_RDWR + if options.ReadOnly { + flag = os.O_RDONLY + db.readOnly = true + } + + // Open data file and separate sync handler for metadata writes. + db.path = path + var err error + if db.file, err = os.OpenFile(db.path, flag|os.O_CREATE, mode); err != nil { + _ = db.close() + return nil, err + } + + // Lock file so that other processes using Bolt in read-write mode cannot + // use the database at the same time. This would cause corruption since + // the two processes would write meta pages and free pages separately. + // The database file is locked exclusively (only one process can grab the lock) + // if !options.ReadOnly. + // The database file is locked using the shared lock (more than one process may + // hold a lock at the same time) otherwise (options.ReadOnly is set). + if err := flock(db, mode, !db.readOnly, options.Timeout); err != nil { + _ = db.close() + return nil, err + } + + // Default values for test hooks + db.ops.writeAt = db.file.WriteAt + + // Initialize the database if it doesn't exist. + if info, err := db.file.Stat(); err != nil { + return nil, err + } else if info.Size() == 0 { + // Initialize new files with meta pages. + if err := db.init(); err != nil { + return nil, err + } + } else { + // Read the first meta page to determine the page size. + var buf [0x1000]byte + if _, err := db.file.ReadAt(buf[:], 0); err == nil { + m := db.pageInBuffer(buf[:], 0).meta() + if err := m.validate(); err != nil { + // If we can't read the page size, we can assume it's the same + // as the OS -- since that's how the page size was chosen in the + // first place. + // + // If the first page is invalid and this OS uses a different + // page size than what the database was created with then we + // are out of luck and cannot access the database. + db.pageSize = os.Getpagesize() + } else { + db.pageSize = int(m.pageSize) + } + } + } + + // Initialize page pool. + db.pagePool = sync.Pool{ + New: func() interface{} { + return make([]byte, db.pageSize) + }, + } + + // Memory map the data file. + if err := db.mmap(options.InitialMmapSize); err != nil { + _ = db.close() + return nil, err + } + + // Read in the freelist. + db.freelist = newFreelist() + db.freelist.read(db.page(db.meta().freelist)) + + // Mark the database as opened and return. + return db, nil +} + +// mmap opens the underlying memory-mapped file and initializes the meta references. +// minsz is the minimum size that the new mmap can be. +func (db *DB) mmap(minsz int) error { + db.mmaplock.Lock() + defer db.mmaplock.Unlock() + + info, err := db.file.Stat() + if err != nil { + return fmt.Errorf("mmap stat error: %s", err) + } else if int(info.Size()) < db.pageSize*2 { + return fmt.Errorf("file size too small") + } + + // Ensure the size is at least the minimum size. + var size = int(info.Size()) + if size < minsz { + size = minsz + } + size, err = db.mmapSize(size) + if err != nil { + return err + } + + // Dereference all mmap references before unmapping. + if db.rwtx != nil { + db.rwtx.root.dereference() + } + + // Unmap existing data before continuing. + if err := db.munmap(); err != nil { + return err + } + + // Memory-map the data file as a byte slice. + if err := mmap(db, size); err != nil { + return err + } + + // Save references to the meta pages. + db.meta0 = db.page(0).meta() + db.meta1 = db.page(1).meta() + + // Validate the meta pages. We only return an error if both meta pages fail + // validation, since meta0 failing validation means that it wasn't saved + // properly -- but we can recover using meta1. And vice-versa. + err0 := db.meta0.validate() + err1 := db.meta1.validate() + if err0 != nil && err1 != nil { + return err0 + } + + return nil +} + +// munmap unmaps the data file from memory. +func (db *DB) munmap() error { + if err := munmap(db); err != nil { + return fmt.Errorf("unmap error: " + err.Error()) + } + return nil +} + +// mmapSize determines the appropriate size for the mmap given the current size +// of the database. The minimum size is 32KB and doubles until it reaches 1GB. +// Returns an error if the new mmap size is greater than the max allowed. +func (db *DB) mmapSize(size int) (int, error) { + // Double the size from 32KB until 1GB. + for i := uint(15); i <= 30; i++ { + if size <= 1< maxMapSize { + return 0, fmt.Errorf("mmap too large") + } + + // If larger than 1GB then grow by 1GB at a time. + sz := int64(size) + if remainder := sz % int64(maxMmapStep); remainder > 0 { + sz += int64(maxMmapStep) - remainder + } + + // Ensure that the mmap size is a multiple of the page size. + // This should always be true since we're incrementing in MBs. + pageSize := int64(db.pageSize) + if (sz % pageSize) != 0 { + sz = ((sz / pageSize) + 1) * pageSize + } + + // If we've exceeded the max size then only grow up to the max size. + if sz > maxMapSize { + sz = maxMapSize + } + + return int(sz), nil +} + +// init creates a new database file and initializes its meta pages. +func (db *DB) init() error { + // Set the page size to the OS page size. + db.pageSize = os.Getpagesize() + + // Create two meta pages on a buffer. + buf := make([]byte, db.pageSize*4) + for i := 0; i < 2; i++ { + p := db.pageInBuffer(buf[:], pgid(i)) + p.id = pgid(i) + p.flags = metaPageFlag + + // Initialize the meta page. + m := p.meta() + m.magic = magic + m.version = version + m.pageSize = uint32(db.pageSize) + m.freelist = 2 + m.root = bucket{root: 3} + m.pgid = 4 + m.txid = txid(i) + m.checksum = m.sum64() + } + + // Write an empty freelist at page 3. + p := db.pageInBuffer(buf[:], pgid(2)) + p.id = pgid(2) + p.flags = freelistPageFlag + p.count = 0 + + // Write an empty leaf page at page 4. + p = db.pageInBuffer(buf[:], pgid(3)) + p.id = pgid(3) + p.flags = leafPageFlag + p.count = 0 + + // Write the buffer to our data file. + if _, err := db.ops.writeAt(buf, 0); err != nil { + return err + } + if err := fdatasync(db); err != nil { + return err + } + + return nil +} + +// Close releases all database resources. +// All transactions must be closed before closing the database. +func (db *DB) Close() error { + db.rwlock.Lock() + defer db.rwlock.Unlock() + + db.metalock.Lock() + defer db.metalock.Unlock() + + db.mmaplock.RLock() + defer db.mmaplock.RUnlock() + + return db.close() +} + +func (db *DB) close() error { + if !db.opened { + return nil + } + + db.opened = false + + db.freelist = nil + + // Clear ops. + db.ops.writeAt = nil + + // Close the mmap. + if err := db.munmap(); err != nil { + return err + } + + // Close file handles. + if db.file != nil { + // No need to unlock read-only file. + if !db.readOnly { + // Unlock the file. + if err := funlock(db); err != nil { + log.Printf("bolt.Close(): funlock error: %s", err) + } + } + + // Close the file descriptor. + if err := db.file.Close(); err != nil { + return fmt.Errorf("db file close: %s", err) + } + db.file = nil + } + + db.path = "" + return nil +} + +// Begin starts a new transaction. +// Multiple read-only transactions can be used concurrently but only one +// write transaction can be used at a time. Starting multiple write transactions +// will cause the calls to block and be serialized until the current write +// transaction finishes. +// +// Transactions should not be dependent on one another. Opening a read +// transaction and a write transaction in the same goroutine can cause the +// writer to deadlock because the database periodically needs to re-mmap itself +// as it grows and it cannot do that while a read transaction is open. +// +// If a long running read transaction (for example, a snapshot transaction) is +// needed, you might want to set DB.InitialMmapSize to a large enough value +// to avoid potential blocking of write transaction. +// +// IMPORTANT: You must close read-only transactions after you are finished or +// else the database will not reclaim old pages. +func (db *DB) Begin(writable bool) (*Tx, error) { + if writable { + return db.beginRWTx() + } + return db.beginTx() +} + +func (db *DB) beginTx() (*Tx, error) { + // Lock the meta pages while we initialize the transaction. We obtain + // the meta lock before the mmap lock because that's the order that the + // write transaction will obtain them. + db.metalock.Lock() + + // Obtain a read-only lock on the mmap. When the mmap is remapped it will + // obtain a write lock so all transactions must finish before it can be + // remapped. + db.mmaplock.RLock() + + // Exit if the database is not open yet. + if !db.opened { + db.mmaplock.RUnlock() + db.metalock.Unlock() + return nil, ErrDatabaseNotOpen + } + + // Create a transaction associated with the database. + t := &Tx{} + t.init(db) + + // Keep track of transaction until it closes. + db.txs = append(db.txs, t) + n := len(db.txs) + + // Unlock the meta pages. + db.metalock.Unlock() + + // Update the transaction stats. + db.statlock.Lock() + db.stats.TxN++ + db.stats.OpenTxN = n + db.statlock.Unlock() + + return t, nil +} + +func (db *DB) beginRWTx() (*Tx, error) { + // If the database was opened with Options.ReadOnly, return an error. + if db.readOnly { + return nil, ErrDatabaseReadOnly + } + + // Obtain writer lock. This is released by the transaction when it closes. + // This enforces only one writer transaction at a time. + db.rwlock.Lock() + + // Once we have the writer lock then we can lock the meta pages so that + // we can set up the transaction. + db.metalock.Lock() + defer db.metalock.Unlock() + + // Exit if the database is not open yet. + if !db.opened { + db.rwlock.Unlock() + return nil, ErrDatabaseNotOpen + } + + // Create a transaction associated with the database. + t := &Tx{writable: true} + t.init(db) + db.rwtx = t + + // Free any pages associated with closed read-only transactions. + var minid txid = 0xFFFFFFFFFFFFFFFF + for _, t := range db.txs { + if t.meta.txid < minid { + minid = t.meta.txid + } + } + if minid > 0 { + db.freelist.release(minid - 1) + } + + return t, nil +} + +// removeTx removes a transaction from the database. +func (db *DB) removeTx(tx *Tx) { + // Release the read lock on the mmap. + db.mmaplock.RUnlock() + + // Use the meta lock to restrict access to the DB object. + db.metalock.Lock() + + // Remove the transaction. + for i, t := range db.txs { + if t == tx { + last := len(db.txs) - 1 + db.txs[i] = db.txs[last] + db.txs[last] = nil + db.txs = db.txs[:last] + break + } + } + n := len(db.txs) + + // Unlock the meta pages. + db.metalock.Unlock() + + // Merge statistics. + db.statlock.Lock() + db.stats.OpenTxN = n + db.stats.TxStats.add(&tx.stats) + db.statlock.Unlock() +} + +// Update executes a function within the context of a read-write managed transaction. +// If no error is returned from the function then the transaction is committed. +// If an error is returned then the entire transaction is rolled back. +// Any error that is returned from the function or returned from the commit is +// returned from the Update() method. +// +// Attempting to manually commit or rollback within the function will cause a panic. +func (db *DB) Update(fn func(*Tx) error) error { + t, err := db.Begin(true) + if err != nil { + return err + } + + // Make sure the transaction rolls back in the event of a panic. + defer func() { + if t.db != nil { + t.rollback() + } + }() + + // Mark as a managed tx so that the inner function cannot manually commit. + t.managed = true + + // If an error is returned from the function then rollback and return error. + err = fn(t) + t.managed = false + if err != nil { + _ = t.Rollback() + return err + } + + return t.Commit() +} + +// View executes a function within the context of a managed read-only transaction. +// Any error that is returned from the function is returned from the View() method. +// +// Attempting to manually rollback within the function will cause a panic. +func (db *DB) View(fn func(*Tx) error) error { + t, err := db.Begin(false) + if err != nil { + return err + } + + // Make sure the transaction rolls back in the event of a panic. + defer func() { + if t.db != nil { + t.rollback() + } + }() + + // Mark as a managed tx so that the inner function cannot manually rollback. + t.managed = true + + // If an error is returned from the function then pass it through. + err = fn(t) + t.managed = false + if err != nil { + _ = t.Rollback() + return err + } + + if err := t.Rollback(); err != nil { + return err + } + + return nil +} + +// Batch calls fn as part of a batch. It behaves similar to Update, +// except: +// +// 1. concurrent Batch calls can be combined into a single Bolt +// transaction. +// +// 2. the function passed to Batch may be called multiple times, +// regardless of whether it returns error or not. +// +// This means that Batch function side effects must be idempotent and +// take permanent effect only after a successful return is seen in +// caller. +// +// The maximum batch size and delay can be adjusted with DB.MaxBatchSize +// and DB.MaxBatchDelay, respectively. +// +// Batch is only useful when there are multiple goroutines calling it. +func (db *DB) Batch(fn func(*Tx) error) error { + errCh := make(chan error, 1) + + db.batchMu.Lock() + if (db.batch == nil) || (db.batch != nil && len(db.batch.calls) >= db.MaxBatchSize) { + // There is no existing batch, or the existing batch is full; start a new one. + db.batch = &batch{ + db: db, + } + db.batch.timer = time.AfterFunc(db.MaxBatchDelay, db.batch.trigger) + } + db.batch.calls = append(db.batch.calls, call{fn: fn, err: errCh}) + if len(db.batch.calls) >= db.MaxBatchSize { + // wake up batch, it's ready to run + go db.batch.trigger() + } + db.batchMu.Unlock() + + err := <-errCh + if err == trySolo { + err = db.Update(fn) + } + return err +} + +type call struct { + fn func(*Tx) error + err chan<- error +} + +type batch struct { + db *DB + timer *time.Timer + start sync.Once + calls []call +} + +// trigger runs the batch if it hasn't already been run. +func (b *batch) trigger() { + b.start.Do(b.run) +} + +// run performs the transactions in the batch and communicates results +// back to DB.Batch. +func (b *batch) run() { + b.db.batchMu.Lock() + b.timer.Stop() + // Make sure no new work is added to this batch, but don't break + // other batches. + if b.db.batch == b { + b.db.batch = nil + } + b.db.batchMu.Unlock() + +retry: + for len(b.calls) > 0 { + var failIdx = -1 + err := b.db.Update(func(tx *Tx) error { + for i, c := range b.calls { + if err := safelyCall(c.fn, tx); err != nil { + failIdx = i + return err + } + } + return nil + }) + + if failIdx >= 0 { + // take the failing transaction out of the batch. it's + // safe to shorten b.calls here because db.batch no longer + // points to us, and we hold the mutex anyway. + c := b.calls[failIdx] + b.calls[failIdx], b.calls = b.calls[len(b.calls)-1], b.calls[:len(b.calls)-1] + // tell the submitter re-run it solo, continue with the rest of the batch + c.err <- trySolo + continue retry + } + + // pass success, or bolt internal errors, to all callers + for _, c := range b.calls { + if c.err != nil { + c.err <- err + } + } + break retry + } +} + +// trySolo is a special sentinel error value used for signaling that a +// transaction function should be re-run. It should never be seen by +// callers. +var trySolo = errors.New("batch function returned an error and should be re-run solo") + +type panicked struct { + reason interface{} +} + +func (p panicked) Error() string { + if err, ok := p.reason.(error); ok { + return err.Error() + } + return fmt.Sprintf("panic: %v", p.reason) +} + +func safelyCall(fn func(*Tx) error, tx *Tx) (err error) { + defer func() { + if p := recover(); p != nil { + err = panicked{p} + } + }() + return fn(tx) +} + +// Sync executes fdatasync() against the database file handle. +// +// This is not necessary under normal operation, however, if you use NoSync +// then it allows you to force the database file to sync against the disk. +func (db *DB) Sync() error { return fdatasync(db) } + +// Stats retrieves ongoing performance stats for the database. +// This is only updated when a transaction closes. +func (db *DB) Stats() Stats { + db.statlock.RLock() + defer db.statlock.RUnlock() + return db.stats +} + +// This is for internal access to the raw data bytes from the C cursor, use +// carefully, or not at all. +func (db *DB) Info() *Info { + return &Info{uintptr(unsafe.Pointer(&db.data[0])), db.pageSize} +} + +// page retrieves a page reference from the mmap based on the current page size. +func (db *DB) page(id pgid) *page { + pos := id * pgid(db.pageSize) + return (*page)(unsafe.Pointer(&db.data[pos])) +} + +// pageInBuffer retrieves a page reference from a given byte array based on the current page size. +func (db *DB) pageInBuffer(b []byte, id pgid) *page { + return (*page)(unsafe.Pointer(&b[id*pgid(db.pageSize)])) +} + +// meta retrieves the current meta page reference. +func (db *DB) meta() *meta { + // We have to return the meta with the highest txid which doesn't fail + // validation. Otherwise, we can cause errors when in fact the database is + // in a consistent state. metaA is the one with the higher txid. + metaA := db.meta0 + metaB := db.meta1 + if db.meta1.txid > db.meta0.txid { + metaA = db.meta1 + metaB = db.meta0 + } + + // Use higher meta page if valid. Otherwise fallback to previous, if valid. + if err := metaA.validate(); err == nil { + return metaA + } else if err := metaB.validate(); err == nil { + return metaB + } + + // This should never be reached, because both meta1 and meta0 were validated + // on mmap() and we do fsync() on every write. + panic("bolt.DB.meta(): invalid meta pages") +} + +// allocate returns a contiguous block of memory starting at a given page. +func (db *DB) allocate(count int) (*page, error) { + // Allocate a temporary buffer for the page. + var buf []byte + if count == 1 { + buf = db.pagePool.Get().([]byte) + } else { + buf = make([]byte, count*db.pageSize) + } + p := (*page)(unsafe.Pointer(&buf[0])) + p.overflow = uint32(count - 1) + + // Use pages from the freelist if they are available. + if p.id = db.freelist.allocate(count); p.id != 0 { + return p, nil + } + + // Resize mmap() if we're at the end. + p.id = db.rwtx.meta.pgid + var minsz = int((p.id+pgid(count))+1) * db.pageSize + if minsz >= db.datasz { + if err := db.mmap(minsz); err != nil { + return nil, fmt.Errorf("mmap allocate error: %s", err) + } + } + + // Move the page id high water mark. + db.rwtx.meta.pgid += pgid(count) + + return p, nil +} + +// grow grows the size of the database to the given sz. +func (db *DB) grow(sz int) error { + // Ignore if the new size is less than available file size. + if sz <= db.filesz { + return nil + } + + // If the data is smaller than the alloc size then only allocate what's needed. + // Once it goes over the allocation size then allocate in chunks. + if db.datasz < db.AllocSize { + sz = db.datasz + } else { + sz += db.AllocSize + } + + // Truncate and fsync to ensure file size metadata is flushed. + // https://github.com/boltdb/bolt/issues/284 + if !db.NoGrowSync && !db.readOnly { + if runtime.GOOS != "windows" { + if err := db.file.Truncate(int64(sz)); err != nil { + return fmt.Errorf("file resize error: %s", err) + } + } + if err := db.file.Sync(); err != nil { + return fmt.Errorf("file sync error: %s", err) + } + } + + db.filesz = sz + return nil +} + +func (db *DB) IsReadOnly() bool { + return db.readOnly +} + +// Options represents the options that can be set when opening a database. +type Options struct { + // Timeout is the amount of time to wait to obtain a file lock. + // When set to zero it will wait indefinitely. This option is only + // available on Darwin and Linux. + Timeout time.Duration + + // Sets the DB.NoGrowSync flag before memory mapping the file. + NoGrowSync bool + + // Open database in read-only mode. Uses flock(..., LOCK_SH |LOCK_NB) to + // grab a shared lock (UNIX). + ReadOnly bool + + // Sets the DB.MmapFlags flag before memory mapping the file. + MmapFlags int + + // InitialMmapSize is the initial mmap size of the database + // in bytes. Read transactions won't block write transaction + // if the InitialMmapSize is large enough to hold database mmap + // size. (See DB.Begin for more information) + // + // If <=0, the initial map size is 0. + // If initialMmapSize is smaller than the previous database size, + // it takes no effect. + InitialMmapSize int +} + +// DefaultOptions represent the options used if nil options are passed into Open(). +// No timeout is used which will cause Bolt to wait indefinitely for a lock. +var DefaultOptions = &Options{ + Timeout: 0, + NoGrowSync: false, +} + +// Stats represents statistics about the database. +type Stats struct { + // Freelist stats + FreePageN int // total number of free pages on the freelist + PendingPageN int // total number of pending pages on the freelist + FreeAlloc int // total bytes allocated in free pages + FreelistInuse int // total bytes used by the freelist + + // Transaction stats + TxN int // total number of started read transactions + OpenTxN int // number of currently open read transactions + + TxStats TxStats // global, ongoing stats. +} + +// Sub calculates and returns the difference between two sets of database stats. +// This is useful when obtaining stats at two different points and time and +// you need the performance counters that occurred within that time span. +func (s *Stats) Sub(other *Stats) Stats { + if other == nil { + return *s + } + var diff Stats + diff.FreePageN = s.FreePageN + diff.PendingPageN = s.PendingPageN + diff.FreeAlloc = s.FreeAlloc + diff.FreelistInuse = s.FreelistInuse + diff.TxN = s.TxN - other.TxN + diff.TxStats = s.TxStats.Sub(&other.TxStats) + return diff +} + +func (s *Stats) add(other *Stats) { + s.TxStats.add(&other.TxStats) +} + +type Info struct { + Data uintptr + PageSize int +} + +type meta struct { + magic uint32 + version uint32 + pageSize uint32 + flags uint32 + root bucket + freelist pgid + pgid pgid + txid txid + checksum uint64 +} + +// validate checks the marker bytes and version of the meta page to ensure it matches this binary. +func (m *meta) validate() error { + if m.magic != magic { + return ErrInvalid + } else if m.version != version { + return ErrVersionMismatch + } else if m.checksum != 0 && m.checksum != m.sum64() { + return ErrChecksum + } + return nil +} + +// copy copies one meta object to another. +func (m *meta) copy(dest *meta) { + *dest = *m +} + +// write writes the meta onto a page. +func (m *meta) write(p *page) { + if m.root.root >= m.pgid { + panic(fmt.Sprintf("root bucket pgid (%d) above high water mark (%d)", m.root.root, m.pgid)) + } else if m.freelist >= m.pgid { + panic(fmt.Sprintf("freelist pgid (%d) above high water mark (%d)", m.freelist, m.pgid)) + } + + // Page id is either going to be 0 or 1 which we can determine by the transaction ID. + p.id = pgid(m.txid % 2) + p.flags |= metaPageFlag + + // Calculate the checksum. + m.checksum = m.sum64() + + m.copy(p.meta()) +} + +// generates the checksum for the meta. +func (m *meta) sum64() uint64 { + var h = fnv.New64a() + _, _ = h.Write((*[unsafe.Offsetof(meta{}.checksum)]byte)(unsafe.Pointer(m))[:]) + return h.Sum64() +} + +// _assert will panic with a given formatted message if the given condition is false. +func _assert(condition bool, msg string, v ...interface{}) { + if !condition { + panic(fmt.Sprintf("assertion failed: "+msg, v...)) + } +} + +func warn(v ...interface{}) { fmt.Fprintln(os.Stderr, v...) } +func warnf(msg string, v ...interface{}) { fmt.Fprintf(os.Stderr, msg+"\n", v...) } + +func printstack() { + stack := strings.Join(strings.Split(string(debug.Stack()), "\n")[2:], "\n") + fmt.Fprintln(os.Stderr, stack) +} diff --git a/vendor/github.com/boltdb/bolt/doc.go b/vendor/github.com/boltdb/bolt/doc.go new file mode 100644 index 00000000000..cc937845dba --- /dev/null +++ b/vendor/github.com/boltdb/bolt/doc.go @@ -0,0 +1,44 @@ +/* +Package bolt implements a low-level key/value store in pure Go. It supports +fully serializable transactions, ACID semantics, and lock-free MVCC with +multiple readers and a single writer. Bolt can be used for projects that +want a simple data store without the need to add large dependencies such as +Postgres or MySQL. + +Bolt is a single-level, zero-copy, B+tree data store. This means that Bolt is +optimized for fast read access and does not require recovery in the event of a +system crash. Transactions which have not finished committing will simply be +rolled back in the event of a crash. + +The design of Bolt is based on Howard Chu's LMDB database project. + +Bolt currently works on Windows, Mac OS X, and Linux. + + +Basics + +There are only a few types in Bolt: DB, Bucket, Tx, and Cursor. The DB is +a collection of buckets and is represented by a single file on disk. A bucket is +a collection of unique keys that are associated with values. + +Transactions provide either read-only or read-write access to the database. +Read-only transactions can retrieve key/value pairs and can use Cursors to +iterate over the dataset sequentially. Read-write transactions can create and +delete buckets and can insert and remove keys. Only one read-write transaction +is allowed at a time. + + +Caveats + +The database uses a read-only, memory-mapped data file to ensure that +applications cannot corrupt the database, however, this means that keys and +values returned from Bolt cannot be changed. Writing to a read-only byte slice +will cause Go to panic. + +Keys and values retrieved from the database are only valid for the life of +the transaction. When used outside the transaction, these byte slices can +point to different data or can point to invalid memory which will cause a panic. + + +*/ +package bolt diff --git a/vendor/github.com/boltdb/bolt/errors.go b/vendor/github.com/boltdb/bolt/errors.go new file mode 100644 index 00000000000..a3620a3ebb2 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/errors.go @@ -0,0 +1,71 @@ +package bolt + +import "errors" + +// These errors can be returned when opening or calling methods on a DB. +var ( + // ErrDatabaseNotOpen is returned when a DB instance is accessed before it + // is opened or after it is closed. + ErrDatabaseNotOpen = errors.New("database not open") + + // ErrDatabaseOpen is returned when opening a database that is + // already open. + ErrDatabaseOpen = errors.New("database already open") + + // ErrInvalid is returned when both meta pages on a database are invalid. + // This typically occurs when a file is not a bolt database. + ErrInvalid = errors.New("invalid database") + + // ErrVersionMismatch is returned when the data file was created with a + // different version of Bolt. + ErrVersionMismatch = errors.New("version mismatch") + + // ErrChecksum is returned when either meta page checksum does not match. + ErrChecksum = errors.New("checksum error") + + // ErrTimeout is returned when a database cannot obtain an exclusive lock + // on the data file after the timeout passed to Open(). + ErrTimeout = errors.New("timeout") +) + +// These errors can occur when beginning or committing a Tx. +var ( + // ErrTxNotWritable is returned when performing a write operation on a + // read-only transaction. + ErrTxNotWritable = errors.New("tx not writable") + + // ErrTxClosed is returned when committing or rolling back a transaction + // that has already been committed or rolled back. + ErrTxClosed = errors.New("tx closed") + + // ErrDatabaseReadOnly is returned when a mutating transaction is started on a + // read-only database. + ErrDatabaseReadOnly = errors.New("database is in read-only mode") +) + +// These errors can occur when putting or deleting a value or a bucket. +var ( + // ErrBucketNotFound is returned when trying to access a bucket that has + // not been created yet. + ErrBucketNotFound = errors.New("bucket not found") + + // ErrBucketExists is returned when creating a bucket that already exists. + ErrBucketExists = errors.New("bucket already exists") + + // ErrBucketNameRequired is returned when creating a bucket with a blank name. + ErrBucketNameRequired = errors.New("bucket name required") + + // ErrKeyRequired is returned when inserting a zero-length key. + ErrKeyRequired = errors.New("key required") + + // ErrKeyTooLarge is returned when inserting a key that is larger than MaxKeySize. + ErrKeyTooLarge = errors.New("key too large") + + // ErrValueTooLarge is returned when inserting a value that is larger than MaxValueSize. + ErrValueTooLarge = errors.New("value too large") + + // ErrIncompatibleValue is returned when trying create or delete a bucket + // on an existing non-bucket key or when trying to create or delete a + // non-bucket key on an existing bucket key. + ErrIncompatibleValue = errors.New("incompatible value") +) diff --git a/vendor/github.com/boltdb/bolt/freelist.go b/vendor/github.com/boltdb/bolt/freelist.go new file mode 100644 index 00000000000..aba48f58c62 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/freelist.go @@ -0,0 +1,252 @@ +package bolt + +import ( + "fmt" + "sort" + "unsafe" +) + +// freelist represents a list of all pages that are available for allocation. +// It also tracks pages that have been freed but are still in use by open transactions. +type freelist struct { + ids []pgid // all free and available free page ids. + pending map[txid][]pgid // mapping of soon-to-be free page ids by tx. + cache map[pgid]bool // fast lookup of all free and pending page ids. +} + +// newFreelist returns an empty, initialized freelist. +func newFreelist() *freelist { + return &freelist{ + pending: make(map[txid][]pgid), + cache: make(map[pgid]bool), + } +} + +// size returns the size of the page after serialization. +func (f *freelist) size() int { + n := f.count() + if n >= 0xFFFF { + // The first element will be used to store the count. See freelist.write. + n++ + } + return pageHeaderSize + (int(unsafe.Sizeof(pgid(0))) * n) +} + +// count returns count of pages on the freelist +func (f *freelist) count() int { + return f.free_count() + f.pending_count() +} + +// free_count returns count of free pages +func (f *freelist) free_count() int { + return len(f.ids) +} + +// pending_count returns count of pending pages +func (f *freelist) pending_count() int { + var count int + for _, list := range f.pending { + count += len(list) + } + return count +} + +// copyall copies into dst a list of all free ids and all pending ids in one sorted list. +// f.count returns the minimum length required for dst. +func (f *freelist) copyall(dst []pgid) { + m := make(pgids, 0, f.pending_count()) + for _, list := range f.pending { + m = append(m, list...) + } + sort.Sort(m) + mergepgids(dst, f.ids, m) +} + +// allocate returns the starting page id of a contiguous list of pages of a given size. +// If a contiguous block cannot be found then 0 is returned. +func (f *freelist) allocate(n int) pgid { + if len(f.ids) == 0 { + return 0 + } + + var initial, previd pgid + for i, id := range f.ids { + if id <= 1 { + panic(fmt.Sprintf("invalid page allocation: %d", id)) + } + + // Reset initial page if this is not contiguous. + if previd == 0 || id-previd != 1 { + initial = id + } + + // If we found a contiguous block then remove it and return it. + if (id-initial)+1 == pgid(n) { + // If we're allocating off the beginning then take the fast path + // and just adjust the existing slice. This will use extra memory + // temporarily but the append() in free() will realloc the slice + // as is necessary. + if (i + 1) == n { + f.ids = f.ids[i+1:] + } else { + copy(f.ids[i-n+1:], f.ids[i+1:]) + f.ids = f.ids[:len(f.ids)-n] + } + + // Remove from the free cache. + for i := pgid(0); i < pgid(n); i++ { + delete(f.cache, initial+i) + } + + return initial + } + + previd = id + } + return 0 +} + +// free releases a page and its overflow for a given transaction id. +// If the page is already free then a panic will occur. +func (f *freelist) free(txid txid, p *page) { + if p.id <= 1 { + panic(fmt.Sprintf("cannot free page 0 or 1: %d", p.id)) + } + + // Free page and all its overflow pages. + var ids = f.pending[txid] + for id := p.id; id <= p.id+pgid(p.overflow); id++ { + // Verify that page is not already free. + if f.cache[id] { + panic(fmt.Sprintf("page %d already freed", id)) + } + + // Add to the freelist and cache. + ids = append(ids, id) + f.cache[id] = true + } + f.pending[txid] = ids +} + +// release moves all page ids for a transaction id (or older) to the freelist. +func (f *freelist) release(txid txid) { + m := make(pgids, 0) + for tid, ids := range f.pending { + if tid <= txid { + // Move transaction's pending pages to the available freelist. + // Don't remove from the cache since the page is still free. + m = append(m, ids...) + delete(f.pending, tid) + } + } + sort.Sort(m) + f.ids = pgids(f.ids).merge(m) +} + +// rollback removes the pages from a given pending tx. +func (f *freelist) rollback(txid txid) { + // Remove page ids from cache. + for _, id := range f.pending[txid] { + delete(f.cache, id) + } + + // Remove pages from pending list. + delete(f.pending, txid) +} + +// freed returns whether a given page is in the free list. +func (f *freelist) freed(pgid pgid) bool { + return f.cache[pgid] +} + +// read initializes the freelist from a freelist page. +func (f *freelist) read(p *page) { + // If the page.count is at the max uint16 value (64k) then it's considered + // an overflow and the size of the freelist is stored as the first element. + idx, count := 0, int(p.count) + if count == 0xFFFF { + idx = 1 + count = int(((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[0]) + } + + // Copy the list of page ids from the freelist. + if count == 0 { + f.ids = nil + } else { + ids := ((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[idx:count] + f.ids = make([]pgid, len(ids)) + copy(f.ids, ids) + + // Make sure they're sorted. + sort.Sort(pgids(f.ids)) + } + + // Rebuild the page cache. + f.reindex() +} + +// write writes the page ids onto a freelist page. All free and pending ids are +// saved to disk since in the event of a program crash, all pending ids will +// become free. +func (f *freelist) write(p *page) error { + // Combine the old free pgids and pgids waiting on an open transaction. + + // Update the header flag. + p.flags |= freelistPageFlag + + // The page.count can only hold up to 64k elements so if we overflow that + // number then we handle it by putting the size in the first element. + lenids := f.count() + if lenids == 0 { + p.count = uint16(lenids) + } else if lenids < 0xFFFF { + p.count = uint16(lenids) + f.copyall(((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[:]) + } else { + p.count = 0xFFFF + ((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[0] = pgid(lenids) + f.copyall(((*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)))[1:]) + } + + return nil +} + +// reload reads the freelist from a page and filters out pending items. +func (f *freelist) reload(p *page) { + f.read(p) + + // Build a cache of only pending pages. + pcache := make(map[pgid]bool) + for _, pendingIDs := range f.pending { + for _, pendingID := range pendingIDs { + pcache[pendingID] = true + } + } + + // Check each page in the freelist and build a new available freelist + // with any pages not in the pending lists. + var a []pgid + for _, id := range f.ids { + if !pcache[id] { + a = append(a, id) + } + } + f.ids = a + + // Once the available list is rebuilt then rebuild the free cache so that + // it includes the available and pending free pages. + f.reindex() +} + +// reindex rebuilds the free cache based on available and pending free lists. +func (f *freelist) reindex() { + f.cache = make(map[pgid]bool, len(f.ids)) + for _, id := range f.ids { + f.cache[id] = true + } + for _, pendingIDs := range f.pending { + for _, pendingID := range pendingIDs { + f.cache[pendingID] = true + } + } +} diff --git a/vendor/github.com/boltdb/bolt/node.go b/vendor/github.com/boltdb/bolt/node.go new file mode 100644 index 00000000000..159318b229c --- /dev/null +++ b/vendor/github.com/boltdb/bolt/node.go @@ -0,0 +1,604 @@ +package bolt + +import ( + "bytes" + "fmt" + "sort" + "unsafe" +) + +// node represents an in-memory, deserialized page. +type node struct { + bucket *Bucket + isLeaf bool + unbalanced bool + spilled bool + key []byte + pgid pgid + parent *node + children nodes + inodes inodes +} + +// root returns the top-level node this node is attached to. +func (n *node) root() *node { + if n.parent == nil { + return n + } + return n.parent.root() +} + +// minKeys returns the minimum number of inodes this node should have. +func (n *node) minKeys() int { + if n.isLeaf { + return 1 + } + return 2 +} + +// size returns the size of the node after serialization. +func (n *node) size() int { + sz, elsz := pageHeaderSize, n.pageElementSize() + for i := 0; i < len(n.inodes); i++ { + item := &n.inodes[i] + sz += elsz + len(item.key) + len(item.value) + } + return sz +} + +// sizeLessThan returns true if the node is less than a given size. +// This is an optimization to avoid calculating a large node when we only need +// to know if it fits inside a certain page size. +func (n *node) sizeLessThan(v int) bool { + sz, elsz := pageHeaderSize, n.pageElementSize() + for i := 0; i < len(n.inodes); i++ { + item := &n.inodes[i] + sz += elsz + len(item.key) + len(item.value) + if sz >= v { + return false + } + } + return true +} + +// pageElementSize returns the size of each page element based on the type of node. +func (n *node) pageElementSize() int { + if n.isLeaf { + return leafPageElementSize + } + return branchPageElementSize +} + +// childAt returns the child node at a given index. +func (n *node) childAt(index int) *node { + if n.isLeaf { + panic(fmt.Sprintf("invalid childAt(%d) on a leaf node", index)) + } + return n.bucket.node(n.inodes[index].pgid, n) +} + +// childIndex returns the index of a given child node. +func (n *node) childIndex(child *node) int { + index := sort.Search(len(n.inodes), func(i int) bool { return bytes.Compare(n.inodes[i].key, child.key) != -1 }) + return index +} + +// numChildren returns the number of children. +func (n *node) numChildren() int { + return len(n.inodes) +} + +// nextSibling returns the next node with the same parent. +func (n *node) nextSibling() *node { + if n.parent == nil { + return nil + } + index := n.parent.childIndex(n) + if index >= n.parent.numChildren()-1 { + return nil + } + return n.parent.childAt(index + 1) +} + +// prevSibling returns the previous node with the same parent. +func (n *node) prevSibling() *node { + if n.parent == nil { + return nil + } + index := n.parent.childIndex(n) + if index == 0 { + return nil + } + return n.parent.childAt(index - 1) +} + +// put inserts a key/value. +func (n *node) put(oldKey, newKey, value []byte, pgid pgid, flags uint32) { + if pgid >= n.bucket.tx.meta.pgid { + panic(fmt.Sprintf("pgid (%d) above high water mark (%d)", pgid, n.bucket.tx.meta.pgid)) + } else if len(oldKey) <= 0 { + panic("put: zero-length old key") + } else if len(newKey) <= 0 { + panic("put: zero-length new key") + } + + // Find insertion index. + index := sort.Search(len(n.inodes), func(i int) bool { return bytes.Compare(n.inodes[i].key, oldKey) != -1 }) + + // Add capacity and shift nodes if we don't have an exact match and need to insert. + exact := (len(n.inodes) > 0 && index < len(n.inodes) && bytes.Equal(n.inodes[index].key, oldKey)) + if !exact { + n.inodes = append(n.inodes, inode{}) + copy(n.inodes[index+1:], n.inodes[index:]) + } + + inode := &n.inodes[index] + inode.flags = flags + inode.key = newKey + inode.value = value + inode.pgid = pgid + _assert(len(inode.key) > 0, "put: zero-length inode key") +} + +// del removes a key from the node. +func (n *node) del(key []byte) { + // Find index of key. + index := sort.Search(len(n.inodes), func(i int) bool { return bytes.Compare(n.inodes[i].key, key) != -1 }) + + // Exit if the key isn't found. + if index >= len(n.inodes) || !bytes.Equal(n.inodes[index].key, key) { + return + } + + // Delete inode from the node. + n.inodes = append(n.inodes[:index], n.inodes[index+1:]...) + + // Mark the node as needing rebalancing. + n.unbalanced = true +} + +// read initializes the node from a page. +func (n *node) read(p *page) { + n.pgid = p.id + n.isLeaf = ((p.flags & leafPageFlag) != 0) + n.inodes = make(inodes, int(p.count)) + + for i := 0; i < int(p.count); i++ { + inode := &n.inodes[i] + if n.isLeaf { + elem := p.leafPageElement(uint16(i)) + inode.flags = elem.flags + inode.key = elem.key() + inode.value = elem.value() + } else { + elem := p.branchPageElement(uint16(i)) + inode.pgid = elem.pgid + inode.key = elem.key() + } + _assert(len(inode.key) > 0, "read: zero-length inode key") + } + + // Save first key so we can find the node in the parent when we spill. + if len(n.inodes) > 0 { + n.key = n.inodes[0].key + _assert(len(n.key) > 0, "read: zero-length node key") + } else { + n.key = nil + } +} + +// write writes the items onto one or more pages. +func (n *node) write(p *page) { + // Initialize page. + if n.isLeaf { + p.flags |= leafPageFlag + } else { + p.flags |= branchPageFlag + } + + if len(n.inodes) >= 0xFFFF { + panic(fmt.Sprintf("inode overflow: %d (pgid=%d)", len(n.inodes), p.id)) + } + p.count = uint16(len(n.inodes)) + + // Stop here if there are no items to write. + if p.count == 0 { + return + } + + // Loop over each item and write it to the page. + b := (*[maxAllocSize]byte)(unsafe.Pointer(&p.ptr))[n.pageElementSize()*len(n.inodes):] + for i, item := range n.inodes { + _assert(len(item.key) > 0, "write: zero-length inode key") + + // Write the page element. + if n.isLeaf { + elem := p.leafPageElement(uint16(i)) + elem.pos = uint32(uintptr(unsafe.Pointer(&b[0])) - uintptr(unsafe.Pointer(elem))) + elem.flags = item.flags + elem.ksize = uint32(len(item.key)) + elem.vsize = uint32(len(item.value)) + } else { + elem := p.branchPageElement(uint16(i)) + elem.pos = uint32(uintptr(unsafe.Pointer(&b[0])) - uintptr(unsafe.Pointer(elem))) + elem.ksize = uint32(len(item.key)) + elem.pgid = item.pgid + _assert(elem.pgid != p.id, "write: circular dependency occurred") + } + + // If the length of key+value is larger than the max allocation size + // then we need to reallocate the byte array pointer. + // + // See: https://github.com/boltdb/bolt/pull/335 + klen, vlen := len(item.key), len(item.value) + if len(b) < klen+vlen { + b = (*[maxAllocSize]byte)(unsafe.Pointer(&b[0]))[:] + } + + // Write data for the element to the end of the page. + copy(b[0:], item.key) + b = b[klen:] + copy(b[0:], item.value) + b = b[vlen:] + } + + // DEBUG ONLY: n.dump() +} + +// split breaks up a node into multiple smaller nodes, if appropriate. +// This should only be called from the spill() function. +func (n *node) split(pageSize int) []*node { + var nodes []*node + + node := n + for { + // Split node into two. + a, b := node.splitTwo(pageSize) + nodes = append(nodes, a) + + // If we can't split then exit the loop. + if b == nil { + break + } + + // Set node to b so it gets split on the next iteration. + node = b + } + + return nodes +} + +// splitTwo breaks up a node into two smaller nodes, if appropriate. +// This should only be called from the split() function. +func (n *node) splitTwo(pageSize int) (*node, *node) { + // Ignore the split if the page doesn't have at least enough nodes for + // two pages or if the nodes can fit in a single page. + if len(n.inodes) <= (minKeysPerPage*2) || n.sizeLessThan(pageSize) { + return n, nil + } + + // Determine the threshold before starting a new node. + var fillPercent = n.bucket.FillPercent + if fillPercent < minFillPercent { + fillPercent = minFillPercent + } else if fillPercent > maxFillPercent { + fillPercent = maxFillPercent + } + threshold := int(float64(pageSize) * fillPercent) + + // Determine split position and sizes of the two pages. + splitIndex, _ := n.splitIndex(threshold) + + // Split node into two separate nodes. + // If there's no parent then we'll need to create one. + if n.parent == nil { + n.parent = &node{bucket: n.bucket, children: []*node{n}} + } + + // Create a new node and add it to the parent. + next := &node{bucket: n.bucket, isLeaf: n.isLeaf, parent: n.parent} + n.parent.children = append(n.parent.children, next) + + // Split inodes across two nodes. + next.inodes = n.inodes[splitIndex:] + n.inodes = n.inodes[:splitIndex] + + // Update the statistics. + n.bucket.tx.stats.Split++ + + return n, next +} + +// splitIndex finds the position where a page will fill a given threshold. +// It returns the index as well as the size of the first page. +// This is only be called from split(). +func (n *node) splitIndex(threshold int) (index, sz int) { + sz = pageHeaderSize + + // Loop until we only have the minimum number of keys required for the second page. + for i := 0; i < len(n.inodes)-minKeysPerPage; i++ { + index = i + inode := n.inodes[i] + elsize := n.pageElementSize() + len(inode.key) + len(inode.value) + + // If we have at least the minimum number of keys and adding another + // node would put us over the threshold then exit and return. + if i >= minKeysPerPage && sz+elsize > threshold { + break + } + + // Add the element size to the total size. + sz += elsize + } + + return +} + +// spill writes the nodes to dirty pages and splits nodes as it goes. +// Returns an error if dirty pages cannot be allocated. +func (n *node) spill() error { + var tx = n.bucket.tx + if n.spilled { + return nil + } + + // Spill child nodes first. Child nodes can materialize sibling nodes in + // the case of split-merge so we cannot use a range loop. We have to check + // the children size on every loop iteration. + sort.Sort(n.children) + for i := 0; i < len(n.children); i++ { + if err := n.children[i].spill(); err != nil { + return err + } + } + + // We no longer need the child list because it's only used for spill tracking. + n.children = nil + + // Split nodes into appropriate sizes. The first node will always be n. + var nodes = n.split(tx.db.pageSize) + for _, node := range nodes { + // Add node's page to the freelist if it's not new. + if node.pgid > 0 { + tx.db.freelist.free(tx.meta.txid, tx.page(node.pgid)) + node.pgid = 0 + } + + // Allocate contiguous space for the node. + p, err := tx.allocate((node.size() / tx.db.pageSize) + 1) + if err != nil { + return err + } + + // Write the node. + if p.id >= tx.meta.pgid { + panic(fmt.Sprintf("pgid (%d) above high water mark (%d)", p.id, tx.meta.pgid)) + } + node.pgid = p.id + node.write(p) + node.spilled = true + + // Insert into parent inodes. + if node.parent != nil { + var key = node.key + if key == nil { + key = node.inodes[0].key + } + + node.parent.put(key, node.inodes[0].key, nil, node.pgid, 0) + node.key = node.inodes[0].key + _assert(len(node.key) > 0, "spill: zero-length node key") + } + + // Update the statistics. + tx.stats.Spill++ + } + + // If the root node split and created a new root then we need to spill that + // as well. We'll clear out the children to make sure it doesn't try to respill. + if n.parent != nil && n.parent.pgid == 0 { + n.children = nil + return n.parent.spill() + } + + return nil +} + +// rebalance attempts to combine the node with sibling nodes if the node fill +// size is below a threshold or if there are not enough keys. +func (n *node) rebalance() { + if !n.unbalanced { + return + } + n.unbalanced = false + + // Update statistics. + n.bucket.tx.stats.Rebalance++ + + // Ignore if node is above threshold (25%) and has enough keys. + var threshold = n.bucket.tx.db.pageSize / 4 + if n.size() > threshold && len(n.inodes) > n.minKeys() { + return + } + + // Root node has special handling. + if n.parent == nil { + // If root node is a branch and only has one node then collapse it. + if !n.isLeaf && len(n.inodes) == 1 { + // Move root's child up. + child := n.bucket.node(n.inodes[0].pgid, n) + n.isLeaf = child.isLeaf + n.inodes = child.inodes[:] + n.children = child.children + + // Reparent all child nodes being moved. + for _, inode := range n.inodes { + if child, ok := n.bucket.nodes[inode.pgid]; ok { + child.parent = n + } + } + + // Remove old child. + child.parent = nil + delete(n.bucket.nodes, child.pgid) + child.free() + } + + return + } + + // If node has no keys then just remove it. + if n.numChildren() == 0 { + n.parent.del(n.key) + n.parent.removeChild(n) + delete(n.bucket.nodes, n.pgid) + n.free() + n.parent.rebalance() + return + } + + _assert(n.parent.numChildren() > 1, "parent must have at least 2 children") + + // Destination node is right sibling if idx == 0, otherwise left sibling. + var target *node + var useNextSibling = (n.parent.childIndex(n) == 0) + if useNextSibling { + target = n.nextSibling() + } else { + target = n.prevSibling() + } + + // If both this node and the target node are too small then merge them. + if useNextSibling { + // Reparent all child nodes being moved. + for _, inode := range target.inodes { + if child, ok := n.bucket.nodes[inode.pgid]; ok { + child.parent.removeChild(child) + child.parent = n + child.parent.children = append(child.parent.children, child) + } + } + + // Copy over inodes from target and remove target. + n.inodes = append(n.inodes, target.inodes...) + n.parent.del(target.key) + n.parent.removeChild(target) + delete(n.bucket.nodes, target.pgid) + target.free() + } else { + // Reparent all child nodes being moved. + for _, inode := range n.inodes { + if child, ok := n.bucket.nodes[inode.pgid]; ok { + child.parent.removeChild(child) + child.parent = target + child.parent.children = append(child.parent.children, child) + } + } + + // Copy over inodes to target and remove node. + target.inodes = append(target.inodes, n.inodes...) + n.parent.del(n.key) + n.parent.removeChild(n) + delete(n.bucket.nodes, n.pgid) + n.free() + } + + // Either this node or the target node was deleted from the parent so rebalance it. + n.parent.rebalance() +} + +// removes a node from the list of in-memory children. +// This does not affect the inodes. +func (n *node) removeChild(target *node) { + for i, child := range n.children { + if child == target { + n.children = append(n.children[:i], n.children[i+1:]...) + return + } + } +} + +// dereference causes the node to copy all its inode key/value references to heap memory. +// This is required when the mmap is reallocated so inodes are not pointing to stale data. +func (n *node) dereference() { + if n.key != nil { + key := make([]byte, len(n.key)) + copy(key, n.key) + n.key = key + _assert(n.pgid == 0 || len(n.key) > 0, "dereference: zero-length node key on existing node") + } + + for i := range n.inodes { + inode := &n.inodes[i] + + key := make([]byte, len(inode.key)) + copy(key, inode.key) + inode.key = key + _assert(len(inode.key) > 0, "dereference: zero-length inode key") + + value := make([]byte, len(inode.value)) + copy(value, inode.value) + inode.value = value + } + + // Recursively dereference children. + for _, child := range n.children { + child.dereference() + } + + // Update statistics. + n.bucket.tx.stats.NodeDeref++ +} + +// free adds the node's underlying page to the freelist. +func (n *node) free() { + if n.pgid != 0 { + n.bucket.tx.db.freelist.free(n.bucket.tx.meta.txid, n.bucket.tx.page(n.pgid)) + n.pgid = 0 + } +} + +// dump writes the contents of the node to STDERR for debugging purposes. +/* +func (n *node) dump() { + // Write node header. + var typ = "branch" + if n.isLeaf { + typ = "leaf" + } + warnf("[NODE %d {type=%s count=%d}]", n.pgid, typ, len(n.inodes)) + + // Write out abbreviated version of each item. + for _, item := range n.inodes { + if n.isLeaf { + if item.flags&bucketLeafFlag != 0 { + bucket := (*bucket)(unsafe.Pointer(&item.value[0])) + warnf("+L %08x -> (bucket root=%d)", trunc(item.key, 4), bucket.root) + } else { + warnf("+L %08x -> %08x", trunc(item.key, 4), trunc(item.value, 4)) + } + } else { + warnf("+B %08x -> pgid=%d", trunc(item.key, 4), item.pgid) + } + } + warn("") +} +*/ + +type nodes []*node + +func (s nodes) Len() int { return len(s) } +func (s nodes) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s nodes) Less(i, j int) bool { return bytes.Compare(s[i].inodes[0].key, s[j].inodes[0].key) == -1 } + +// inode represents an internal node inside of a node. +// It can be used to point to elements in a page or point +// to an element which hasn't been added to a page yet. +type inode struct { + flags uint32 + pgid pgid + key []byte + value []byte +} + +type inodes []inode diff --git a/vendor/github.com/boltdb/bolt/page.go b/vendor/github.com/boltdb/bolt/page.go new file mode 100644 index 00000000000..cde403ae86d --- /dev/null +++ b/vendor/github.com/boltdb/bolt/page.go @@ -0,0 +1,197 @@ +package bolt + +import ( + "fmt" + "os" + "sort" + "unsafe" +) + +const pageHeaderSize = int(unsafe.Offsetof(((*page)(nil)).ptr)) + +const minKeysPerPage = 2 + +const branchPageElementSize = int(unsafe.Sizeof(branchPageElement{})) +const leafPageElementSize = int(unsafe.Sizeof(leafPageElement{})) + +const ( + branchPageFlag = 0x01 + leafPageFlag = 0x02 + metaPageFlag = 0x04 + freelistPageFlag = 0x10 +) + +const ( + bucketLeafFlag = 0x01 +) + +type pgid uint64 + +type page struct { + id pgid + flags uint16 + count uint16 + overflow uint32 + ptr uintptr +} + +// typ returns a human readable page type string used for debugging. +func (p *page) typ() string { + if (p.flags & branchPageFlag) != 0 { + return "branch" + } else if (p.flags & leafPageFlag) != 0 { + return "leaf" + } else if (p.flags & metaPageFlag) != 0 { + return "meta" + } else if (p.flags & freelistPageFlag) != 0 { + return "freelist" + } + return fmt.Sprintf("unknown<%02x>", p.flags) +} + +// meta returns a pointer to the metadata section of the page. +func (p *page) meta() *meta { + return (*meta)(unsafe.Pointer(&p.ptr)) +} + +// leafPageElement retrieves the leaf node by index +func (p *page) leafPageElement(index uint16) *leafPageElement { + n := &((*[0x7FFFFFF]leafPageElement)(unsafe.Pointer(&p.ptr)))[index] + return n +} + +// leafPageElements retrieves a list of leaf nodes. +func (p *page) leafPageElements() []leafPageElement { + if p.count == 0 { + return nil + } + return ((*[0x7FFFFFF]leafPageElement)(unsafe.Pointer(&p.ptr)))[:] +} + +// branchPageElement retrieves the branch node by index +func (p *page) branchPageElement(index uint16) *branchPageElement { + return &((*[0x7FFFFFF]branchPageElement)(unsafe.Pointer(&p.ptr)))[index] +} + +// branchPageElements retrieves a list of branch nodes. +func (p *page) branchPageElements() []branchPageElement { + if p.count == 0 { + return nil + } + return ((*[0x7FFFFFF]branchPageElement)(unsafe.Pointer(&p.ptr)))[:] +} + +// dump writes n bytes of the page to STDERR as hex output. +func (p *page) hexdump(n int) { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(p))[:n] + fmt.Fprintf(os.Stderr, "%x\n", buf) +} + +type pages []*page + +func (s pages) Len() int { return len(s) } +func (s pages) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s pages) Less(i, j int) bool { return s[i].id < s[j].id } + +// branchPageElement represents a node on a branch page. +type branchPageElement struct { + pos uint32 + ksize uint32 + pgid pgid +} + +// key returns a byte slice of the node key. +func (n *branchPageElement) key() []byte { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) + return (*[maxAllocSize]byte)(unsafe.Pointer(&buf[n.pos]))[:n.ksize] +} + +// leafPageElement represents a node on a leaf page. +type leafPageElement struct { + flags uint32 + pos uint32 + ksize uint32 + vsize uint32 +} + +// key returns a byte slice of the node key. +func (n *leafPageElement) key() []byte { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) + return (*[maxAllocSize]byte)(unsafe.Pointer(&buf[n.pos]))[:n.ksize:n.ksize] +} + +// value returns a byte slice of the node value. +func (n *leafPageElement) value() []byte { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) + return (*[maxAllocSize]byte)(unsafe.Pointer(&buf[n.pos+n.ksize]))[:n.vsize:n.vsize] +} + +// PageInfo represents human readable information about a page. +type PageInfo struct { + ID int + Type string + Count int + OverflowCount int +} + +type pgids []pgid + +func (s pgids) Len() int { return len(s) } +func (s pgids) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s pgids) Less(i, j int) bool { return s[i] < s[j] } + +// merge returns the sorted union of a and b. +func (a pgids) merge(b pgids) pgids { + // Return the opposite slice if one is nil. + if len(a) == 0 { + return b + } + if len(b) == 0 { + return a + } + merged := make(pgids, len(a)+len(b)) + mergepgids(merged, a, b) + return merged +} + +// mergepgids copies the sorted union of a and b into dst. +// If dst is too small, it panics. +func mergepgids(dst, a, b pgids) { + if len(dst) < len(a)+len(b) { + panic(fmt.Errorf("mergepgids bad len %d < %d + %d", len(dst), len(a), len(b))) + } + // Copy in the opposite slice if one is nil. + if len(a) == 0 { + copy(dst, b) + return + } + if len(b) == 0 { + copy(dst, a) + return + } + + // Merged will hold all elements from both lists. + merged := dst[:0] + + // Assign lead to the slice with a lower starting value, follow to the higher value. + lead, follow := a, b + if b[0] < a[0] { + lead, follow = b, a + } + + // Continue while there are elements in the lead. + for len(lead) > 0 { + // Merge largest prefix of lead that is ahead of follow[0]. + n := sort.Search(len(lead), func(i int) bool { return lead[i] > follow[0] }) + merged = append(merged, lead[:n]...) + if n >= len(lead) { + break + } + + // Swap lead and follow. + lead, follow = follow, lead[n:] + } + + // Append what's left in follow. + _ = append(merged, follow...) +} diff --git a/vendor/github.com/boltdb/bolt/tx.go b/vendor/github.com/boltdb/bolt/tx.go new file mode 100644 index 00000000000..6700308a290 --- /dev/null +++ b/vendor/github.com/boltdb/bolt/tx.go @@ -0,0 +1,684 @@ +package bolt + +import ( + "fmt" + "io" + "os" + "sort" + "strings" + "time" + "unsafe" +) + +// txid represents the internal transaction identifier. +type txid uint64 + +// Tx represents a read-only or read/write transaction on the database. +// Read-only transactions can be used for retrieving values for keys and creating cursors. +// Read/write transactions can create and remove buckets and create and remove keys. +// +// IMPORTANT: You must commit or rollback transactions when you are done with +// them. Pages can not be reclaimed by the writer until no more transactions +// are using them. A long running read transaction can cause the database to +// quickly grow. +type Tx struct { + writable bool + managed bool + db *DB + meta *meta + root Bucket + pages map[pgid]*page + stats TxStats + commitHandlers []func() + + // WriteFlag specifies the flag for write-related methods like WriteTo(). + // Tx opens the database file with the specified flag to copy the data. + // + // By default, the flag is unset, which works well for mostly in-memory + // workloads. For databases that are much larger than available RAM, + // set the flag to syscall.O_DIRECT to avoid trashing the page cache. + WriteFlag int +} + +// init initializes the transaction. +func (tx *Tx) init(db *DB) { + tx.db = db + tx.pages = nil + + // Copy the meta page since it can be changed by the writer. + tx.meta = &meta{} + db.meta().copy(tx.meta) + + // Copy over the root bucket. + tx.root = newBucket(tx) + tx.root.bucket = &bucket{} + *tx.root.bucket = tx.meta.root + + // Increment the transaction id and add a page cache for writable transactions. + if tx.writable { + tx.pages = make(map[pgid]*page) + tx.meta.txid += txid(1) + } +} + +// ID returns the transaction id. +func (tx *Tx) ID() int { + return int(tx.meta.txid) +} + +// DB returns a reference to the database that created the transaction. +func (tx *Tx) DB() *DB { + return tx.db +} + +// Size returns current database size in bytes as seen by this transaction. +func (tx *Tx) Size() int64 { + return int64(tx.meta.pgid) * int64(tx.db.pageSize) +} + +// Writable returns whether the transaction can perform write operations. +func (tx *Tx) Writable() bool { + return tx.writable +} + +// Cursor creates a cursor associated with the root bucket. +// All items in the cursor will return a nil value because all root bucket keys point to buckets. +// The cursor is only valid as long as the transaction is open. +// Do not use a cursor after the transaction is closed. +func (tx *Tx) Cursor() *Cursor { + return tx.root.Cursor() +} + +// Stats retrieves a copy of the current transaction statistics. +func (tx *Tx) Stats() TxStats { + return tx.stats +} + +// Bucket retrieves a bucket by name. +// Returns nil if the bucket does not exist. +// The bucket instance is only valid for the lifetime of the transaction. +func (tx *Tx) Bucket(name []byte) *Bucket { + return tx.root.Bucket(name) +} + +// CreateBucket creates a new bucket. +// Returns an error if the bucket already exists, if the bucket name is blank, or if the bucket name is too long. +// The bucket instance is only valid for the lifetime of the transaction. +func (tx *Tx) CreateBucket(name []byte) (*Bucket, error) { + return tx.root.CreateBucket(name) +} + +// CreateBucketIfNotExists creates a new bucket if it doesn't already exist. +// Returns an error if the bucket name is blank, or if the bucket name is too long. +// The bucket instance is only valid for the lifetime of the transaction. +func (tx *Tx) CreateBucketIfNotExists(name []byte) (*Bucket, error) { + return tx.root.CreateBucketIfNotExists(name) +} + +// DeleteBucket deletes a bucket. +// Returns an error if the bucket cannot be found or if the key represents a non-bucket value. +func (tx *Tx) DeleteBucket(name []byte) error { + return tx.root.DeleteBucket(name) +} + +// ForEach executes a function for each bucket in the root. +// If the provided function returns an error then the iteration is stopped and +// the error is returned to the caller. +func (tx *Tx) ForEach(fn func(name []byte, b *Bucket) error) error { + return tx.root.ForEach(func(k, v []byte) error { + if err := fn(k, tx.root.Bucket(k)); err != nil { + return err + } + return nil + }) +} + +// OnCommit adds a handler function to be executed after the transaction successfully commits. +func (tx *Tx) OnCommit(fn func()) { + tx.commitHandlers = append(tx.commitHandlers, fn) +} + +// Commit writes all changes to disk and updates the meta page. +// Returns an error if a disk write error occurs, or if Commit is +// called on a read-only transaction. +func (tx *Tx) Commit() error { + _assert(!tx.managed, "managed tx commit not allowed") + if tx.db == nil { + return ErrTxClosed + } else if !tx.writable { + return ErrTxNotWritable + } + + // TODO(benbjohnson): Use vectorized I/O to write out dirty pages. + + // Rebalance nodes which have had deletions. + var startTime = time.Now() + tx.root.rebalance() + if tx.stats.Rebalance > 0 { + tx.stats.RebalanceTime += time.Since(startTime) + } + + // spill data onto dirty pages. + startTime = time.Now() + if err := tx.root.spill(); err != nil { + tx.rollback() + return err + } + tx.stats.SpillTime += time.Since(startTime) + + // Free the old root bucket. + tx.meta.root.root = tx.root.root + + opgid := tx.meta.pgid + + // Free the freelist and allocate new pages for it. This will overestimate + // the size of the freelist but not underestimate the size (which would be bad). + tx.db.freelist.free(tx.meta.txid, tx.db.page(tx.meta.freelist)) + p, err := tx.allocate((tx.db.freelist.size() / tx.db.pageSize) + 1) + if err != nil { + tx.rollback() + return err + } + if err := tx.db.freelist.write(p); err != nil { + tx.rollback() + return err + } + tx.meta.freelist = p.id + + // If the high water mark has moved up then attempt to grow the database. + if tx.meta.pgid > opgid { + if err := tx.db.grow(int(tx.meta.pgid+1) * tx.db.pageSize); err != nil { + tx.rollback() + return err + } + } + + // Write dirty pages to disk. + startTime = time.Now() + if err := tx.write(); err != nil { + tx.rollback() + return err + } + + // If strict mode is enabled then perform a consistency check. + // Only the first consistency error is reported in the panic. + if tx.db.StrictMode { + ch := tx.Check() + var errs []string + for { + err, ok := <-ch + if !ok { + break + } + errs = append(errs, err.Error()) + } + if len(errs) > 0 { + panic("check fail: " + strings.Join(errs, "\n")) + } + } + + // Write meta to disk. + if err := tx.writeMeta(); err != nil { + tx.rollback() + return err + } + tx.stats.WriteTime += time.Since(startTime) + + // Finalize the transaction. + tx.close() + + // Execute commit handlers now that the locks have been removed. + for _, fn := range tx.commitHandlers { + fn() + } + + return nil +} + +// Rollback closes the transaction and ignores all previous updates. Read-only +// transactions must be rolled back and not committed. +func (tx *Tx) Rollback() error { + _assert(!tx.managed, "managed tx rollback not allowed") + if tx.db == nil { + return ErrTxClosed + } + tx.rollback() + return nil +} + +func (tx *Tx) rollback() { + if tx.db == nil { + return + } + if tx.writable { + tx.db.freelist.rollback(tx.meta.txid) + tx.db.freelist.reload(tx.db.page(tx.db.meta().freelist)) + } + tx.close() +} + +func (tx *Tx) close() { + if tx.db == nil { + return + } + if tx.writable { + // Grab freelist stats. + var freelistFreeN = tx.db.freelist.free_count() + var freelistPendingN = tx.db.freelist.pending_count() + var freelistAlloc = tx.db.freelist.size() + + // Remove transaction ref & writer lock. + tx.db.rwtx = nil + tx.db.rwlock.Unlock() + + // Merge statistics. + tx.db.statlock.Lock() + tx.db.stats.FreePageN = freelistFreeN + tx.db.stats.PendingPageN = freelistPendingN + tx.db.stats.FreeAlloc = (freelistFreeN + freelistPendingN) * tx.db.pageSize + tx.db.stats.FreelistInuse = freelistAlloc + tx.db.stats.TxStats.add(&tx.stats) + tx.db.statlock.Unlock() + } else { + tx.db.removeTx(tx) + } + + // Clear all references. + tx.db = nil + tx.meta = nil + tx.root = Bucket{tx: tx} + tx.pages = nil +} + +// Copy writes the entire database to a writer. +// This function exists for backwards compatibility. Use WriteTo() instead. +func (tx *Tx) Copy(w io.Writer) error { + _, err := tx.WriteTo(w) + return err +} + +// WriteTo writes the entire database to a writer. +// If err == nil then exactly tx.Size() bytes will be written into the writer. +func (tx *Tx) WriteTo(w io.Writer) (n int64, err error) { + // Attempt to open reader with WriteFlag + f, err := os.OpenFile(tx.db.path, os.O_RDONLY|tx.WriteFlag, 0) + if err != nil { + return 0, err + } + defer func() { _ = f.Close() }() + + // Generate a meta page. We use the same page data for both meta pages. + buf := make([]byte, tx.db.pageSize) + page := (*page)(unsafe.Pointer(&buf[0])) + page.flags = metaPageFlag + *page.meta() = *tx.meta + + // Write meta 0. + page.id = 0 + page.meta().checksum = page.meta().sum64() + nn, err := w.Write(buf) + n += int64(nn) + if err != nil { + return n, fmt.Errorf("meta 0 copy: %s", err) + } + + // Write meta 1 with a lower transaction id. + page.id = 1 + page.meta().txid -= 1 + page.meta().checksum = page.meta().sum64() + nn, err = w.Write(buf) + n += int64(nn) + if err != nil { + return n, fmt.Errorf("meta 1 copy: %s", err) + } + + // Move past the meta pages in the file. + if _, err := f.Seek(int64(tx.db.pageSize*2), os.SEEK_SET); err != nil { + return n, fmt.Errorf("seek: %s", err) + } + + // Copy data pages. + wn, err := io.CopyN(w, f, tx.Size()-int64(tx.db.pageSize*2)) + n += wn + if err != nil { + return n, err + } + + return n, f.Close() +} + +// CopyFile copies the entire database to file at the given path. +// A reader transaction is maintained during the copy so it is safe to continue +// using the database while a copy is in progress. +func (tx *Tx) CopyFile(path string, mode os.FileMode) error { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) + if err != nil { + return err + } + + err = tx.Copy(f) + if err != nil { + _ = f.Close() + return err + } + return f.Close() +} + +// Check performs several consistency checks on the database for this transaction. +// An error is returned if any inconsistency is found. +// +// It can be safely run concurrently on a writable transaction. However, this +// incurs a high cost for large databases and databases with a lot of subbuckets +// because of caching. This overhead can be removed if running on a read-only +// transaction, however, it is not safe to execute other writer transactions at +// the same time. +func (tx *Tx) Check() <-chan error { + ch := make(chan error) + go tx.check(ch) + return ch +} + +func (tx *Tx) check(ch chan error) { + // Check if any pages are double freed. + freed := make(map[pgid]bool) + all := make([]pgid, tx.db.freelist.count()) + tx.db.freelist.copyall(all) + for _, id := range all { + if freed[id] { + ch <- fmt.Errorf("page %d: already freed", id) + } + freed[id] = true + } + + // Track every reachable page. + reachable := make(map[pgid]*page) + reachable[0] = tx.page(0) // meta0 + reachable[1] = tx.page(1) // meta1 + for i := uint32(0); i <= tx.page(tx.meta.freelist).overflow; i++ { + reachable[tx.meta.freelist+pgid(i)] = tx.page(tx.meta.freelist) + } + + // Recursively check buckets. + tx.checkBucket(&tx.root, reachable, freed, ch) + + // Ensure all pages below high water mark are either reachable or freed. + for i := pgid(0); i < tx.meta.pgid; i++ { + _, isReachable := reachable[i] + if !isReachable && !freed[i] { + ch <- fmt.Errorf("page %d: unreachable unfreed", int(i)) + } + } + + // Close the channel to signal completion. + close(ch) +} + +func (tx *Tx) checkBucket(b *Bucket, reachable map[pgid]*page, freed map[pgid]bool, ch chan error) { + // Ignore inline buckets. + if b.root == 0 { + return + } + + // Check every page used by this bucket. + b.tx.forEachPage(b.root, 0, func(p *page, _ int) { + if p.id > tx.meta.pgid { + ch <- fmt.Errorf("page %d: out of bounds: %d", int(p.id), int(b.tx.meta.pgid)) + } + + // Ensure each page is only referenced once. + for i := pgid(0); i <= pgid(p.overflow); i++ { + var id = p.id + i + if _, ok := reachable[id]; ok { + ch <- fmt.Errorf("page %d: multiple references", int(id)) + } + reachable[id] = p + } + + // We should only encounter un-freed leaf and branch pages. + if freed[p.id] { + ch <- fmt.Errorf("page %d: reachable freed", int(p.id)) + } else if (p.flags&branchPageFlag) == 0 && (p.flags&leafPageFlag) == 0 { + ch <- fmt.Errorf("page %d: invalid type: %s", int(p.id), p.typ()) + } + }) + + // Check each bucket within this bucket. + _ = b.ForEach(func(k, v []byte) error { + if child := b.Bucket(k); child != nil { + tx.checkBucket(child, reachable, freed, ch) + } + return nil + }) +} + +// allocate returns a contiguous block of memory starting at a given page. +func (tx *Tx) allocate(count int) (*page, error) { + p, err := tx.db.allocate(count) + if err != nil { + return nil, err + } + + // Save to our page cache. + tx.pages[p.id] = p + + // Update statistics. + tx.stats.PageCount++ + tx.stats.PageAlloc += count * tx.db.pageSize + + return p, nil +} + +// write writes any dirty pages to disk. +func (tx *Tx) write() error { + // Sort pages by id. + pages := make(pages, 0, len(tx.pages)) + for _, p := range tx.pages { + pages = append(pages, p) + } + // Clear out page cache early. + tx.pages = make(map[pgid]*page) + sort.Sort(pages) + + // Write pages to disk in order. + for _, p := range pages { + size := (int(p.overflow) + 1) * tx.db.pageSize + offset := int64(p.id) * int64(tx.db.pageSize) + + // Write out page in "max allocation" sized chunks. + ptr := (*[maxAllocSize]byte)(unsafe.Pointer(p)) + for { + // Limit our write to our max allocation size. + sz := size + if sz > maxAllocSize-1 { + sz = maxAllocSize - 1 + } + + // Write chunk to disk. + buf := ptr[:sz] + if _, err := tx.db.ops.writeAt(buf, offset); err != nil { + return err + } + + // Update statistics. + tx.stats.Write++ + + // Exit inner for loop if we've written all the chunks. + size -= sz + if size == 0 { + break + } + + // Otherwise move offset forward and move pointer to next chunk. + offset += int64(sz) + ptr = (*[maxAllocSize]byte)(unsafe.Pointer(&ptr[sz])) + } + } + + // Ignore file sync if flag is set on DB. + if !tx.db.NoSync || IgnoreNoSync { + if err := fdatasync(tx.db); err != nil { + return err + } + } + + // Put small pages back to page pool. + for _, p := range pages { + // Ignore page sizes over 1 page. + // These are allocated using make() instead of the page pool. + if int(p.overflow) != 0 { + continue + } + + buf := (*[maxAllocSize]byte)(unsafe.Pointer(p))[:tx.db.pageSize] + + // See https://go.googlesource.com/go/+/f03c9202c43e0abb130669852082117ca50aa9b1 + for i := range buf { + buf[i] = 0 + } + tx.db.pagePool.Put(buf) + } + + return nil +} + +// writeMeta writes the meta to the disk. +func (tx *Tx) writeMeta() error { + // Create a temporary buffer for the meta page. + buf := make([]byte, tx.db.pageSize) + p := tx.db.pageInBuffer(buf, 0) + tx.meta.write(p) + + // Write the meta page to file. + if _, err := tx.db.ops.writeAt(buf, int64(p.id)*int64(tx.db.pageSize)); err != nil { + return err + } + if !tx.db.NoSync || IgnoreNoSync { + if err := fdatasync(tx.db); err != nil { + return err + } + } + + // Update statistics. + tx.stats.Write++ + + return nil +} + +// page returns a reference to the page with a given id. +// If page has been written to then a temporary buffered page is returned. +func (tx *Tx) page(id pgid) *page { + // Check the dirty pages first. + if tx.pages != nil { + if p, ok := tx.pages[id]; ok { + return p + } + } + + // Otherwise return directly from the mmap. + return tx.db.page(id) +} + +// forEachPage iterates over every page within a given page and executes a function. +func (tx *Tx) forEachPage(pgid pgid, depth int, fn func(*page, int)) { + p := tx.page(pgid) + + // Execute function. + fn(p, depth) + + // Recursively loop over children. + if (p.flags & branchPageFlag) != 0 { + for i := 0; i < int(p.count); i++ { + elem := p.branchPageElement(uint16(i)) + tx.forEachPage(elem.pgid, depth+1, fn) + } + } +} + +// Page returns page information for a given page number. +// This is only safe for concurrent use when used by a writable transaction. +func (tx *Tx) Page(id int) (*PageInfo, error) { + if tx.db == nil { + return nil, ErrTxClosed + } else if pgid(id) >= tx.meta.pgid { + return nil, nil + } + + // Build the page info. + p := tx.db.page(pgid(id)) + info := &PageInfo{ + ID: id, + Count: int(p.count), + OverflowCount: int(p.overflow), + } + + // Determine the type (or if it's free). + if tx.db.freelist.freed(pgid(id)) { + info.Type = "free" + } else { + info.Type = p.typ() + } + + return info, nil +} + +// TxStats represents statistics about the actions performed by the transaction. +type TxStats struct { + // Page statistics. + PageCount int // number of page allocations + PageAlloc int // total bytes allocated + + // Cursor statistics. + CursorCount int // number of cursors created + + // Node statistics + NodeCount int // number of node allocations + NodeDeref int // number of node dereferences + + // Rebalance statistics. + Rebalance int // number of node rebalances + RebalanceTime time.Duration // total time spent rebalancing + + // Split/Spill statistics. + Split int // number of nodes split + Spill int // number of nodes spilled + SpillTime time.Duration // total time spent spilling + + // Write statistics. + Write int // number of writes performed + WriteTime time.Duration // total time spent writing to disk +} + +func (s *TxStats) add(other *TxStats) { + s.PageCount += other.PageCount + s.PageAlloc += other.PageAlloc + s.CursorCount += other.CursorCount + s.NodeCount += other.NodeCount + s.NodeDeref += other.NodeDeref + s.Rebalance += other.Rebalance + s.RebalanceTime += other.RebalanceTime + s.Split += other.Split + s.Spill += other.Spill + s.SpillTime += other.SpillTime + s.Write += other.Write + s.WriteTime += other.WriteTime +} + +// Sub calculates and returns the difference between two sets of transaction stats. +// This is useful when obtaining stats at two different points and time and +// you need the performance counters that occurred within that time span. +func (s *TxStats) Sub(other *TxStats) TxStats { + var diff TxStats + diff.PageCount = s.PageCount - other.PageCount + diff.PageAlloc = s.PageAlloc - other.PageAlloc + diff.CursorCount = s.CursorCount - other.CursorCount + diff.NodeCount = s.NodeCount - other.NodeCount + diff.NodeDeref = s.NodeDeref - other.NodeDeref + diff.Rebalance = s.Rebalance - other.Rebalance + diff.RebalanceTime = s.RebalanceTime - other.RebalanceTime + diff.Split = s.Split - other.Split + diff.Spill = s.Spill - other.Spill + diff.SpillTime = s.SpillTime - other.SpillTime + diff.Write = s.Write - other.Write + diff.WriteTime = s.WriteTime - other.WriteTime + return diff +} diff --git a/vendor/github.com/gogo/protobuf/gogoproto/doc.go b/vendor/github.com/gogo/protobuf/gogoproto/doc.go new file mode 100644 index 00000000000..147b5ecc62f --- /dev/null +++ b/vendor/github.com/gogo/protobuf/gogoproto/doc.go @@ -0,0 +1,169 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +Package gogoproto provides extensions for protocol buffers to achieve: + + - fast marshalling and unmarshalling. + - peace of mind by optionally generating test and benchmark code. + - more canonical Go structures. + - less typing by optionally generating extra helper code. + - goprotobuf compatibility + +More Canonical Go Structures + +A lot of time working with a goprotobuf struct will lead you to a place where you create another struct that is easier to work with and then have a function to copy the values between the two structs. +You might also find that basic structs that started their life as part of an API need to be sent over the wire. With gob, you could just send it. With goprotobuf, you need to make a parallel struct. +Gogoprotobuf tries to fix these problems with the nullable, embed, customtype and customname field extensions. + + - nullable, if false, a field is generated without a pointer (see warning below). + - embed, if true, the field is generated as an embedded field. + - customtype, It works with the Marshal and Unmarshal methods, to allow you to have your own types in your struct, but marshal to bytes. For example, custom.Uuid or custom.Fixed128 + - customname (beta), Changes the generated fieldname. This is especially useful when generated methods conflict with fieldnames. + - casttype (beta), Changes the generated fieldtype. All generated code assumes that this type is castable to the protocol buffer field type. It does not work for structs or enums. + - castkey (beta), Changes the generated fieldtype for a map key. All generated code assumes that this type is castable to the protocol buffer field type. Only supported on maps. + - castvalue (beta), Changes the generated fieldtype for a map value. All generated code assumes that this type is castable to the protocol buffer field type. Only supported on maps. + +Warning about nullable: According to the Protocol Buffer specification, you should be able to tell whether a field is set or unset. With the option nullable=false this feature is lost, since your non-nullable fields will always be set. It can be seen as a layer on top of Protocol Buffers, where before and after marshalling all non-nullable fields are set and they cannot be unset. + +Let us look at: + + github.com/gogo/protobuf/test/example/example.proto + +for a quicker overview. + +The following message: + + package test; + + import "github.com/gogo/protobuf/gogoproto/gogo.proto"; + + message A { + optional string Description = 1 [(gogoproto.nullable) = false]; + optional int64 Number = 2 [(gogoproto.nullable) = false]; + optional bytes Id = 3 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uuid", (gogoproto.nullable) = false]; + } + +Will generate a go struct which looks a lot like this: + + type A struct { + Description string + Number int64 + Id github_com_gogo_protobuf_test_custom.Uuid + } + +You will see there are no pointers, since all fields are non-nullable. +You will also see a custom type which marshals to a string. +Be warned it is your responsibility to test your custom types thoroughly. +You should think of every possible empty and nil case for your marshaling, unmarshaling and size methods. + +Next we will embed the message A in message B. + + message B { + optional A A = 1 [(gogoproto.nullable) = false, (gogoproto.embed) = true]; + repeated bytes G = 2 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uint128", (gogoproto.nullable) = false]; + } + +See below that A is embedded in B. + + type B struct { + A + G []github_com_gogo_protobuf_test_custom.Uint128 + } + +Also see the repeated custom type. + + type Uint128 [2]uint64 + +Next we will create a custom name for one of our fields. + + message C { + optional int64 size = 1 [(gogoproto.customname) = "MySize"]; + } + +See below that the field's name is MySize and not Size. + + type C struct { + MySize *int64 + } + +The is useful when having a protocol buffer message with a field name which conflicts with a generated method. +As an example, having a field name size and using the sizer plugin to generate a Size method will cause a go compiler error. +Using customname you can fix this error without changing the field name. +This is typically useful when working with a protocol buffer that was designed before these methods and/or the go language were avialable. + +Gogoprotobuf also has some more subtle changes, these could be changed back: + + - the generated package name for imports do not have the extra /filename.pb, + but are actually the imports specified in the .proto file. + +Gogoprotobuf also has lost some features which should be brought back with time: + + - Marshalling and unmarshalling with reflect and without the unsafe package, + this requires work in pointer_reflect.go + +Why does nullable break protocol buffer specifications: + +The protocol buffer specification states, somewhere, that you should be able to tell whether a +field is set or unset. With the option nullable=false this feature is lost, +since your non-nullable fields will always be set. It can be seen as a layer on top of +protocol buffers, where before and after marshalling all non-nullable fields are set +and they cannot be unset. + +Goprotobuf Compatibility: + +Gogoprotobuf is compatible with Goprotobuf, because it is compatible with protocol buffers. +Gogoprotobuf generates the same code as goprotobuf if no extensions are used. +The enumprefix, getters and stringer extensions can be used to remove some of the unnecessary code generated by goprotobuf: + + - gogoproto_import, if false, the generated code imports github.com/golang/protobuf/proto instead of github.com/gogo/protobuf/proto. + - goproto_enum_prefix, if false, generates the enum constant names without the messagetype prefix + - goproto_enum_stringer (experimental), if false, the enum is generated without the default string method, this is useful for rather using enum_stringer, or allowing you to write your own string method. + - goproto_getters, if false, the message is generated without get methods, this is useful when you would rather want to use face + - goproto_stringer, if false, the message is generated without the default string method, this is useful for rather using stringer, or allowing you to write your own string method. + - goproto_extensions_map (beta), if false, the extensions field is generated as type []byte instead of type map[int32]proto.Extension + - goproto_unrecognized (beta), if false, XXX_unrecognized field is not generated. This is useful in conjunction with gogoproto.nullable=false, to generate structures completely devoid of pointers and reduce GC pressure at the cost of losing information about unrecognized fields. + - goproto_registration (beta), if true, the generated files will register all messages and types against both gogo/protobuf and golang/protobuf. This is necessary when using third-party packages which read registrations from golang/protobuf (such as the grpc-gateway). + +Less Typing and Peace of Mind is explained in their specific plugin folders godoc: + + - github.com/gogo/protobuf/plugin/ + +If you do not use any of these extension the code that is generated +will be the same as if goprotobuf has generated it. + +The most complete way to see examples is to look at + + github.com/gogo/protobuf/test/thetest.proto + +Gogoprototest is a seperate project, +because we want to keep gogoprotobuf independant of goprotobuf, +but we still want to test it thoroughly. + +*/ +package gogoproto diff --git a/vendor/github.com/gogo/protobuf/gogoproto/gogo.pb.go b/vendor/github.com/gogo/protobuf/gogoproto/gogo.pb.go new file mode 100644 index 00000000000..5765acb1530 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/gogoproto/gogo.pb.go @@ -0,0 +1,804 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: gogo.proto + +/* +Package gogoproto is a generated protocol buffer package. + +It is generated from these files: + gogo.proto + +It has these top-level messages: +*/ +package gogoproto + +import proto "github.com/gogo/protobuf/proto" +import fmt "fmt" +import math "math" +import google_protobuf "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package + +var E_GoprotoEnumPrefix = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.EnumOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 62001, + Name: "gogoproto.goproto_enum_prefix", + Tag: "varint,62001,opt,name=goproto_enum_prefix,json=goprotoEnumPrefix", + Filename: "gogo.proto", +} + +var E_GoprotoEnumStringer = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.EnumOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 62021, + Name: "gogoproto.goproto_enum_stringer", + Tag: "varint,62021,opt,name=goproto_enum_stringer,json=goprotoEnumStringer", + Filename: "gogo.proto", +} + +var E_EnumStringer = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.EnumOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 62022, + Name: "gogoproto.enum_stringer", + Tag: "varint,62022,opt,name=enum_stringer,json=enumStringer", + Filename: "gogo.proto", +} + +var E_EnumCustomname = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.EnumOptions)(nil), + ExtensionType: (*string)(nil), + Field: 62023, + Name: "gogoproto.enum_customname", + Tag: "bytes,62023,opt,name=enum_customname,json=enumCustomname", + Filename: "gogo.proto", +} + +var E_Enumdecl = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.EnumOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 62024, + Name: "gogoproto.enumdecl", + Tag: "varint,62024,opt,name=enumdecl", + Filename: "gogo.proto", +} + +var E_EnumvalueCustomname = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.EnumValueOptions)(nil), + ExtensionType: (*string)(nil), + Field: 66001, + Name: "gogoproto.enumvalue_customname", + Tag: "bytes,66001,opt,name=enumvalue_customname,json=enumvalueCustomname", + Filename: "gogo.proto", +} + +var E_GoprotoGettersAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63001, + Name: "gogoproto.goproto_getters_all", + Tag: "varint,63001,opt,name=goproto_getters_all,json=goprotoGettersAll", + Filename: "gogo.proto", +} + +var E_GoprotoEnumPrefixAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63002, + Name: "gogoproto.goproto_enum_prefix_all", + Tag: "varint,63002,opt,name=goproto_enum_prefix_all,json=goprotoEnumPrefixAll", + Filename: "gogo.proto", +} + +var E_GoprotoStringerAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63003, + Name: "gogoproto.goproto_stringer_all", + Tag: "varint,63003,opt,name=goproto_stringer_all,json=goprotoStringerAll", + Filename: "gogo.proto", +} + +var E_VerboseEqualAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63004, + Name: "gogoproto.verbose_equal_all", + Tag: "varint,63004,opt,name=verbose_equal_all,json=verboseEqualAll", + Filename: "gogo.proto", +} + +var E_FaceAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63005, + Name: "gogoproto.face_all", + Tag: "varint,63005,opt,name=face_all,json=faceAll", + Filename: "gogo.proto", +} + +var E_GostringAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63006, + Name: "gogoproto.gostring_all", + Tag: "varint,63006,opt,name=gostring_all,json=gostringAll", + Filename: "gogo.proto", +} + +var E_PopulateAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63007, + Name: "gogoproto.populate_all", + Tag: "varint,63007,opt,name=populate_all,json=populateAll", + Filename: "gogo.proto", +} + +var E_StringerAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63008, + Name: "gogoproto.stringer_all", + Tag: "varint,63008,opt,name=stringer_all,json=stringerAll", + Filename: "gogo.proto", +} + +var E_OnlyoneAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63009, + Name: "gogoproto.onlyone_all", + Tag: "varint,63009,opt,name=onlyone_all,json=onlyoneAll", + Filename: "gogo.proto", +} + +var E_EqualAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63013, + Name: "gogoproto.equal_all", + Tag: "varint,63013,opt,name=equal_all,json=equalAll", + Filename: "gogo.proto", +} + +var E_DescriptionAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63014, + Name: "gogoproto.description_all", + Tag: "varint,63014,opt,name=description_all,json=descriptionAll", + Filename: "gogo.proto", +} + +var E_TestgenAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63015, + Name: "gogoproto.testgen_all", + Tag: "varint,63015,opt,name=testgen_all,json=testgenAll", + Filename: "gogo.proto", +} + +var E_BenchgenAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63016, + Name: "gogoproto.benchgen_all", + Tag: "varint,63016,opt,name=benchgen_all,json=benchgenAll", + Filename: "gogo.proto", +} + +var E_MarshalerAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63017, + Name: "gogoproto.marshaler_all", + Tag: "varint,63017,opt,name=marshaler_all,json=marshalerAll", + Filename: "gogo.proto", +} + +var E_UnmarshalerAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63018, + Name: "gogoproto.unmarshaler_all", + Tag: "varint,63018,opt,name=unmarshaler_all,json=unmarshalerAll", + Filename: "gogo.proto", +} + +var E_StableMarshalerAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63019, + Name: "gogoproto.stable_marshaler_all", + Tag: "varint,63019,opt,name=stable_marshaler_all,json=stableMarshalerAll", + Filename: "gogo.proto", +} + +var E_SizerAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63020, + Name: "gogoproto.sizer_all", + Tag: "varint,63020,opt,name=sizer_all,json=sizerAll", + Filename: "gogo.proto", +} + +var E_GoprotoEnumStringerAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63021, + Name: "gogoproto.goproto_enum_stringer_all", + Tag: "varint,63021,opt,name=goproto_enum_stringer_all,json=goprotoEnumStringerAll", + Filename: "gogo.proto", +} + +var E_EnumStringerAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63022, + Name: "gogoproto.enum_stringer_all", + Tag: "varint,63022,opt,name=enum_stringer_all,json=enumStringerAll", + Filename: "gogo.proto", +} + +var E_UnsafeMarshalerAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63023, + Name: "gogoproto.unsafe_marshaler_all", + Tag: "varint,63023,opt,name=unsafe_marshaler_all,json=unsafeMarshalerAll", + Filename: "gogo.proto", +} + +var E_UnsafeUnmarshalerAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63024, + Name: "gogoproto.unsafe_unmarshaler_all", + Tag: "varint,63024,opt,name=unsafe_unmarshaler_all,json=unsafeUnmarshalerAll", + Filename: "gogo.proto", +} + +var E_GoprotoExtensionsMapAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63025, + Name: "gogoproto.goproto_extensions_map_all", + Tag: "varint,63025,opt,name=goproto_extensions_map_all,json=goprotoExtensionsMapAll", + Filename: "gogo.proto", +} + +var E_GoprotoUnrecognizedAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63026, + Name: "gogoproto.goproto_unrecognized_all", + Tag: "varint,63026,opt,name=goproto_unrecognized_all,json=goprotoUnrecognizedAll", + Filename: "gogo.proto", +} + +var E_GogoprotoImport = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63027, + Name: "gogoproto.gogoproto_import", + Tag: "varint,63027,opt,name=gogoproto_import,json=gogoprotoImport", + Filename: "gogo.proto", +} + +var E_ProtosizerAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63028, + Name: "gogoproto.protosizer_all", + Tag: "varint,63028,opt,name=protosizer_all,json=protosizerAll", + Filename: "gogo.proto", +} + +var E_CompareAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63029, + Name: "gogoproto.compare_all", + Tag: "varint,63029,opt,name=compare_all,json=compareAll", + Filename: "gogo.proto", +} + +var E_TypedeclAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63030, + Name: "gogoproto.typedecl_all", + Tag: "varint,63030,opt,name=typedecl_all,json=typedeclAll", + Filename: "gogo.proto", +} + +var E_EnumdeclAll = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63031, + Name: "gogoproto.enumdecl_all", + Tag: "varint,63031,opt,name=enumdecl_all,json=enumdeclAll", + Filename: "gogo.proto", +} + +var E_GoprotoRegistration = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FileOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 63032, + Name: "gogoproto.goproto_registration", + Tag: "varint,63032,opt,name=goproto_registration,json=goprotoRegistration", + Filename: "gogo.proto", +} + +var E_GoprotoGetters = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64001, + Name: "gogoproto.goproto_getters", + Tag: "varint,64001,opt,name=goproto_getters,json=goprotoGetters", + Filename: "gogo.proto", +} + +var E_GoprotoStringer = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64003, + Name: "gogoproto.goproto_stringer", + Tag: "varint,64003,opt,name=goproto_stringer,json=goprotoStringer", + Filename: "gogo.proto", +} + +var E_VerboseEqual = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64004, + Name: "gogoproto.verbose_equal", + Tag: "varint,64004,opt,name=verbose_equal,json=verboseEqual", + Filename: "gogo.proto", +} + +var E_Face = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64005, + Name: "gogoproto.face", + Tag: "varint,64005,opt,name=face", + Filename: "gogo.proto", +} + +var E_Gostring = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64006, + Name: "gogoproto.gostring", + Tag: "varint,64006,opt,name=gostring", + Filename: "gogo.proto", +} + +var E_Populate = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64007, + Name: "gogoproto.populate", + Tag: "varint,64007,opt,name=populate", + Filename: "gogo.proto", +} + +var E_Stringer = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 67008, + Name: "gogoproto.stringer", + Tag: "varint,67008,opt,name=stringer", + Filename: "gogo.proto", +} + +var E_Onlyone = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64009, + Name: "gogoproto.onlyone", + Tag: "varint,64009,opt,name=onlyone", + Filename: "gogo.proto", +} + +var E_Equal = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64013, + Name: "gogoproto.equal", + Tag: "varint,64013,opt,name=equal", + Filename: "gogo.proto", +} + +var E_Description = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64014, + Name: "gogoproto.description", + Tag: "varint,64014,opt,name=description", + Filename: "gogo.proto", +} + +var E_Testgen = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64015, + Name: "gogoproto.testgen", + Tag: "varint,64015,opt,name=testgen", + Filename: "gogo.proto", +} + +var E_Benchgen = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64016, + Name: "gogoproto.benchgen", + Tag: "varint,64016,opt,name=benchgen", + Filename: "gogo.proto", +} + +var E_Marshaler = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64017, + Name: "gogoproto.marshaler", + Tag: "varint,64017,opt,name=marshaler", + Filename: "gogo.proto", +} + +var E_Unmarshaler = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64018, + Name: "gogoproto.unmarshaler", + Tag: "varint,64018,opt,name=unmarshaler", + Filename: "gogo.proto", +} + +var E_StableMarshaler = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64019, + Name: "gogoproto.stable_marshaler", + Tag: "varint,64019,opt,name=stable_marshaler,json=stableMarshaler", + Filename: "gogo.proto", +} + +var E_Sizer = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64020, + Name: "gogoproto.sizer", + Tag: "varint,64020,opt,name=sizer", + Filename: "gogo.proto", +} + +var E_UnsafeMarshaler = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64023, + Name: "gogoproto.unsafe_marshaler", + Tag: "varint,64023,opt,name=unsafe_marshaler,json=unsafeMarshaler", + Filename: "gogo.proto", +} + +var E_UnsafeUnmarshaler = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64024, + Name: "gogoproto.unsafe_unmarshaler", + Tag: "varint,64024,opt,name=unsafe_unmarshaler,json=unsafeUnmarshaler", + Filename: "gogo.proto", +} + +var E_GoprotoExtensionsMap = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64025, + Name: "gogoproto.goproto_extensions_map", + Tag: "varint,64025,opt,name=goproto_extensions_map,json=goprotoExtensionsMap", + Filename: "gogo.proto", +} + +var E_GoprotoUnrecognized = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64026, + Name: "gogoproto.goproto_unrecognized", + Tag: "varint,64026,opt,name=goproto_unrecognized,json=goprotoUnrecognized", + Filename: "gogo.proto", +} + +var E_Protosizer = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64028, + Name: "gogoproto.protosizer", + Tag: "varint,64028,opt,name=protosizer", + Filename: "gogo.proto", +} + +var E_Compare = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64029, + Name: "gogoproto.compare", + Tag: "varint,64029,opt,name=compare", + Filename: "gogo.proto", +} + +var E_Typedecl = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.MessageOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 64030, + Name: "gogoproto.typedecl", + Tag: "varint,64030,opt,name=typedecl", + Filename: "gogo.proto", +} + +var E_Nullable = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FieldOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 65001, + Name: "gogoproto.nullable", + Tag: "varint,65001,opt,name=nullable", + Filename: "gogo.proto", +} + +var E_Embed = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FieldOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 65002, + Name: "gogoproto.embed", + Tag: "varint,65002,opt,name=embed", + Filename: "gogo.proto", +} + +var E_Customtype = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 65003, + Name: "gogoproto.customtype", + Tag: "bytes,65003,opt,name=customtype", + Filename: "gogo.proto", +} + +var E_Customname = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 65004, + Name: "gogoproto.customname", + Tag: "bytes,65004,opt,name=customname", + Filename: "gogo.proto", +} + +var E_Jsontag = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 65005, + Name: "gogoproto.jsontag", + Tag: "bytes,65005,opt,name=jsontag", + Filename: "gogo.proto", +} + +var E_Moretags = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 65006, + Name: "gogoproto.moretags", + Tag: "bytes,65006,opt,name=moretags", + Filename: "gogo.proto", +} + +var E_Casttype = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 65007, + Name: "gogoproto.casttype", + Tag: "bytes,65007,opt,name=casttype", + Filename: "gogo.proto", +} + +var E_Castkey = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 65008, + Name: "gogoproto.castkey", + Tag: "bytes,65008,opt,name=castkey", + Filename: "gogo.proto", +} + +var E_Castvalue = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FieldOptions)(nil), + ExtensionType: (*string)(nil), + Field: 65009, + Name: "gogoproto.castvalue", + Tag: "bytes,65009,opt,name=castvalue", + Filename: "gogo.proto", +} + +var E_Stdtime = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FieldOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 65010, + Name: "gogoproto.stdtime", + Tag: "varint,65010,opt,name=stdtime", + Filename: "gogo.proto", +} + +var E_Stdduration = &proto.ExtensionDesc{ + ExtendedType: (*google_protobuf.FieldOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 65011, + Name: "gogoproto.stdduration", + Tag: "varint,65011,opt,name=stdduration", + Filename: "gogo.proto", +} + +func init() { + proto.RegisterExtension(E_GoprotoEnumPrefix) + proto.RegisterExtension(E_GoprotoEnumStringer) + proto.RegisterExtension(E_EnumStringer) + proto.RegisterExtension(E_EnumCustomname) + proto.RegisterExtension(E_Enumdecl) + proto.RegisterExtension(E_EnumvalueCustomname) + proto.RegisterExtension(E_GoprotoGettersAll) + proto.RegisterExtension(E_GoprotoEnumPrefixAll) + proto.RegisterExtension(E_GoprotoStringerAll) + proto.RegisterExtension(E_VerboseEqualAll) + proto.RegisterExtension(E_FaceAll) + proto.RegisterExtension(E_GostringAll) + proto.RegisterExtension(E_PopulateAll) + proto.RegisterExtension(E_StringerAll) + proto.RegisterExtension(E_OnlyoneAll) + proto.RegisterExtension(E_EqualAll) + proto.RegisterExtension(E_DescriptionAll) + proto.RegisterExtension(E_TestgenAll) + proto.RegisterExtension(E_BenchgenAll) + proto.RegisterExtension(E_MarshalerAll) + proto.RegisterExtension(E_UnmarshalerAll) + proto.RegisterExtension(E_StableMarshalerAll) + proto.RegisterExtension(E_SizerAll) + proto.RegisterExtension(E_GoprotoEnumStringerAll) + proto.RegisterExtension(E_EnumStringerAll) + proto.RegisterExtension(E_UnsafeMarshalerAll) + proto.RegisterExtension(E_UnsafeUnmarshalerAll) + proto.RegisterExtension(E_GoprotoExtensionsMapAll) + proto.RegisterExtension(E_GoprotoUnrecognizedAll) + proto.RegisterExtension(E_GogoprotoImport) + proto.RegisterExtension(E_ProtosizerAll) + proto.RegisterExtension(E_CompareAll) + proto.RegisterExtension(E_TypedeclAll) + proto.RegisterExtension(E_EnumdeclAll) + proto.RegisterExtension(E_GoprotoRegistration) + proto.RegisterExtension(E_GoprotoGetters) + proto.RegisterExtension(E_GoprotoStringer) + proto.RegisterExtension(E_VerboseEqual) + proto.RegisterExtension(E_Face) + proto.RegisterExtension(E_Gostring) + proto.RegisterExtension(E_Populate) + proto.RegisterExtension(E_Stringer) + proto.RegisterExtension(E_Onlyone) + proto.RegisterExtension(E_Equal) + proto.RegisterExtension(E_Description) + proto.RegisterExtension(E_Testgen) + proto.RegisterExtension(E_Benchgen) + proto.RegisterExtension(E_Marshaler) + proto.RegisterExtension(E_Unmarshaler) + proto.RegisterExtension(E_StableMarshaler) + proto.RegisterExtension(E_Sizer) + proto.RegisterExtension(E_UnsafeMarshaler) + proto.RegisterExtension(E_UnsafeUnmarshaler) + proto.RegisterExtension(E_GoprotoExtensionsMap) + proto.RegisterExtension(E_GoprotoUnrecognized) + proto.RegisterExtension(E_Protosizer) + proto.RegisterExtension(E_Compare) + proto.RegisterExtension(E_Typedecl) + proto.RegisterExtension(E_Nullable) + proto.RegisterExtension(E_Embed) + proto.RegisterExtension(E_Customtype) + proto.RegisterExtension(E_Customname) + proto.RegisterExtension(E_Jsontag) + proto.RegisterExtension(E_Moretags) + proto.RegisterExtension(E_Casttype) + proto.RegisterExtension(E_Castkey) + proto.RegisterExtension(E_Castvalue) + proto.RegisterExtension(E_Stdtime) + proto.RegisterExtension(E_Stdduration) +} + +func init() { proto.RegisterFile("gogo.proto", fileDescriptorGogo) } + +var fileDescriptorGogo = []byte{ + // 1220 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x98, 0x4b, 0x6f, 0x1c, 0x45, + 0x10, 0x80, 0x85, 0x48, 0x14, 0x6f, 0xd9, 0x8e, 0xf1, 0xda, 0x98, 0x10, 0x81, 0x08, 0x9c, 0x38, + 0xd9, 0xa7, 0x08, 0xa5, 0xad, 0xc8, 0x72, 0x2c, 0xc7, 0x4a, 0x84, 0xc1, 0x98, 0x38, 0xbc, 0x0e, + 0xab, 0xd9, 0xdd, 0xf6, 0x78, 0x60, 0x66, 0x7a, 0x98, 0xe9, 0x89, 0xe2, 0xdc, 0x50, 0x78, 0x08, + 0x21, 0xde, 0x48, 0x90, 0x90, 0x04, 0x38, 0xf0, 0x7e, 0x86, 0xf7, 0x91, 0x0b, 0x8f, 0x2b, 0xff, + 0x81, 0x0b, 0x60, 0xde, 0xbe, 0xf9, 0x82, 0x6a, 0xb6, 0x6a, 0xb6, 0x67, 0xbd, 0x52, 0xf7, 0xde, + 0xc6, 0xeb, 0xfe, 0xbe, 0xad, 0xa9, 0x9a, 0xae, 0xea, 0x59, 0x00, 0x5f, 0xf9, 0x6a, 0x3a, 0x49, + 0x95, 0x56, 0xf5, 0x1a, 0x5e, 0x17, 0x97, 0x07, 0x0f, 0xf9, 0x4a, 0xf9, 0xa1, 0x9c, 0x29, 0xfe, + 0x6a, 0xe6, 0xeb, 0x33, 0x6d, 0x99, 0xb5, 0xd2, 0x20, 0xd1, 0x2a, 0xed, 0x2c, 0x16, 0x77, 0xc1, + 0x04, 0x2d, 0x6e, 0xc8, 0x38, 0x8f, 0x1a, 0x49, 0x2a, 0xd7, 0x83, 0xb3, 0xf5, 0x9b, 0xa6, 0x3b, + 0xe4, 0x34, 0x93, 0xd3, 0x8b, 0x71, 0x1e, 0xdd, 0x9d, 0xe8, 0x40, 0xc5, 0xd9, 0x81, 0xab, 0xbf, + 0x5c, 0x7b, 0xe8, 0x9a, 0xdb, 0x87, 0x56, 0xc7, 0x09, 0xc5, 0xff, 0xad, 0x14, 0xa0, 0x58, 0x85, + 0xeb, 0x2b, 0xbe, 0x4c, 0xa7, 0x41, 0xec, 0xcb, 0xd4, 0x62, 0xfc, 0x9e, 0x8c, 0x13, 0x86, 0xf1, + 0x5e, 0x42, 0xc5, 0x02, 0x8c, 0x0e, 0xe2, 0xfa, 0x81, 0x5c, 0x23, 0xd2, 0x94, 0x2c, 0xc1, 0x58, + 0x21, 0x69, 0xe5, 0x99, 0x56, 0x51, 0xec, 0x45, 0xd2, 0xa2, 0xf9, 0xb1, 0xd0, 0xd4, 0x56, 0xf7, + 0x23, 0xb6, 0x50, 0x52, 0x42, 0xc0, 0x10, 0x7e, 0xd2, 0x96, 0xad, 0xd0, 0x62, 0xf8, 0x89, 0x02, + 0x29, 0xd7, 0x8b, 0xd3, 0x30, 0x89, 0xd7, 0x67, 0xbc, 0x30, 0x97, 0x66, 0x24, 0xb7, 0xf6, 0xf5, + 0x9c, 0xc6, 0x65, 0x2c, 0xfb, 0xf9, 0xfc, 0x9e, 0x22, 0x9c, 0x89, 0x52, 0x60, 0xc4, 0x64, 0x54, + 0xd1, 0x97, 0x5a, 0xcb, 0x34, 0x6b, 0x78, 0x61, 0xbf, 0xf0, 0x8e, 0x07, 0x61, 0x69, 0xbc, 0xb0, + 0x55, 0xad, 0xe2, 0x52, 0x87, 0x9c, 0x0f, 0x43, 0xb1, 0x06, 0x37, 0xf4, 0x79, 0x2a, 0x1c, 0x9c, + 0x17, 0xc9, 0x39, 0xb9, 0xeb, 0xc9, 0x40, 0xed, 0x0a, 0xf0, 0xe7, 0x65, 0x2d, 0x1d, 0x9c, 0xaf, + 0x93, 0xb3, 0x4e, 0x2c, 0x97, 0x14, 0x8d, 0x27, 0x61, 0xfc, 0x8c, 0x4c, 0x9b, 0x2a, 0x93, 0x0d, + 0xf9, 0x68, 0xee, 0x85, 0x0e, 0xba, 0x4b, 0xa4, 0x1b, 0x23, 0x70, 0x11, 0x39, 0x74, 0x1d, 0x81, + 0xa1, 0x75, 0xaf, 0x25, 0x1d, 0x14, 0x97, 0x49, 0xb1, 0x0f, 0xd7, 0x23, 0x3a, 0x0f, 0x23, 0xbe, + 0xea, 0xdc, 0x92, 0x03, 0x7e, 0x85, 0xf0, 0x61, 0x66, 0x48, 0x91, 0xa8, 0x24, 0x0f, 0x3d, 0xed, + 0x12, 0xc1, 0x1b, 0xac, 0x60, 0x86, 0x14, 0x03, 0xa4, 0xf5, 0x4d, 0x56, 0x64, 0x46, 0x3e, 0xe7, + 0x60, 0x58, 0xc5, 0xe1, 0xa6, 0x8a, 0x5d, 0x82, 0x78, 0x8b, 0x0c, 0x40, 0x08, 0x0a, 0x66, 0xa1, + 0xe6, 0x5a, 0x88, 0xb7, 0xb7, 0x78, 0x7b, 0x70, 0x05, 0x96, 0x60, 0x8c, 0x1b, 0x54, 0xa0, 0x62, + 0x07, 0xc5, 0x3b, 0xa4, 0xd8, 0x6f, 0x60, 0x74, 0x1b, 0x5a, 0x66, 0xda, 0x97, 0x2e, 0x92, 0x77, + 0xf9, 0x36, 0x08, 0xa1, 0x54, 0x36, 0x65, 0xdc, 0xda, 0x70, 0x33, 0xbc, 0xc7, 0xa9, 0x64, 0x06, + 0x15, 0x0b, 0x30, 0x1a, 0x79, 0x69, 0xb6, 0xe1, 0x85, 0x4e, 0xe5, 0x78, 0x9f, 0x1c, 0x23, 0x25, + 0x44, 0x19, 0xc9, 0xe3, 0x41, 0x34, 0x1f, 0x70, 0x46, 0x0c, 0x8c, 0xb6, 0x5e, 0xa6, 0xbd, 0x66, + 0x28, 0x1b, 0x83, 0xd8, 0x3e, 0xe4, 0xad, 0xd7, 0x61, 0x97, 0x4d, 0xe3, 0x2c, 0xd4, 0xb2, 0xe0, + 0x9c, 0x93, 0xe6, 0x23, 0xae, 0x74, 0x01, 0x20, 0xfc, 0x00, 0xdc, 0xd8, 0x77, 0x4c, 0x38, 0xc8, + 0x3e, 0x26, 0xd9, 0x54, 0x9f, 0x51, 0x41, 0x2d, 0x61, 0x50, 0xe5, 0x27, 0xdc, 0x12, 0x64, 0x8f, + 0x6b, 0x05, 0x26, 0xf3, 0x38, 0xf3, 0xd6, 0x07, 0xcb, 0xda, 0xa7, 0x9c, 0xb5, 0x0e, 0x5b, 0xc9, + 0xda, 0x29, 0x98, 0x22, 0xe3, 0x60, 0x75, 0xfd, 0x8c, 0x1b, 0x6b, 0x87, 0x5e, 0xab, 0x56, 0xf7, + 0x21, 0x38, 0x58, 0xa6, 0xf3, 0xac, 0x96, 0x71, 0x86, 0x4c, 0x23, 0xf2, 0x12, 0x07, 0xf3, 0x55, + 0x32, 0x73, 0xc7, 0x5f, 0x2c, 0x05, 0xcb, 0x5e, 0x82, 0xf2, 0xfb, 0xe1, 0x00, 0xcb, 0xf3, 0x38, + 0x95, 0x2d, 0xe5, 0xc7, 0xc1, 0x39, 0xd9, 0x76, 0x50, 0x7f, 0xde, 0x53, 0xaa, 0x35, 0x03, 0x47, + 0xf3, 0x09, 0xb8, 0xae, 0x3c, 0xab, 0x34, 0x82, 0x28, 0x51, 0xa9, 0xb6, 0x18, 0xbf, 0xe0, 0x4a, + 0x95, 0xdc, 0x89, 0x02, 0x13, 0x8b, 0xb0, 0xbf, 0xf8, 0xd3, 0xf5, 0x91, 0xfc, 0x92, 0x44, 0xa3, + 0x5d, 0x8a, 0x1a, 0x47, 0x4b, 0x45, 0x89, 0x97, 0xba, 0xf4, 0xbf, 0xaf, 0xb8, 0x71, 0x10, 0x42, + 0x8d, 0x43, 0x6f, 0x26, 0x12, 0xa7, 0xbd, 0x83, 0xe1, 0x6b, 0x6e, 0x1c, 0xcc, 0x90, 0x82, 0x0f, + 0x0c, 0x0e, 0x8a, 0x6f, 0x58, 0xc1, 0x0c, 0x2a, 0xee, 0xe9, 0x0e, 0xda, 0x54, 0xfa, 0x41, 0xa6, + 0x53, 0x0f, 0x57, 0x5b, 0x54, 0xdf, 0x6e, 0x55, 0x0f, 0x61, 0xab, 0x06, 0x2a, 0x4e, 0xc2, 0x58, + 0xcf, 0x11, 0xa3, 0x7e, 0xcb, 0x2e, 0xdb, 0xb2, 0xcc, 0x32, 0xcf, 0x2f, 0x85, 0x8f, 0x6d, 0x53, + 0x33, 0xaa, 0x9e, 0x30, 0xc4, 0x9d, 0x58, 0xf7, 0xea, 0x39, 0xc0, 0x2e, 0x3b, 0xbf, 0x5d, 0x96, + 0xbe, 0x72, 0x0c, 0x10, 0xc7, 0x61, 0xb4, 0x72, 0x06, 0xb0, 0xab, 0x1e, 0x27, 0xd5, 0x88, 0x79, + 0x04, 0x10, 0x87, 0x61, 0x0f, 0xce, 0x73, 0x3b, 0xfe, 0x04, 0xe1, 0xc5, 0x72, 0x71, 0x14, 0x86, + 0x78, 0x8e, 0xdb, 0xd1, 0x27, 0x09, 0x2d, 0x11, 0xc4, 0x79, 0x86, 0xdb, 0xf1, 0xa7, 0x18, 0x67, + 0x04, 0x71, 0xf7, 0x14, 0x7e, 0xf7, 0xcc, 0x1e, 0xea, 0xc3, 0x9c, 0xbb, 0x59, 0xd8, 0x47, 0xc3, + 0xdb, 0x4e, 0x3f, 0x4d, 0x5f, 0xce, 0x84, 0xb8, 0x03, 0xf6, 0x3a, 0x26, 0xfc, 0x59, 0x42, 0x3b, + 0xeb, 0xc5, 0x02, 0x0c, 0x1b, 0x03, 0xdb, 0x8e, 0x3f, 0x47, 0xb8, 0x49, 0x61, 0xe8, 0x34, 0xb0, + 0xed, 0x82, 0xe7, 0x39, 0x74, 0x22, 0x30, 0x6d, 0x3c, 0xab, 0xed, 0xf4, 0x0b, 0x9c, 0x75, 0x46, + 0xc4, 0x1c, 0xd4, 0xca, 0xfe, 0x6b, 0xe7, 0x5f, 0x24, 0xbe, 0xcb, 0x60, 0x06, 0x8c, 0xfe, 0x6f, + 0x57, 0xbc, 0xc4, 0x19, 0x30, 0x28, 0xdc, 0x46, 0xbd, 0x33, 0xdd, 0x6e, 0x7a, 0x99, 0xb7, 0x51, + 0xcf, 0x48, 0xc7, 0x6a, 0x16, 0x6d, 0xd0, 0xae, 0x78, 0x85, 0xab, 0x59, 0xac, 0xc7, 0x30, 0x7a, + 0x87, 0xa4, 0xdd, 0xf1, 0x2a, 0x87, 0xd1, 0x33, 0x23, 0xc5, 0x0a, 0xd4, 0x77, 0x0f, 0x48, 0xbb, + 0xef, 0x35, 0xf2, 0x8d, 0xef, 0x9a, 0x8f, 0xe2, 0x3e, 0x98, 0xea, 0x3f, 0x1c, 0xed, 0xd6, 0x0b, + 0xdb, 0x3d, 0xaf, 0x33, 0xe6, 0x6c, 0x14, 0xa7, 0xba, 0x5d, 0xd6, 0x1c, 0x8c, 0x76, 0xed, 0xc5, + 0xed, 0x6a, 0xa3, 0x35, 0xe7, 0xa2, 0x98, 0x07, 0xe8, 0xce, 0x24, 0xbb, 0xeb, 0x12, 0xb9, 0x0c, + 0x08, 0xb7, 0x06, 0x8d, 0x24, 0x3b, 0x7f, 0x99, 0xb7, 0x06, 0x11, 0xb8, 0x35, 0x78, 0x1a, 0xd9, + 0xe9, 0x2b, 0xbc, 0x35, 0x18, 0x11, 0xb3, 0x30, 0x14, 0xe7, 0x61, 0x88, 0xcf, 0x56, 0xfd, 0xe6, + 0x3e, 0xe3, 0x46, 0x86, 0x6d, 0x86, 0x7f, 0xdd, 0x21, 0x98, 0x01, 0x71, 0x18, 0xf6, 0xca, 0xa8, + 0x29, 0xdb, 0x36, 0xf2, 0xb7, 0x1d, 0xee, 0x27, 0xb8, 0x5a, 0xcc, 0x01, 0x74, 0x5e, 0xa6, 0x31, + 0x0a, 0x1b, 0xfb, 0xfb, 0x4e, 0xe7, 0xbd, 0xde, 0x40, 0xba, 0x82, 0xe2, 0x6d, 0xdc, 0x22, 0xd8, + 0xaa, 0x0a, 0x8a, 0x17, 0xf0, 0x23, 0xb0, 0xef, 0xe1, 0x4c, 0xc5, 0xda, 0xf3, 0x6d, 0xf4, 0x1f, + 0x44, 0xf3, 0x7a, 0x4c, 0x58, 0xa4, 0x52, 0xa9, 0x3d, 0x3f, 0xb3, 0xb1, 0x7f, 0x12, 0x5b, 0x02, + 0x08, 0xb7, 0xbc, 0x4c, 0xbb, 0xdc, 0xf7, 0x5f, 0x0c, 0x33, 0x80, 0x41, 0xe3, 0xf5, 0x23, 0x72, + 0xd3, 0xc6, 0xfe, 0xcd, 0x41, 0xd3, 0x7a, 0x71, 0x14, 0x6a, 0x78, 0x59, 0xfc, 0x0e, 0x61, 0x83, + 0xff, 0x21, 0xb8, 0x4b, 0xe0, 0x37, 0x67, 0xba, 0xad, 0x03, 0x7b, 0xb2, 0xff, 0xa5, 0x4a, 0xf3, + 0x7a, 0x31, 0x0f, 0xc3, 0x99, 0x6e, 0xb7, 0x73, 0x3a, 0xd1, 0x58, 0xf0, 0xff, 0x76, 0xca, 0x97, + 0xdc, 0x92, 0x39, 0xb6, 0x08, 0x13, 0x2d, 0x15, 0xf5, 0x82, 0xc7, 0x60, 0x49, 0x2d, 0xa9, 0x95, + 0x62, 0x17, 0x3d, 0x78, 0x9b, 0x1f, 0xe8, 0x8d, 0xbc, 0x39, 0xdd, 0x52, 0xd1, 0x0c, 0x1e, 0x35, + 0xbb, 0xbf, 0xa0, 0x95, 0x07, 0xcf, 0xff, 0x03, 0x00, 0x00, 0xff, 0xff, 0xed, 0x5f, 0x6c, 0x20, + 0x74, 0x13, 0x00, 0x00, +} diff --git a/vendor/github.com/gogo/protobuf/gogoproto/helper.go b/vendor/github.com/gogo/protobuf/gogoproto/helper.go new file mode 100644 index 00000000000..6b851c56239 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/gogoproto/helper.go @@ -0,0 +1,357 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package gogoproto + +import google_protobuf "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" +import proto "github.com/gogo/protobuf/proto" + +func IsEmbed(field *google_protobuf.FieldDescriptorProto) bool { + return proto.GetBoolExtension(field.Options, E_Embed, false) +} + +func IsNullable(field *google_protobuf.FieldDescriptorProto) bool { + return proto.GetBoolExtension(field.Options, E_Nullable, true) +} + +func IsStdTime(field *google_protobuf.FieldDescriptorProto) bool { + return proto.GetBoolExtension(field.Options, E_Stdtime, false) +} + +func IsStdDuration(field *google_protobuf.FieldDescriptorProto) bool { + return proto.GetBoolExtension(field.Options, E_Stdduration, false) +} + +func NeedsNilCheck(proto3 bool, field *google_protobuf.FieldDescriptorProto) bool { + nullable := IsNullable(field) + if field.IsMessage() || IsCustomType(field) { + return nullable + } + if proto3 { + return false + } + return nullable || *field.Type == google_protobuf.FieldDescriptorProto_TYPE_BYTES +} + +func IsCustomType(field *google_protobuf.FieldDescriptorProto) bool { + typ := GetCustomType(field) + if len(typ) > 0 { + return true + } + return false +} + +func IsCastType(field *google_protobuf.FieldDescriptorProto) bool { + typ := GetCastType(field) + if len(typ) > 0 { + return true + } + return false +} + +func IsCastKey(field *google_protobuf.FieldDescriptorProto) bool { + typ := GetCastKey(field) + if len(typ) > 0 { + return true + } + return false +} + +func IsCastValue(field *google_protobuf.FieldDescriptorProto) bool { + typ := GetCastValue(field) + if len(typ) > 0 { + return true + } + return false +} + +func HasEnumDecl(file *google_protobuf.FileDescriptorProto, enum *google_protobuf.EnumDescriptorProto) bool { + return proto.GetBoolExtension(enum.Options, E_Enumdecl, proto.GetBoolExtension(file.Options, E_EnumdeclAll, true)) +} + +func HasTypeDecl(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Typedecl, proto.GetBoolExtension(file.Options, E_TypedeclAll, true)) +} + +func GetCustomType(field *google_protobuf.FieldDescriptorProto) string { + if field == nil { + return "" + } + if field.Options != nil { + v, err := proto.GetExtension(field.Options, E_Customtype) + if err == nil && v.(*string) != nil { + return *(v.(*string)) + } + } + return "" +} + +func GetCastType(field *google_protobuf.FieldDescriptorProto) string { + if field == nil { + return "" + } + if field.Options != nil { + v, err := proto.GetExtension(field.Options, E_Casttype) + if err == nil && v.(*string) != nil { + return *(v.(*string)) + } + } + return "" +} + +func GetCastKey(field *google_protobuf.FieldDescriptorProto) string { + if field == nil { + return "" + } + if field.Options != nil { + v, err := proto.GetExtension(field.Options, E_Castkey) + if err == nil && v.(*string) != nil { + return *(v.(*string)) + } + } + return "" +} + +func GetCastValue(field *google_protobuf.FieldDescriptorProto) string { + if field == nil { + return "" + } + if field.Options != nil { + v, err := proto.GetExtension(field.Options, E_Castvalue) + if err == nil && v.(*string) != nil { + return *(v.(*string)) + } + } + return "" +} + +func IsCustomName(field *google_protobuf.FieldDescriptorProto) bool { + name := GetCustomName(field) + if len(name) > 0 { + return true + } + return false +} + +func IsEnumCustomName(field *google_protobuf.EnumDescriptorProto) bool { + name := GetEnumCustomName(field) + if len(name) > 0 { + return true + } + return false +} + +func IsEnumValueCustomName(field *google_protobuf.EnumValueDescriptorProto) bool { + name := GetEnumValueCustomName(field) + if len(name) > 0 { + return true + } + return false +} + +func GetCustomName(field *google_protobuf.FieldDescriptorProto) string { + if field == nil { + return "" + } + if field.Options != nil { + v, err := proto.GetExtension(field.Options, E_Customname) + if err == nil && v.(*string) != nil { + return *(v.(*string)) + } + } + return "" +} + +func GetEnumCustomName(field *google_protobuf.EnumDescriptorProto) string { + if field == nil { + return "" + } + if field.Options != nil { + v, err := proto.GetExtension(field.Options, E_EnumCustomname) + if err == nil && v.(*string) != nil { + return *(v.(*string)) + } + } + return "" +} + +func GetEnumValueCustomName(field *google_protobuf.EnumValueDescriptorProto) string { + if field == nil { + return "" + } + if field.Options != nil { + v, err := proto.GetExtension(field.Options, E_EnumvalueCustomname) + if err == nil && v.(*string) != nil { + return *(v.(*string)) + } + } + return "" +} + +func GetJsonTag(field *google_protobuf.FieldDescriptorProto) *string { + if field == nil { + return nil + } + if field.Options != nil { + v, err := proto.GetExtension(field.Options, E_Jsontag) + if err == nil && v.(*string) != nil { + return (v.(*string)) + } + } + return nil +} + +func GetMoreTags(field *google_protobuf.FieldDescriptorProto) *string { + if field == nil { + return nil + } + if field.Options != nil { + v, err := proto.GetExtension(field.Options, E_Moretags) + if err == nil && v.(*string) != nil { + return (v.(*string)) + } + } + return nil +} + +type EnableFunc func(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool + +func EnabledGoEnumPrefix(file *google_protobuf.FileDescriptorProto, enum *google_protobuf.EnumDescriptorProto) bool { + return proto.GetBoolExtension(enum.Options, E_GoprotoEnumPrefix, proto.GetBoolExtension(file.Options, E_GoprotoEnumPrefixAll, true)) +} + +func EnabledGoStringer(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_GoprotoStringer, proto.GetBoolExtension(file.Options, E_GoprotoStringerAll, true)) +} + +func HasGoGetters(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_GoprotoGetters, proto.GetBoolExtension(file.Options, E_GoprotoGettersAll, true)) +} + +func IsUnion(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Onlyone, proto.GetBoolExtension(file.Options, E_OnlyoneAll, false)) +} + +func HasGoString(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Gostring, proto.GetBoolExtension(file.Options, E_GostringAll, false)) +} + +func HasEqual(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Equal, proto.GetBoolExtension(file.Options, E_EqualAll, false)) +} + +func HasVerboseEqual(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_VerboseEqual, proto.GetBoolExtension(file.Options, E_VerboseEqualAll, false)) +} + +func IsStringer(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Stringer, proto.GetBoolExtension(file.Options, E_StringerAll, false)) +} + +func IsFace(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Face, proto.GetBoolExtension(file.Options, E_FaceAll, false)) +} + +func HasDescription(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Description, proto.GetBoolExtension(file.Options, E_DescriptionAll, false)) +} + +func HasPopulate(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Populate, proto.GetBoolExtension(file.Options, E_PopulateAll, false)) +} + +func HasTestGen(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Testgen, proto.GetBoolExtension(file.Options, E_TestgenAll, false)) +} + +func HasBenchGen(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Benchgen, proto.GetBoolExtension(file.Options, E_BenchgenAll, false)) +} + +func IsMarshaler(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Marshaler, proto.GetBoolExtension(file.Options, E_MarshalerAll, false)) +} + +func IsUnmarshaler(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Unmarshaler, proto.GetBoolExtension(file.Options, E_UnmarshalerAll, false)) +} + +func IsStableMarshaler(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_StableMarshaler, proto.GetBoolExtension(file.Options, E_StableMarshalerAll, false)) +} + +func IsSizer(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Sizer, proto.GetBoolExtension(file.Options, E_SizerAll, false)) +} + +func IsProtoSizer(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Protosizer, proto.GetBoolExtension(file.Options, E_ProtosizerAll, false)) +} + +func IsGoEnumStringer(file *google_protobuf.FileDescriptorProto, enum *google_protobuf.EnumDescriptorProto) bool { + return proto.GetBoolExtension(enum.Options, E_GoprotoEnumStringer, proto.GetBoolExtension(file.Options, E_GoprotoEnumStringerAll, true)) +} + +func IsEnumStringer(file *google_protobuf.FileDescriptorProto, enum *google_protobuf.EnumDescriptorProto) bool { + return proto.GetBoolExtension(enum.Options, E_EnumStringer, proto.GetBoolExtension(file.Options, E_EnumStringerAll, false)) +} + +func IsUnsafeMarshaler(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_UnsafeMarshaler, proto.GetBoolExtension(file.Options, E_UnsafeMarshalerAll, false)) +} + +func IsUnsafeUnmarshaler(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_UnsafeUnmarshaler, proto.GetBoolExtension(file.Options, E_UnsafeUnmarshalerAll, false)) +} + +func HasExtensionsMap(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_GoprotoExtensionsMap, proto.GetBoolExtension(file.Options, E_GoprotoExtensionsMapAll, true)) +} + +func HasUnrecognized(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + if IsProto3(file) { + return false + } + return proto.GetBoolExtension(message.Options, E_GoprotoUnrecognized, proto.GetBoolExtension(file.Options, E_GoprotoUnrecognizedAll, true)) +} + +func IsProto3(file *google_protobuf.FileDescriptorProto) bool { + return file.GetSyntax() == "proto3" +} + +func ImportsGoGoProto(file *google_protobuf.FileDescriptorProto) bool { + return proto.GetBoolExtension(file.Options, E_GogoprotoImport, true) +} + +func HasCompare(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool { + return proto.GetBoolExtension(message.Options, E_Compare, proto.GetBoolExtension(file.Options, E_CompareAll, false)) +} + +func RegistersGolangProto(file *google_protobuf.FileDescriptorProto) bool { + return proto.GetBoolExtension(file.Options, E_GoprotoRegistration, false) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/compare/compare.go b/vendor/github.com/gogo/protobuf/plugin/compare/compare.go new file mode 100644 index 00000000000..97d0a4a9aa2 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/compare/compare.go @@ -0,0 +1,526 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package compare + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/proto" + descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" + "github.com/gogo/protobuf/vanity" +) + +type plugin struct { + *generator.Generator + generator.PluginImports + fmtPkg generator.Single + bytesPkg generator.Single + sortkeysPkg generator.Single + protoPkg generator.Single +} + +func NewPlugin() *plugin { + return &plugin{} +} + +func (p *plugin) Name() string { + return "compare" +} + +func (p *plugin) Init(g *generator.Generator) { + p.Generator = g +} + +func (p *plugin) Generate(file *generator.FileDescriptor) { + p.PluginImports = generator.NewPluginImports(p.Generator) + p.fmtPkg = p.NewImport("fmt") + p.bytesPkg = p.NewImport("bytes") + p.sortkeysPkg = p.NewImport("github.com/gogo/protobuf/sortkeys") + p.protoPkg = p.NewImport("github.com/gogo/protobuf/proto") + + for _, msg := range file.Messages() { + if msg.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + if gogoproto.HasCompare(file.FileDescriptorProto, msg.DescriptorProto) { + p.generateMessage(file, msg) + } + } +} + +func (p *plugin) generateNullableField(fieldname string) { + p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`) + p.In() + p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`) + p.In() + p.P(`if *this.`, fieldname, ` < *that1.`, fieldname, `{`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) + p.P(`return 1`) + p.Out() + p.P(`}`) + p.Out() + p.P(`} else if this.`, fieldname, ` != nil {`) + p.In() + p.P(`return 1`) + p.Out() + p.P(`} else if that1.`, fieldname, ` != nil {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) +} + +func (p *plugin) generateMsgNullAndTypeCheck(ccTypeName string) { + p.P(`if that == nil {`) + p.In() + p.P(`if this == nil {`) + p.In() + p.P(`return 0`) + p.Out() + p.P(`}`) + p.P(`return 1`) + p.Out() + p.P(`}`) + p.P(``) + p.P(`that1, ok := that.(*`, ccTypeName, `)`) + p.P(`if !ok {`) + p.In() + p.P(`that2, ok := that.(`, ccTypeName, `)`) + p.P(`if ok {`) + p.In() + p.P(`that1 = &that2`) + p.Out() + p.P(`} else {`) + p.In() + p.P(`return 1`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + p.P(`if that1 == nil {`) + p.In() + p.P(`if this == nil {`) + p.In() + p.P(`return 0`) + p.Out() + p.P(`}`) + p.P(`return 1`) + p.Out() + p.P(`} else if this == nil {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) +} + +func (p *plugin) generateField(file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto) { + proto3 := gogoproto.IsProto3(file.FileDescriptorProto) + fieldname := p.GetOneOfFieldName(message, field) + repeated := field.IsRepeated() + ctype := gogoproto.IsCustomType(field) + nullable := gogoproto.IsNullable(field) + // oneof := field.OneofIndex != nil + if !repeated { + if ctype { + if nullable { + p.P(`if that1.`, fieldname, ` == nil {`) + p.In() + p.P(`if this.`, fieldname, ` != nil {`) + p.In() + p.P(`return 1`) + p.Out() + p.P(`}`) + p.Out() + p.P(`} else if this.`, fieldname, ` == nil {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`} else if c := this.`, fieldname, `.Compare(*that1.`, fieldname, `); c != 0 {`) + } else { + p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`) + } + p.In() + p.P(`return c`) + p.Out() + p.P(`}`) + } else { + if field.IsMessage() || p.IsGroup(field) { + if nullable { + p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`) + } else { + p.P(`if c := this.`, fieldname, `.Compare(&that1.`, fieldname, `); c != 0 {`) + } + p.In() + p.P(`return c`) + p.Out() + p.P(`}`) + } else if field.IsBytes() { + p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`) + p.In() + p.P(`return c`) + p.Out() + p.P(`}`) + } else if field.IsString() { + if nullable && !proto3 { + p.generateNullableField(fieldname) + } else { + p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`) + p.In() + p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) + p.P(`return 1`) + p.Out() + p.P(`}`) + } + } else if field.IsBool() { + if nullable && !proto3 { + p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`) + p.In() + p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`) + p.In() + p.P(`if !*this.`, fieldname, ` {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) + p.P(`return 1`) + p.Out() + p.P(`}`) + p.Out() + p.P(`} else if this.`, fieldname, ` != nil {`) + p.In() + p.P(`return 1`) + p.Out() + p.P(`} else if that1.`, fieldname, ` != nil {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) + } else { + p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`) + p.In() + p.P(`if !this.`, fieldname, ` {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) + p.P(`return 1`) + p.Out() + p.P(`}`) + } + } else { + if nullable && !proto3 { + p.generateNullableField(fieldname) + } else { + p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`) + p.In() + p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) + p.P(`return 1`) + p.Out() + p.P(`}`) + } + } + } + } else { + p.P(`if len(this.`, fieldname, `) != len(that1.`, fieldname, `) {`) + p.In() + p.P(`if len(this.`, fieldname, `) < len(that1.`, fieldname, `) {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) + p.P(`return 1`) + p.Out() + p.P(`}`) + p.P(`for i := range this.`, fieldname, ` {`) + p.In() + if ctype { + p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`) + p.In() + p.P(`return c`) + p.Out() + p.P(`}`) + } else { + if p.IsMap(field) { + m := p.GoMapType(nil, field) + valuegoTyp, _ := p.GoType(nil, m.ValueField) + valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField) + nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp) + + mapValue := m.ValueAliasField + if mapValue.IsMessage() || p.IsGroup(mapValue) { + if nullable && valuegoTyp == valuegoAliasTyp { + p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`) + } else { + // Compare() has a pointer receiver, but map value is a value type + a := `this.` + fieldname + `[i]` + b := `that1.` + fieldname + `[i]` + if valuegoTyp != valuegoAliasTyp { + // cast back to the type that has the generated methods on it + a = `(` + valuegoTyp + `)(` + a + `)` + b = `(` + valuegoTyp + `)(` + b + `)` + } + p.P(`a := `, a) + p.P(`b := `, b) + if nullable { + p.P(`if c := a.Compare(b); c != 0 {`) + } else { + p.P(`if c := (&a).Compare(&b); c != 0 {`) + } + } + p.In() + p.P(`return c`) + p.Out() + p.P(`}`) + } else if mapValue.IsBytes() { + p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`) + p.In() + p.P(`return c`) + p.Out() + p.P(`}`) + } else if mapValue.IsString() { + p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) + p.In() + p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) + p.P(`return 1`) + p.Out() + p.P(`}`) + } else { + p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) + p.In() + p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) + p.P(`return 1`) + p.Out() + p.P(`}`) + } + } else if field.IsMessage() || p.IsGroup(field) { + if nullable { + p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`) + p.In() + p.P(`return c`) + p.Out() + p.P(`}`) + } else { + p.P(`if c := this.`, fieldname, `[i].Compare(&that1.`, fieldname, `[i]); c != 0 {`) + p.In() + p.P(`return c`) + p.Out() + p.P(`}`) + } + } else if field.IsBytes() { + p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`) + p.In() + p.P(`return c`) + p.Out() + p.P(`}`) + } else if field.IsString() { + p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) + p.In() + p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) + p.P(`return 1`) + p.Out() + p.P(`}`) + } else if field.IsBool() { + p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) + p.In() + p.P(`if !this.`, fieldname, `[i] {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) + p.P(`return 1`) + p.Out() + p.P(`}`) + } else { + p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) + p.In() + p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) + p.P(`return 1`) + p.Out() + p.P(`}`) + } + } + p.Out() + p.P(`}`) + } +} + +func (p *plugin) generateMessage(file *generator.FileDescriptor, message *generator.Descriptor) { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`) + p.In() + p.generateMsgNullAndTypeCheck(ccTypeName) + oneofs := make(map[string]struct{}) + + for _, field := range message.Field { + oneof := field.OneofIndex != nil + if oneof { + fieldname := p.GetFieldName(message, field) + if _, ok := oneofs[fieldname]; ok { + continue + } else { + oneofs[fieldname] = struct{}{} + } + p.P(`if that1.`, fieldname, ` == nil {`) + p.In() + p.P(`if this.`, fieldname, ` != nil {`) + p.In() + p.P(`return 1`) + p.Out() + p.P(`}`) + p.Out() + p.P(`} else if this.`, fieldname, ` == nil {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`} else if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`) + p.In() + p.P(`return c`) + p.Out() + p.P(`}`) + } else { + p.generateField(file, message, field) + } + } + if message.DescriptorProto.HasExtension() { + if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`thismap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(this)`) + p.P(`thatmap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(that1)`) + p.P(`extkeys := make([]int32, 0, len(thismap)+len(thatmap))`) + p.P(`for k, _ := range thismap {`) + p.In() + p.P(`extkeys = append(extkeys, k)`) + p.Out() + p.P(`}`) + p.P(`for k, _ := range thatmap {`) + p.In() + p.P(`if _, ok := thismap[k]; !ok {`) + p.In() + p.P(`extkeys = append(extkeys, k)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + p.P(p.sortkeysPkg.Use(), `.Int32s(extkeys)`) + p.P(`for _, k := range extkeys {`) + p.In() + p.P(`if v, ok := thismap[k]; ok {`) + p.In() + p.P(`if v2, ok := thatmap[k]; ok {`) + p.In() + p.P(`if c := v.Compare(&v2); c != 0 {`) + p.In() + p.P(`return c`) + p.Out() + p.P(`}`) + p.Out() + p.P(`} else {`) + p.In() + p.P(`return 1`) + p.Out() + p.P(`}`) + p.Out() + p.P(`} else {`) + p.In() + p.P(`return -1`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + } else { + fieldname := "XXX_extensions" + p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`) + p.In() + p.P(`return c`) + p.Out() + p.P(`}`) + } + } + if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) { + fieldname := "XXX_unrecognized" + p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`) + p.In() + p.P(`return c`) + p.Out() + p.P(`}`) + } + p.P(`return 0`) + p.Out() + p.P(`}`) + + //Generate Compare methods for oneof fields + m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto) + for _, field := range m.Field { + oneof := field.OneofIndex != nil + if !oneof { + continue + } + ccTypeName := p.OneOfTypeName(message, field) + p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`) + p.In() + + p.generateMsgNullAndTypeCheck(ccTypeName) + vanity.TurnOffNullableForNativeTypes(field) + p.generateField(file, message, field) + + p.P(`return 0`) + p.Out() + p.P(`}`) + } +} + +func init() { + generator.RegisterPlugin(NewPlugin()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/compare/comparetest.go b/vendor/github.com/gogo/protobuf/plugin/compare/comparetest.go new file mode 100644 index 00000000000..4fbdbc633cd --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/compare/comparetest.go @@ -0,0 +1,118 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package compare + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/plugin/testgen" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type test struct { + *generator.Generator +} + +func NewTest(g *generator.Generator) testgen.TestPlugin { + return &test{g} +} + +func (p *test) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool { + used := false + randPkg := imports.NewImport("math/rand") + timePkg := imports.NewImport("time") + testingPkg := imports.NewImport("testing") + protoPkg := imports.NewImport("github.com/gogo/protobuf/proto") + unsafePkg := imports.NewImport("unsafe") + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + protoPkg = imports.NewImport("github.com/golang/protobuf/proto") + } + for _, message := range file.Messages() { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + if !gogoproto.HasCompare(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + + if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) { + used = true + hasUnsafe := gogoproto.IsUnsafeMarshaler(file.FileDescriptorProto, message.DescriptorProto) || + gogoproto.IsUnsafeUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) + p.P(`func Test`, ccTypeName, `Compare(t *`, testingPkg.Use(), `.T) {`) + p.In() + if hasUnsafe { + p.P(`var bigendian uint32 = 0x01020304`) + p.P(`if *(*byte)(`, unsafePkg.Use(), `.Pointer(&bigendian)) == 1 {`) + p.In() + p.P(`t.Skip("unsafe does not work on big endian architectures")`) + p.Out() + p.P(`}`) + } + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(`, timePkg.Use(), `.Now().UnixNano()))`) + p.P(`p := NewPopulated`, ccTypeName, `(popr, false)`) + p.P(`dAtA, err := `, protoPkg.Use(), `.Marshal(p)`) + p.P(`if err != nil {`) + p.In() + p.P(`panic(err)`) + p.Out() + p.P(`}`) + p.P(`msg := &`, ccTypeName, `{}`) + p.P(`if err := `, protoPkg.Use(), `.Unmarshal(dAtA, msg); err != nil {`) + p.In() + p.P(`panic(err)`) + p.Out() + p.P(`}`) + p.P(`if c := p.Compare(msg); c != 0 {`) + p.In() + p.P(`t.Fatalf("%#v !Compare %#v, since %d", msg, p, c)`) + p.Out() + p.P(`}`) + p.P(`p2 := NewPopulated`, ccTypeName, `(popr, false)`) + p.P(`c := p.Compare(p2)`) + p.P(`c2 := p2.Compare(p)`) + p.P(`if c != (-1 * c2) {`) + p.In() + p.P(`t.Errorf("p.Compare(p2) = %d", c)`) + p.P(`t.Errorf("p2.Compare(p) = %d", c2)`) + p.P(`t.Errorf("p = %#v", p)`) + p.P(`t.Errorf("p2 = %#v", p2)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + } + + } + return used +} + +func init() { + testgen.RegisterTestPlugin(NewTest) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/defaultcheck/defaultcheck.go b/vendor/github.com/gogo/protobuf/plugin/defaultcheck/defaultcheck.go new file mode 100644 index 00000000000..486f2877192 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/defaultcheck/defaultcheck.go @@ -0,0 +1,133 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The defaultcheck plugin is used to check whether nullable is not used incorrectly. +For instance: +An error is caused if a nullable field: + - has a default value, + - is an enum which does not start at zero, + - is used for an extension, + - is used for a native proto3 type, + - is used for a repeated native type. + +An error is also caused if a field with a default value is used in a message: + - which is a face. + - without getters. + +It is enabled by the following extensions: + + - nullable + +For incorrect usage of nullable with tests see: + + github.com/gogo/protobuf/test/nullableconflict + +*/ +package defaultcheck + +import ( + "fmt" + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" + "os" +) + +type plugin struct { + *generator.Generator +} + +func NewPlugin() *plugin { + return &plugin{} +} + +func (p *plugin) Name() string { + return "defaultcheck" +} + +func (p *plugin) Init(g *generator.Generator) { + p.Generator = g +} + +func (p *plugin) Generate(file *generator.FileDescriptor) { + proto3 := gogoproto.IsProto3(file.FileDescriptorProto) + for _, msg := range file.Messages() { + getters := gogoproto.HasGoGetters(file.FileDescriptorProto, msg.DescriptorProto) + face := gogoproto.IsFace(file.FileDescriptorProto, msg.DescriptorProto) + for _, field := range msg.GetField() { + if len(field.GetDefaultValue()) > 0 { + if !getters { + fmt.Fprintf(os.Stderr, "ERROR: field %v.%v cannot have a default value and not have a getter method", generator.CamelCase(*msg.Name), generator.CamelCase(*field.Name)) + os.Exit(1) + } + if face { + fmt.Fprintf(os.Stderr, "ERROR: field %v.%v cannot have a default value be in a face", generator.CamelCase(*msg.Name), generator.CamelCase(*field.Name)) + os.Exit(1) + } + } + if gogoproto.IsNullable(field) { + continue + } + if len(field.GetDefaultValue()) > 0 { + fmt.Fprintf(os.Stderr, "ERROR: field %v.%v cannot be non-nullable and have a default value", generator.CamelCase(*msg.Name), generator.CamelCase(*field.Name)) + os.Exit(1) + } + if !field.IsMessage() && !gogoproto.IsCustomType(field) { + if field.IsRepeated() { + fmt.Fprintf(os.Stderr, "WARNING: field %v.%v is a repeated non-nullable native type, nullable=false has no effect\n", generator.CamelCase(*msg.Name), generator.CamelCase(*field.Name)) + } else if proto3 { + fmt.Fprintf(os.Stderr, "ERROR: field %v.%v is a native type and in proto3 syntax with nullable=false there exists conflicting implementations when encoding zero values", generator.CamelCase(*msg.Name), generator.CamelCase(*field.Name)) + os.Exit(1) + } + if field.IsBytes() { + fmt.Fprintf(os.Stderr, "WARNING: field %v.%v is a non-nullable bytes type, nullable=false has no effect\n", generator.CamelCase(*msg.Name), generator.CamelCase(*field.Name)) + } + } + if !field.IsEnum() { + continue + } + enum := p.ObjectNamed(field.GetTypeName()).(*generator.EnumDescriptor) + if len(enum.Value) == 0 || enum.Value[0].GetNumber() != 0 { + fmt.Fprintf(os.Stderr, "ERROR: field %v.%v cannot be non-nullable and be an enum type %v which does not start with zero", generator.CamelCase(*msg.Name), generator.CamelCase(*field.Name), enum.GetName()) + os.Exit(1) + } + } + } + for _, e := range file.GetExtension() { + if !gogoproto.IsNullable(e) { + fmt.Fprintf(os.Stderr, "ERROR: extended field %v cannot be nullable %v", generator.CamelCase(e.GetName()), generator.CamelCase(*e.Name)) + os.Exit(1) + } + } +} + +func (p *plugin) GenerateImports(*generator.FileDescriptor) {} + +func init() { + generator.RegisterPlugin(NewPlugin()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/description/description.go b/vendor/github.com/gogo/protobuf/plugin/description/description.go new file mode 100644 index 00000000000..f72efba6128 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/description/description.go @@ -0,0 +1,201 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The description (experimental) plugin generates a Description method for each message. +The Description method returns a populated google_protobuf.FileDescriptorSet struct. +This contains the description of the files used to generate this message. + +It is enabled by the following extensions: + + - description + - description_all + +The description plugin also generates a test given it is enabled using one of the following extensions: + + - testgen + - testgen_all + +Let us look at: + + github.com/gogo/protobuf/test/example/example.proto + +Btw all the output can be seen at: + + github.com/gogo/protobuf/test/example/* + +The following message: + + message B { + option (gogoproto.description) = true; + optional A A = 1 [(gogoproto.nullable) = false, (gogoproto.embed) = true]; + repeated bytes G = 2 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uint128", (gogoproto.nullable) = false]; + } + +given to the description plugin, will generate the following code: + + func (this *B) Description() (desc *google_protobuf.FileDescriptorSet) { + return ExampleDescription() + } + +and the following test code: + + func TestDescription(t *testing9.T) { + ExampleDescription() + } + +The hope is to use this struct in some way instead of reflect. +This package is subject to change, since a use has not been figured out yet. + +*/ +package description + +import ( + "bytes" + "compress/gzip" + "fmt" + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/proto" + descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type plugin struct { + *generator.Generator + generator.PluginImports +} + +func NewPlugin() *plugin { + return &plugin{} +} + +func (p *plugin) Name() string { + return "description" +} + +func (p *plugin) Init(g *generator.Generator) { + p.Generator = g +} + +func (p *plugin) Generate(file *generator.FileDescriptor) { + used := false + localName := generator.FileName(file) + + p.PluginImports = generator.NewPluginImports(p.Generator) + descriptorPkg := p.NewImport("github.com/gogo/protobuf/protoc-gen-gogo/descriptor") + protoPkg := p.NewImport("github.com/gogo/protobuf/proto") + gzipPkg := p.NewImport("compress/gzip") + bytesPkg := p.NewImport("bytes") + ioutilPkg := p.NewImport("io/ioutil") + + for _, message := range file.Messages() { + if !gogoproto.HasDescription(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + used = true + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + p.P(`func (this *`, ccTypeName, `) Description() (desc *`, descriptorPkg.Use(), `.FileDescriptorSet) {`) + p.In() + p.P(`return `, localName, `Description()`) + p.Out() + p.P(`}`) + } + + if used { + + p.P(`func `, localName, `Description() (desc *`, descriptorPkg.Use(), `.FileDescriptorSet) {`) + p.In() + //Don't generate SourceCodeInfo, since it will create too much code. + + ss := make([]*descriptor.SourceCodeInfo, 0) + for _, f := range p.Generator.AllFiles().GetFile() { + ss = append(ss, f.SourceCodeInfo) + f.SourceCodeInfo = nil + } + b, err := proto.Marshal(p.Generator.AllFiles()) + if err != nil { + panic(err) + } + for i, f := range p.Generator.AllFiles().GetFile() { + f.SourceCodeInfo = ss[i] + } + p.P(`d := &`, descriptorPkg.Use(), `.FileDescriptorSet{}`) + var buf bytes.Buffer + w, _ := gzip.NewWriterLevel(&buf, gzip.BestCompression) + w.Write(b) + w.Close() + b = buf.Bytes() + p.P("var gzipped = []byte{") + p.In() + p.P("// ", len(b), " bytes of a gzipped FileDescriptorSet") + for len(b) > 0 { + n := 16 + if n > len(b) { + n = len(b) + } + + s := "" + for _, c := range b[:n] { + s += fmt.Sprintf("0x%02x,", c) + } + p.P(s) + + b = b[n:] + } + p.Out() + p.P("}") + p.P(`r := `, bytesPkg.Use(), `.NewReader(gzipped)`) + p.P(`gzipr, err := `, gzipPkg.Use(), `.NewReader(r)`) + p.P(`if err != nil {`) + p.In() + p.P(`panic(err)`) + p.Out() + p.P(`}`) + p.P(`ungzipped, err := `, ioutilPkg.Use(), `.ReadAll(gzipr)`) + p.P(`if err != nil {`) + p.In() + p.P(`panic(err)`) + p.Out() + p.P(`}`) + p.P(`if err := `, protoPkg.Use(), `.Unmarshal(ungzipped, d); err != nil {`) + p.In() + p.P(`panic(err)`) + p.Out() + p.P(`}`) + p.P(`return d`) + p.Out() + p.P(`}`) + } +} + +func init() { + generator.RegisterPlugin(NewPlugin()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/description/descriptiontest.go b/vendor/github.com/gogo/protobuf/plugin/description/descriptiontest.go new file mode 100644 index 00000000000..babcd311da4 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/description/descriptiontest.go @@ -0,0 +1,73 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package description + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/plugin/testgen" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type test struct { + *generator.Generator +} + +func NewTest(g *generator.Generator) testgen.TestPlugin { + return &test{g} +} + +func (p *test) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool { + used := false + testingPkg := imports.NewImport("testing") + for _, message := range file.Messages() { + if !gogoproto.HasDescription(file.FileDescriptorProto, message.DescriptorProto) || + !gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + used = true + } + + if used { + localName := generator.FileName(file) + p.P(`func Test`, localName, `Description(t *`, testingPkg.Use(), `.T) {`) + p.In() + p.P(localName, `Description()`) + p.Out() + p.P(`}`) + + } + return used +} + +func init() { + testgen.RegisterTestPlugin(NewTest) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/embedcheck/embedcheck.go b/vendor/github.com/gogo/protobuf/plugin/embedcheck/embedcheck.go new file mode 100644 index 00000000000..1cb77cacb10 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/embedcheck/embedcheck.go @@ -0,0 +1,199 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The embedcheck plugin is used to check whether embed is not used incorrectly. +For instance: +An embedded message has a generated string method, but the is a member of a message which does not. +This causes a warning. +An error is caused by a namespace conflict. + +It is enabled by the following extensions: + + - embed + - embed_all + +For incorrect usage of embed with tests see: + + github.com/gogo/protobuf/test/embedconflict + +*/ +package embedcheck + +import ( + "fmt" + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" + "os" +) + +type plugin struct { + *generator.Generator +} + +func NewPlugin() *plugin { + return &plugin{} +} + +func (p *plugin) Name() string { + return "embedcheck" +} + +func (p *plugin) Init(g *generator.Generator) { + p.Generator = g +} + +var overwriters []map[string]gogoproto.EnableFunc = []map[string]gogoproto.EnableFunc{ + { + "stringer": gogoproto.IsStringer, + }, + { + "gostring": gogoproto.HasGoString, + }, + { + "equal": gogoproto.HasEqual, + }, + { + "verboseequal": gogoproto.HasVerboseEqual, + }, + { + "size": gogoproto.IsSizer, + "protosizer": gogoproto.IsProtoSizer, + }, + { + "unmarshaler": gogoproto.IsUnmarshaler, + "unsafe_unmarshaler": gogoproto.IsUnsafeUnmarshaler, + }, + { + "marshaler": gogoproto.IsMarshaler, + "unsafe_marshaler": gogoproto.IsUnsafeMarshaler, + }, +} + +func (p *plugin) Generate(file *generator.FileDescriptor) { + for _, msg := range file.Messages() { + for _, os := range overwriters { + possible := true + for _, overwriter := range os { + if overwriter(file.FileDescriptorProto, msg.DescriptorProto) { + possible = false + } + } + if possible { + p.checkOverwrite(msg, os) + } + } + p.checkNameSpace(msg) + for _, field := range msg.GetField() { + if gogoproto.IsEmbed(field) && gogoproto.IsCustomName(field) { + fmt.Fprintf(os.Stderr, "ERROR: field %v with custom name %v cannot be embedded", *field.Name, gogoproto.GetCustomName(field)) + os.Exit(1) + } + } + p.checkRepeated(msg) + } + for _, e := range file.GetExtension() { + if gogoproto.IsEmbed(e) { + fmt.Fprintf(os.Stderr, "ERROR: extended field %v cannot be embedded", generator.CamelCase(*e.Name)) + os.Exit(1) + } + } +} + +func (p *plugin) checkNameSpace(message *generator.Descriptor) map[string]bool { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + names := make(map[string]bool) + for _, field := range message.Field { + fieldname := generator.CamelCase(*field.Name) + if field.IsMessage() && gogoproto.IsEmbed(field) { + desc := p.ObjectNamed(field.GetTypeName()) + moreNames := p.checkNameSpace(desc.(*generator.Descriptor)) + for another := range moreNames { + if names[another] { + fmt.Fprintf(os.Stderr, "ERROR: duplicate embedded fieldname %v in type %v\n", fieldname, ccTypeName) + os.Exit(1) + } + names[another] = true + } + } else { + if names[fieldname] { + fmt.Fprintf(os.Stderr, "ERROR: duplicate embedded fieldname %v in type %v\n", fieldname, ccTypeName) + os.Exit(1) + } + names[fieldname] = true + } + } + return names +} + +func (p *plugin) checkOverwrite(message *generator.Descriptor, enablers map[string]gogoproto.EnableFunc) { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + names := []string{} + for name := range enablers { + names = append(names, name) + } + for _, field := range message.Field { + if field.IsMessage() && gogoproto.IsEmbed(field) { + fieldname := generator.CamelCase(*field.Name) + desc := p.ObjectNamed(field.GetTypeName()) + msg := desc.(*generator.Descriptor) + for errStr, enabled := range enablers { + if enabled(msg.File(), msg.DescriptorProto) { + fmt.Fprintf(os.Stderr, "WARNING: found non-%v %v with embedded %v %v\n", names, ccTypeName, errStr, fieldname) + } + } + p.checkOverwrite(msg, enablers) + } + } +} + +func (p *plugin) checkRepeated(message *generator.Descriptor) { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + for _, field := range message.Field { + if !gogoproto.IsEmbed(field) { + continue + } + if field.IsBytes() { + fieldname := generator.CamelCase(*field.Name) + fmt.Fprintf(os.Stderr, "ERROR: found embedded bytes field %s in message %s\n", fieldname, ccTypeName) + os.Exit(1) + } + if !field.IsRepeated() { + continue + } + fieldname := generator.CamelCase(*field.Name) + fmt.Fprintf(os.Stderr, "ERROR: found repeated embedded field %s in message %s\n", fieldname, ccTypeName) + os.Exit(1) + } +} + +func (p *plugin) GenerateImports(*generator.FileDescriptor) {} + +func init() { + generator.RegisterPlugin(NewPlugin()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/enumstringer/enumstringer.go b/vendor/github.com/gogo/protobuf/plugin/enumstringer/enumstringer.go new file mode 100644 index 00000000000..04d6e547fc3 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/enumstringer/enumstringer.go @@ -0,0 +1,104 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The enumstringer (experimental) plugin generates a String method for each enum. + +It is enabled by the following extensions: + + - enum_stringer + - enum_stringer_all + +This package is subject to change. + +*/ +package enumstringer + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type enumstringer struct { + *generator.Generator + generator.PluginImports + atleastOne bool + localName string +} + +func NewEnumStringer() *enumstringer { + return &enumstringer{} +} + +func (p *enumstringer) Name() string { + return "enumstringer" +} + +func (p *enumstringer) Init(g *generator.Generator) { + p.Generator = g +} + +func (p *enumstringer) Generate(file *generator.FileDescriptor) { + p.PluginImports = generator.NewPluginImports(p.Generator) + p.atleastOne = false + + p.localName = generator.FileName(file) + + strconvPkg := p.NewImport("strconv") + + for _, enum := range file.Enums() { + if !gogoproto.IsEnumStringer(file.FileDescriptorProto, enum.EnumDescriptorProto) { + continue + } + if gogoproto.IsGoEnumStringer(file.FileDescriptorProto, enum.EnumDescriptorProto) { + panic("Go enum stringer conflicts with new enumstringer plugin: please use gogoproto.goproto_enum_stringer or gogoproto.goproto_enum_string_all and set it to false") + } + p.atleastOne = true + ccTypeName := generator.CamelCaseSlice(enum.TypeName()) + p.P("func (x ", ccTypeName, ") String() string {") + p.In() + p.P(`s, ok := `, ccTypeName, `_name[int32(x)]`) + p.P(`if ok {`) + p.In() + p.P(`return s`) + p.Out() + p.P(`}`) + p.P(`return `, strconvPkg.Use(), `.Itoa(int(x))`) + p.Out() + p.P(`}`) + } + + if !p.atleastOne { + return + } + +} + +func init() { + generator.RegisterPlugin(NewEnumStringer()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/equal/equal.go b/vendor/github.com/gogo/protobuf/plugin/equal/equal.go new file mode 100644 index 00000000000..41a2c97041d --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/equal/equal.go @@ -0,0 +1,631 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The equal plugin generates an Equal and a VerboseEqual method for each message. +These equal methods are quite obvious. +The only difference is that VerboseEqual returns a non nil error if it is not equal. +This error contains more detail on exactly which part of the message was not equal to the other message. +The idea is that this is useful for debugging. + +Equal is enabled using the following extensions: + + - equal + - equal_all + +While VerboseEqual is enable dusing the following extensions: + + - verbose_equal + - verbose_equal_all + +The equal plugin also generates a test given it is enabled using one of the following extensions: + + - testgen + - testgen_all + +Let us look at: + + github.com/gogo/protobuf/test/example/example.proto + +Btw all the output can be seen at: + + github.com/gogo/protobuf/test/example/* + +The following message: + + option (gogoproto.equal_all) = true; + option (gogoproto.verbose_equal_all) = true; + + message B { + optional A A = 1 [(gogoproto.nullable) = false, (gogoproto.embed) = true]; + repeated bytes G = 2 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uint128", (gogoproto.nullable) = false]; + } + +given to the equal plugin, will generate the following code: + + func (this *B) VerboseEqual(that interface{}) error { + if that == nil { + if this == nil { + return nil + } + return fmt2.Errorf("that == nil && this != nil") + } + + that1, ok := that.(*B) + if !ok { + return fmt2.Errorf("that is not of type *B") + } + if that1 == nil { + if this == nil { + return nil + } + return fmt2.Errorf("that is type *B but is nil && this != nil") + } else if this == nil { + return fmt2.Errorf("that is type *B but is not nil && this == nil") + } + if !this.A.Equal(&that1.A) { + return fmt2.Errorf("A this(%v) Not Equal that(%v)", this.A, that1.A) + } + if len(this.G) != len(that1.G) { + return fmt2.Errorf("G this(%v) Not Equal that(%v)", len(this.G), len(that1.G)) + } + for i := range this.G { + if !this.G[i].Equal(that1.G[i]) { + return fmt2.Errorf("G this[%v](%v) Not Equal that[%v](%v)", i, this.G[i], i, that1.G[i]) + } + } + if !bytes.Equal(this.XXX_unrecognized, that1.XXX_unrecognized) { + return fmt2.Errorf("XXX_unrecognized this(%v) Not Equal that(%v)", this.XXX_unrecognized, that1.XXX_unrecognized) + } + return nil + } + + func (this *B) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + that1, ok := that.(*B) + if !ok { + return false + } + if that1 == nil { + return this == nil + } else if this == nil { + return false + } + if !this.A.Equal(&that1.A) { + return false + } + if len(this.G) != len(that1.G) { + return false + } + for i := range this.G { + if !this.G[i].Equal(that1.G[i]) { + return false + } + } + if !bytes.Equal(this.XXX_unrecognized, that1.XXX_unrecognized) { + return false + } + return true + } + +and the following test code: + + func TestBVerboseEqual(t *testing8.T) { + popr := math_rand8.New(math_rand8.NewSource(time8.Now().UnixNano())) + p := NewPopulatedB(popr, false) + dAtA, err := github_com_gogo_protobuf_proto2.Marshal(p) + if err != nil { + panic(err) + } + msg := &B{} + if err := github_com_gogo_protobuf_proto2.Unmarshal(dAtA, msg); err != nil { + panic(err) + } + if err := p.VerboseEqual(msg); err != nil { + t.Fatalf("%#v !VerboseEqual %#v, since %v", msg, p, err) + } + +*/ +package equal + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/proto" + descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" + "github.com/gogo/protobuf/vanity" +) + +type plugin struct { + *generator.Generator + generator.PluginImports + fmtPkg generator.Single + bytesPkg generator.Single + protoPkg generator.Single +} + +func NewPlugin() *plugin { + return &plugin{} +} + +func (p *plugin) Name() string { + return "equal" +} + +func (p *plugin) Init(g *generator.Generator) { + p.Generator = g +} + +func (p *plugin) Generate(file *generator.FileDescriptor) { + p.PluginImports = generator.NewPluginImports(p.Generator) + p.fmtPkg = p.NewImport("fmt") + p.bytesPkg = p.NewImport("bytes") + p.protoPkg = p.NewImport("github.com/gogo/protobuf/proto") + + for _, msg := range file.Messages() { + if msg.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + if gogoproto.HasVerboseEqual(file.FileDescriptorProto, msg.DescriptorProto) { + p.generateMessage(file, msg, true) + } + if gogoproto.HasEqual(file.FileDescriptorProto, msg.DescriptorProto) { + p.generateMessage(file, msg, false) + } + } +} + +func (p *plugin) generateNullableField(fieldname string, verbose bool) { + p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`) + p.In() + p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`) + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("`, fieldname, ` this(%v) Not Equal that(%v)", *this.`, fieldname, `, *that1.`, fieldname, `)`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + p.Out() + p.P(`} else if this.`, fieldname, ` != nil {`) + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("this.`, fieldname, ` == nil && that.`, fieldname, ` != nil")`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`} else if that1.`, fieldname, ` != nil {`) +} + +func (p *plugin) generateMsgNullAndTypeCheck(ccTypeName string, verbose bool) { + p.P(`if that == nil {`) + p.In() + if verbose { + p.P(`if this == nil {`) + p.In() + p.P(`return nil`) + p.Out() + p.P(`}`) + p.P(`return `, p.fmtPkg.Use(), `.Errorf("that == nil && this != nil")`) + } else { + p.P(`return this == nil`) + } + p.Out() + p.P(`}`) + p.P(``) + p.P(`that1, ok := that.(*`, ccTypeName, `)`) + p.P(`if !ok {`) + p.In() + p.P(`that2, ok := that.(`, ccTypeName, `)`) + p.P(`if ok {`) + p.In() + p.P(`that1 = &that2`) + p.Out() + p.P(`} else {`) + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("that is not of type *`, ccTypeName, `")`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + p.P(`if that1 == nil {`) + p.In() + if verbose { + p.P(`if this == nil {`) + p.In() + p.P(`return nil`) + p.Out() + p.P(`}`) + p.P(`return `, p.fmtPkg.Use(), `.Errorf("that is type *`, ccTypeName, ` but is nil && this != nil")`) + } else { + p.P(`return this == nil`) + } + p.Out() + p.P(`} else if this == nil {`) + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("that is type *`, ccTypeName, ` but is not nil && this == nil")`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) +} + +func (p *plugin) generateField(file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto, verbose bool) { + proto3 := gogoproto.IsProto3(file.FileDescriptorProto) + fieldname := p.GetOneOfFieldName(message, field) + repeated := field.IsRepeated() + ctype := gogoproto.IsCustomType(field) + nullable := gogoproto.IsNullable(field) + isDuration := gogoproto.IsStdDuration(field) + isTimestamp := gogoproto.IsStdTime(field) + // oneof := field.OneofIndex != nil + if !repeated { + if ctype || isTimestamp { + if nullable { + p.P(`if that1.`, fieldname, ` == nil {`) + p.In() + p.P(`if this.`, fieldname, ` != nil {`) + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("this.`, fieldname, ` != nil && that1.`, fieldname, ` == nil")`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + p.Out() + p.P(`} else if !this.`, fieldname, `.Equal(*that1.`, fieldname, `) {`) + } else { + p.P(`if !this.`, fieldname, `.Equal(that1.`, fieldname, `) {`) + } + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("`, fieldname, ` this(%v) Not Equal that(%v)", this.`, fieldname, `, that1.`, fieldname, `)`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + } else if isDuration { + if nullable { + p.generateNullableField(fieldname, verbose) + } else { + p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`) + } + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("`, fieldname, ` this(%v) Not Equal that(%v)", this.`, fieldname, `, that1.`, fieldname, `)`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + } else { + if field.IsMessage() || p.IsGroup(field) { + if nullable { + p.P(`if !this.`, fieldname, `.Equal(that1.`, fieldname, `) {`) + } else { + p.P(`if !this.`, fieldname, `.Equal(&that1.`, fieldname, `) {`) + } + } else if field.IsBytes() { + p.P(`if !`, p.bytesPkg.Use(), `.Equal(this.`, fieldname, `, that1.`, fieldname, `) {`) + } else if field.IsString() { + if nullable && !proto3 { + p.generateNullableField(fieldname, verbose) + } else { + p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`) + } + } else { + if nullable && !proto3 { + p.generateNullableField(fieldname, verbose) + } else { + p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`) + } + } + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("`, fieldname, ` this(%v) Not Equal that(%v)", this.`, fieldname, `, that1.`, fieldname, `)`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + } + } else { + p.P(`if len(this.`, fieldname, `) != len(that1.`, fieldname, `) {`) + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("`, fieldname, ` this(%v) Not Equal that(%v)", len(this.`, fieldname, `), len(that1.`, fieldname, `))`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + p.P(`for i := range this.`, fieldname, ` {`) + p.In() + if ctype && !p.IsMap(field) { + p.P(`if !this.`, fieldname, `[i].Equal(that1.`, fieldname, `[i]) {`) + } else if isTimestamp { + if nullable { + p.P(`if !this.`, fieldname, `[i].Equal(*that1.`, fieldname, `[i]) {`) + } else { + p.P(`if !this.`, fieldname, `[i].Equal(that1.`, fieldname, `[i]) {`) + } + } else if isDuration { + if nullable { + p.P(`if dthis, dthat := this.`, fieldname, `[i], that1.`, fieldname, `[i]; (dthis != nil && dthat != nil && *dthis != *dthat) || (dthis != nil && dthat == nil) || (dthis == nil && dthat != nil) {`) + } else { + p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) + } + } else { + if p.IsMap(field) { + m := p.GoMapType(nil, field) + valuegoTyp, _ := p.GoType(nil, m.ValueField) + valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField) + nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp) + + mapValue := m.ValueAliasField + if mapValue.IsMessage() || p.IsGroup(mapValue) { + if nullable && valuegoTyp == valuegoAliasTyp { + p.P(`if !this.`, fieldname, `[i].Equal(that1.`, fieldname, `[i]) {`) + } else { + // Equal() has a pointer receiver, but map value is a value type + a := `this.` + fieldname + `[i]` + b := `that1.` + fieldname + `[i]` + if valuegoTyp != valuegoAliasTyp { + // cast back to the type that has the generated methods on it + a = `(` + valuegoTyp + `)(` + a + `)` + b = `(` + valuegoTyp + `)(` + b + `)` + } + p.P(`a := `, a) + p.P(`b := `, b) + if nullable { + p.P(`if !a.Equal(b) {`) + } else { + p.P(`if !(&a).Equal(&b) {`) + } + } + } else if mapValue.IsBytes() { + if ctype { + if nullable { + p.P(`if !this.`, fieldname, `[i].Equal(*that1.`, fieldname, `[i]) { //nullable`) + } else { + p.P(`if !this.`, fieldname, `[i].Equal(that1.`, fieldname, `[i]) { //not nullable`) + } + } else { + p.P(`if !`, p.bytesPkg.Use(), `.Equal(this.`, fieldname, `[i], that1.`, fieldname, `[i]) {`) + } + } else if mapValue.IsString() { + p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) + } else { + p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) + } + } else if field.IsMessage() || p.IsGroup(field) { + if nullable { + p.P(`if !this.`, fieldname, `[i].Equal(that1.`, fieldname, `[i]) {`) + } else { + p.P(`if !this.`, fieldname, `[i].Equal(&that1.`, fieldname, `[i]) {`) + } + } else if field.IsBytes() { + p.P(`if !`, p.bytesPkg.Use(), `.Equal(this.`, fieldname, `[i], that1.`, fieldname, `[i]) {`) + } else if field.IsString() { + p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) + } else { + p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) + } + } + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("`, fieldname, ` this[%v](%v) Not Equal that[%v](%v)", i, this.`, fieldname, `[i], i, that1.`, fieldname, `[i])`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + } +} + +func (p *plugin) generateMessage(file *generator.FileDescriptor, message *generator.Descriptor, verbose bool) { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + if verbose { + p.P(`func (this *`, ccTypeName, `) VerboseEqual(that interface{}) error {`) + } else { + p.P(`func (this *`, ccTypeName, `) Equal(that interface{}) bool {`) + } + p.In() + p.generateMsgNullAndTypeCheck(ccTypeName, verbose) + oneofs := make(map[string]struct{}) + + for _, field := range message.Field { + oneof := field.OneofIndex != nil + if oneof { + fieldname := p.GetFieldName(message, field) + if _, ok := oneofs[fieldname]; ok { + continue + } else { + oneofs[fieldname] = struct{}{} + } + p.P(`if that1.`, fieldname, ` == nil {`) + p.In() + p.P(`if this.`, fieldname, ` != nil {`) + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("this.`, fieldname, ` != nil && that1.`, fieldname, ` == nil")`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + p.Out() + p.P(`} else if this.`, fieldname, ` == nil {`) + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("this.`, fieldname, ` == nil && that1.`, fieldname, ` != nil")`) + } else { + p.P(`return false`) + } + p.Out() + if verbose { + p.P(`} else if err := this.`, fieldname, `.VerboseEqual(that1.`, fieldname, `); err != nil {`) + } else { + p.P(`} else if !this.`, fieldname, `.Equal(that1.`, fieldname, `) {`) + } + p.In() + if verbose { + p.P(`return err`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + } else { + p.generateField(file, message, field, verbose) + } + } + if message.DescriptorProto.HasExtension() { + if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) { + fieldname := "XXX_InternalExtensions" + p.P(`thismap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(this)`) + p.P(`thatmap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(that1)`) + p.P(`for k, v := range thismap {`) + p.In() + p.P(`if v2, ok := thatmap[k]; ok {`) + p.In() + p.P(`if !v.Equal(&v2) {`) + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("`, fieldname, ` this[%v](%v) Not Equal that[%v](%v)", k, thismap[k], k, thatmap[k])`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + p.Out() + p.P(`} else {`) + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("`, fieldname, `[%v] Not In that", k)`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + + p.P(`for k, _ := range thatmap {`) + p.In() + p.P(`if _, ok := thismap[k]; !ok {`) + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("`, fieldname, `[%v] Not In this", k)`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + } else { + fieldname := "XXX_extensions" + p.P(`if !`, p.bytesPkg.Use(), `.Equal(this.`, fieldname, `, that1.`, fieldname, `) {`) + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("`, fieldname, ` this(%v) Not Equal that(%v)", this.`, fieldname, `, that1.`, fieldname, `)`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + } + } + if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) { + fieldname := "XXX_unrecognized" + p.P(`if !`, p.bytesPkg.Use(), `.Equal(this.`, fieldname, `, that1.`, fieldname, `) {`) + p.In() + if verbose { + p.P(`return `, p.fmtPkg.Use(), `.Errorf("`, fieldname, ` this(%v) Not Equal that(%v)", this.`, fieldname, `, that1.`, fieldname, `)`) + } else { + p.P(`return false`) + } + p.Out() + p.P(`}`) + } + if verbose { + p.P(`return nil`) + } else { + p.P(`return true`) + } + p.Out() + p.P(`}`) + + //Generate Equal methods for oneof fields + m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto) + for _, field := range m.Field { + oneof := field.OneofIndex != nil + if !oneof { + continue + } + ccTypeName := p.OneOfTypeName(message, field) + if verbose { + p.P(`func (this *`, ccTypeName, `) VerboseEqual(that interface{}) error {`) + } else { + p.P(`func (this *`, ccTypeName, `) Equal(that interface{}) bool {`) + } + p.In() + + p.generateMsgNullAndTypeCheck(ccTypeName, verbose) + vanity.TurnOffNullableForNativeTypes(field) + p.generateField(file, message, field, verbose) + + if verbose { + p.P(`return nil`) + } else { + p.P(`return true`) + } + p.Out() + p.P(`}`) + } +} + +func init() { + generator.RegisterPlugin(NewPlugin()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/equal/equaltest.go b/vendor/github.com/gogo/protobuf/plugin/equal/equaltest.go new file mode 100644 index 00000000000..1233647a56d --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/equal/equaltest.go @@ -0,0 +1,109 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package equal + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/plugin/testgen" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type test struct { + *generator.Generator +} + +func NewTest(g *generator.Generator) testgen.TestPlugin { + return &test{g} +} + +func (p *test) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool { + used := false + randPkg := imports.NewImport("math/rand") + timePkg := imports.NewImport("time") + testingPkg := imports.NewImport("testing") + protoPkg := imports.NewImport("github.com/gogo/protobuf/proto") + unsafePkg := imports.NewImport("unsafe") + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + protoPkg = imports.NewImport("github.com/golang/protobuf/proto") + } + for _, message := range file.Messages() { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + if !gogoproto.HasVerboseEqual(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + + if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) { + used = true + hasUnsafe := gogoproto.IsUnsafeMarshaler(file.FileDescriptorProto, message.DescriptorProto) || + gogoproto.IsUnsafeUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) + p.P(`func Test`, ccTypeName, `VerboseEqual(t *`, testingPkg.Use(), `.T) {`) + p.In() + if hasUnsafe { + if hasUnsafe { + p.P(`var bigendian uint32 = 0x01020304`) + p.P(`if *(*byte)(`, unsafePkg.Use(), `.Pointer(&bigendian)) == 1 {`) + p.In() + p.P(`t.Skip("unsafe does not work on big endian architectures")`) + p.Out() + p.P(`}`) + } + } + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(`, timePkg.Use(), `.Now().UnixNano()))`) + p.P(`p := NewPopulated`, ccTypeName, `(popr, false)`) + p.P(`dAtA, err := `, protoPkg.Use(), `.Marshal(p)`) + p.P(`if err != nil {`) + p.In() + p.P(`panic(err)`) + p.Out() + p.P(`}`) + p.P(`msg := &`, ccTypeName, `{}`) + p.P(`if err := `, protoPkg.Use(), `.Unmarshal(dAtA, msg); err != nil {`) + p.In() + p.P(`panic(err)`) + p.Out() + p.P(`}`) + p.P(`if err := p.VerboseEqual(msg); err != nil {`) + p.In() + p.P(`t.Fatalf("%#v !VerboseEqual %#v, since %v", msg, p, err)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + } + + } + return used +} + +func init() { + testgen.RegisterTestPlugin(NewTest) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/face/face.go b/vendor/github.com/gogo/protobuf/plugin/face/face.go new file mode 100644 index 00000000000..a0293452652 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/face/face.go @@ -0,0 +1,233 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The face plugin generates a function will be generated which can convert a structure which satisfies an interface (face) to the specified structure. +This interface contains getters for each of the fields in the struct. +The specified struct is also generated with the getters. +This means that getters should be turned off so as not to conflict with face getters. +This allows it to satisfy its own face. + +It is enabled by the following extensions: + + - face + - face_all + +Turn off getters by using the following extensions: + + - getters + - getters_all + +The face plugin also generates a test given it is enabled using one of the following extensions: + + - testgen + - testgen_all + +Let us look at: + + github.com/gogo/protobuf/test/example/example.proto + +Btw all the output can be seen at: + + github.com/gogo/protobuf/test/example/* + +The following message: + + message A { + option (gogoproto.face) = true; + option (gogoproto.goproto_getters) = false; + optional string Description = 1 [(gogoproto.nullable) = false]; + optional int64 Number = 2 [(gogoproto.nullable) = false]; + optional bytes Id = 3 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uuid", (gogoproto.nullable) = false]; + } + +given to the face plugin, will generate the following code: + + type AFace interface { + Proto() github_com_gogo_protobuf_proto.Message + GetDescription() string + GetNumber() int64 + GetId() github_com_gogo_protobuf_test_custom.Uuid + } + + func (this *A) Proto() github_com_gogo_protobuf_proto.Message { + return this + } + + func (this *A) TestProto() github_com_gogo_protobuf_proto.Message { + return NewAFromFace(this) + } + + func (this *A) GetDescription() string { + return this.Description + } + + func (this *A) GetNumber() int64 { + return this.Number + } + + func (this *A) GetId() github_com_gogo_protobuf_test_custom.Uuid { + return this.Id + } + + func NewAFromFace(that AFace) *A { + this := &A{} + this.Description = that.GetDescription() + this.Number = that.GetNumber() + this.Id = that.GetId() + return this + } + +and the following test code: + + func TestAFace(t *testing7.T) { + popr := math_rand7.New(math_rand7.NewSource(time7.Now().UnixNano())) + p := NewPopulatedA(popr, true) + msg := p.TestProto() + if !p.Equal(msg) { + t.Fatalf("%#v !Face Equal %#v", msg, p) + } + } + +The struct A, representing the message, will also be generated just like always. +As you can see A satisfies its own Face, AFace. + +Creating another struct which satisfies AFace is very easy. +Simply create all these methods specified in AFace. +Implementing The Proto method is done with the helper function NewAFromFace: + + func (this *MyStruct) Proto() proto.Message { + return NewAFromFace(this) + } + +just the like TestProto method which is used to test the NewAFromFace function. + +*/ +package face + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type plugin struct { + *generator.Generator + generator.PluginImports +} + +func NewPlugin() *plugin { + return &plugin{} +} + +func (p *plugin) Name() string { + return "face" +} + +func (p *plugin) Init(g *generator.Generator) { + p.Generator = g +} + +func (p *plugin) Generate(file *generator.FileDescriptor) { + p.PluginImports = generator.NewPluginImports(p.Generator) + protoPkg := p.NewImport("github.com/gogo/protobuf/proto") + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + protoPkg = p.NewImport("github.com/golang/protobuf/proto") + } + for _, message := range file.Messages() { + if !gogoproto.IsFace(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + if message.DescriptorProto.HasExtension() { + panic("face does not support message with extensions") + } + if gogoproto.HasGoGetters(file.FileDescriptorProto, message.DescriptorProto) { + panic("face requires getters to be disabled please use gogoproto.getters or gogoproto.getters_all and set it to false") + } + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + p.P(`type `, ccTypeName, `Face interface{`) + p.In() + p.P(`Proto() `, protoPkg.Use(), `.Message`) + for _, field := range message.Field { + fieldname := p.GetFieldName(message, field) + goTyp, _ := p.GoType(message, field) + if p.IsMap(field) { + m := p.GoMapType(nil, field) + goTyp = m.GoType + } + p.P(`Get`, fieldname, `() `, goTyp) + } + p.Out() + p.P(`}`) + p.P(``) + p.P(`func (this *`, ccTypeName, `) Proto() `, protoPkg.Use(), `.Message {`) + p.In() + p.P(`return this`) + p.Out() + p.P(`}`) + p.P(``) + p.P(`func (this *`, ccTypeName, `) TestProto() `, protoPkg.Use(), `.Message {`) + p.In() + p.P(`return New`, ccTypeName, `FromFace(this)`) + p.Out() + p.P(`}`) + p.P(``) + for _, field := range message.Field { + fieldname := p.GetFieldName(message, field) + goTyp, _ := p.GoType(message, field) + if p.IsMap(field) { + m := p.GoMapType(nil, field) + goTyp = m.GoType + } + p.P(`func (this *`, ccTypeName, `) Get`, fieldname, `() `, goTyp, `{`) + p.In() + p.P(` return this.`, fieldname) + p.Out() + p.P(`}`) + p.P(``) + } + p.P(``) + p.P(`func New`, ccTypeName, `FromFace(that `, ccTypeName, `Face) *`, ccTypeName, ` {`) + p.In() + p.P(`this := &`, ccTypeName, `{}`) + for _, field := range message.Field { + fieldname := p.GetFieldName(message, field) + p.P(`this.`, fieldname, ` = that.Get`, fieldname, `()`) + } + p.P(`return this`) + p.Out() + p.P(`}`) + p.P(``) + } +} + +func init() { + generator.RegisterPlugin(NewPlugin()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/face/facetest.go b/vendor/github.com/gogo/protobuf/plugin/face/facetest.go new file mode 100644 index 00000000000..467cc0a6640 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/face/facetest.go @@ -0,0 +1,82 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package face + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/plugin/testgen" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type test struct { + *generator.Generator +} + +func NewTest(g *generator.Generator) testgen.TestPlugin { + return &test{g} +} + +func (p *test) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool { + used := false + randPkg := imports.NewImport("math/rand") + timePkg := imports.NewImport("time") + testingPkg := imports.NewImport("testing") + for _, message := range file.Messages() { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + if !gogoproto.IsFace(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + + if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) { + used = true + + p.P(`func Test`, ccTypeName, `Face(t *`, testingPkg.Use(), `.T) {`) + p.In() + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(`, timePkg.Use(), `.Now().UnixNano()))`) + p.P(`p := NewPopulated`, ccTypeName, `(popr, true)`) + p.P(`msg := p.TestProto()`) + p.P(`if !p.Equal(msg) {`) + p.In() + p.P(`t.Fatalf("%#v !Face Equal %#v", msg, p)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + } + + } + return used +} + +func init() { + testgen.RegisterTestPlugin(NewTest) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/gostring/gostring.go b/vendor/github.com/gogo/protobuf/plugin/gostring/gostring.go new file mode 100644 index 00000000000..2b439469fcb --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/gostring/gostring.go @@ -0,0 +1,386 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The gostring plugin generates a GoString method for each message. +The GoString method is called whenever you use a fmt.Printf as such: + + fmt.Printf("%#v", mymessage) + +or whenever you actually call GoString() +The output produced by the GoString method can be copied from the output into code and used to set a variable. +It is totally valid Go Code and is populated exactly as the struct that was printed out. + +It is enabled by the following extensions: + + - gostring + - gostring_all + +The gostring plugin also generates a test given it is enabled using one of the following extensions: + + - testgen + - testgen_all + +Let us look at: + + github.com/gogo/protobuf/test/example/example.proto + +Btw all the output can be seen at: + + github.com/gogo/protobuf/test/example/* + +The following message: + + option (gogoproto.gostring_all) = true; + + message A { + optional string Description = 1 [(gogoproto.nullable) = false]; + optional int64 Number = 2 [(gogoproto.nullable) = false]; + optional bytes Id = 3 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uuid", (gogoproto.nullable) = false]; + } + +given to the gostring plugin, will generate the following code: + + func (this *A) GoString() string { + if this == nil { + return "nil" + } + s := strings1.Join([]string{`&test.A{` + `Description:` + fmt1.Sprintf("%#v", this.Description), `Number:` + fmt1.Sprintf("%#v", this.Number), `Id:` + fmt1.Sprintf("%#v", this.Id), `XXX_unrecognized:` + fmt1.Sprintf("%#v", this.XXX_unrecognized) + `}`}, ", ") + return s + } + +and the following test code: + + func TestAGoString(t *testing6.T) { + popr := math_rand6.New(math_rand6.NewSource(time6.Now().UnixNano())) + p := NewPopulatedA(popr, false) + s1 := p.GoString() + s2 := fmt2.Sprintf("%#v", p) + if s1 != s2 { + t.Fatalf("GoString want %v got %v", s1, s2) + } + _, err := go_parser.ParseExpr(s1) + if err != nil { + panic(err) + } + } + +Typically fmt.Printf("%#v") will stop to print when it reaches a pointer and +not print their values, while the generated GoString method will always print all values, recursively. + +*/ +package gostring + +import ( + "fmt" + "os" + "strconv" + "strings" + + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type gostring struct { + *generator.Generator + generator.PluginImports + atleastOne bool + localName string + overwrite bool +} + +func NewGoString() *gostring { + return &gostring{} +} + +func (p *gostring) Name() string { + return "gostring" +} + +func (p *gostring) Overwrite() { + p.overwrite = true +} + +func (p *gostring) Init(g *generator.Generator) { + p.Generator = g +} + +func (p *gostring) Generate(file *generator.FileDescriptor) { + proto3 := gogoproto.IsProto3(file.FileDescriptorProto) + p.PluginImports = generator.NewPluginImports(p.Generator) + p.atleastOne = false + + p.localName = generator.FileName(file) + + fmtPkg := p.NewImport("fmt") + stringsPkg := p.NewImport("strings") + protoPkg := p.NewImport("github.com/gogo/protobuf/proto") + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + protoPkg = p.NewImport("github.com/golang/protobuf/proto") + } + sortPkg := p.NewImport("sort") + strconvPkg := p.NewImport("strconv") + reflectPkg := p.NewImport("reflect") + sortKeysPkg := p.NewImport("github.com/gogo/protobuf/sortkeys") + + extensionToGoStringUsed := false + for _, message := range file.Messages() { + if !p.overwrite && !gogoproto.HasGoString(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + p.atleastOne = true + packageName := file.PackageName() + + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + p.P(`func (this *`, ccTypeName, `) GoString() string {`) + p.In() + p.P(`if this == nil {`) + p.In() + p.P(`return "nil"`) + p.Out() + p.P(`}`) + + p.P(`s := make([]string, 0, `, strconv.Itoa(len(message.Field)+4), `)`) + p.P(`s = append(s, "&`, packageName, ".", ccTypeName, `{")`) + + oneofs := make(map[string]struct{}) + for _, field := range message.Field { + nullable := gogoproto.IsNullable(field) + repeated := field.IsRepeated() + fieldname := p.GetFieldName(message, field) + oneof := field.OneofIndex != nil + if oneof { + if _, ok := oneofs[fieldname]; ok { + continue + } else { + oneofs[fieldname] = struct{}{} + } + p.P(`if this.`, fieldname, ` != nil {`) + p.In() + p.P(`s = append(s, "`, fieldname, `: " + `, fmtPkg.Use(), `.Sprintf("%#v", this.`, fieldname, `) + ",\n")`) + p.Out() + p.P(`}`) + } else if p.IsMap(field) { + m := p.GoMapType(nil, field) + mapgoTyp, keyField, keyAliasField := m.GoType, m.KeyField, m.KeyAliasField + keysName := `keysFor` + fieldname + keygoTyp, _ := p.GoType(nil, keyField) + keygoTyp = strings.Replace(keygoTyp, "*", "", 1) + keygoAliasTyp, _ := p.GoType(nil, keyAliasField) + keygoAliasTyp = strings.Replace(keygoAliasTyp, "*", "", 1) + keyCapTyp := generator.CamelCase(keygoTyp) + p.P(keysName, ` := make([]`, keygoTyp, `, 0, len(this.`, fieldname, `))`) + p.P(`for k, _ := range this.`, fieldname, ` {`) + p.In() + if keygoAliasTyp == keygoTyp { + p.P(keysName, ` = append(`, keysName, `, k)`) + } else { + p.P(keysName, ` = append(`, keysName, `, `, keygoTyp, `(k))`) + } + p.Out() + p.P(`}`) + p.P(sortKeysPkg.Use(), `.`, keyCapTyp, `s(`, keysName, `)`) + mapName := `mapStringFor` + fieldname + p.P(mapName, ` := "`, mapgoTyp, `{"`) + p.P(`for _, k := range `, keysName, ` {`) + p.In() + if keygoAliasTyp == keygoTyp { + p.P(mapName, ` += fmt.Sprintf("%#v: %#v,", k, this.`, fieldname, `[k])`) + } else { + p.P(mapName, ` += fmt.Sprintf("%#v: %#v,", k, this.`, fieldname, `[`, keygoAliasTyp, `(k)])`) + } + p.Out() + p.P(`}`) + p.P(mapName, ` += "}"`) + p.P(`if this.`, fieldname, ` != nil {`) + p.In() + p.P(`s = append(s, "`, fieldname, `: " + `, mapName, `+ ",\n")`) + p.Out() + p.P(`}`) + } else if (field.IsMessage() && !gogoproto.IsCustomType(field) && !gogoproto.IsStdTime(field) && !gogoproto.IsStdDuration(field)) || p.IsGroup(field) { + if nullable || repeated { + p.P(`if this.`, fieldname, ` != nil {`) + p.In() + } + if nullable { + p.P(`s = append(s, "`, fieldname, `: " + `, fmtPkg.Use(), `.Sprintf("%#v", this.`, fieldname, `) + ",\n")`) + } else if repeated { + if nullable { + p.P(`s = append(s, "`, fieldname, `: " + `, fmtPkg.Use(), `.Sprintf("%#v", this.`, fieldname, `) + ",\n")`) + } else { + goTyp, _ := p.GoType(message, field) + goTyp = strings.Replace(goTyp, "[]", "", 1) + p.P("vs := make([]*", goTyp, ", len(this.", fieldname, "))") + p.P("for i := range vs {") + p.In() + p.P("vs[i] = &this.", fieldname, "[i]") + p.Out() + p.P("}") + p.P(`s = append(s, "`, fieldname, `: " + `, fmtPkg.Use(), `.Sprintf("%#v", vs) + ",\n")`) + } + } else { + p.P(`s = append(s, "`, fieldname, `: " + `, stringsPkg.Use(), `.Replace(this.`, fieldname, `.GoString()`, ",`&`,``,1)", ` + ",\n")`) + } + if nullable || repeated { + p.Out() + p.P(`}`) + } + } else { + if !proto3 && (nullable || repeated) { + p.P(`if this.`, fieldname, ` != nil {`) + p.In() + } + if field.IsEnum() { + if nullable && !repeated && !proto3 { + goTyp, _ := p.GoType(message, field) + p.P(`s = append(s, "`, fieldname, `: " + valueToGoString`, p.localName, `(this.`, fieldname, `,"`, generator.GoTypeToName(goTyp), `"`, `) + ",\n")`) + } else { + p.P(`s = append(s, "`, fieldname, `: " + `, fmtPkg.Use(), `.Sprintf("%#v", this.`, fieldname, `) + ",\n")`) + } + } else { + if nullable && !repeated && !proto3 { + goTyp, _ := p.GoType(message, field) + p.P(`s = append(s, "`, fieldname, `: " + valueToGoString`, p.localName, `(this.`, fieldname, `,"`, generator.GoTypeToName(goTyp), `"`, `) + ",\n")`) + } else { + p.P(`s = append(s, "`, fieldname, `: " + `, fmtPkg.Use(), `.Sprintf("%#v", this.`, fieldname, `) + ",\n")`) + } + } + if !proto3 && (nullable || repeated) { + p.Out() + p.P(`}`) + } + } + } + if message.DescriptorProto.HasExtension() { + if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`s = append(s, "XXX_InternalExtensions: " + extensionToGoString`, p.localName, `(this) + ",\n")`) + extensionToGoStringUsed = true + } else { + p.P(`if this.XXX_extensions != nil {`) + p.In() + p.P(`s = append(s, "XXX_extensions: " + `, fmtPkg.Use(), `.Sprintf("%#v", this.XXX_extensions) + ",\n")`) + p.Out() + p.P(`}`) + } + } + if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`if this.XXX_unrecognized != nil {`) + p.In() + p.P(`s = append(s, "XXX_unrecognized:" + `, fmtPkg.Use(), `.Sprintf("%#v", this.XXX_unrecognized) + ",\n")`) + p.Out() + p.P(`}`) + } + + p.P(`s = append(s, "}")`) + p.P(`return `, stringsPkg.Use(), `.Join(s, "")`) + p.Out() + p.P(`}`) + + //Generate GoString methods for oneof fields + for _, field := range message.Field { + oneof := field.OneofIndex != nil + if !oneof { + continue + } + ccTypeName := p.OneOfTypeName(message, field) + p.P(`func (this *`, ccTypeName, `) GoString() string {`) + p.In() + p.P(`if this == nil {`) + p.In() + p.P(`return "nil"`) + p.Out() + p.P(`}`) + fieldname := p.GetOneOfFieldName(message, field) + outStr := strings.Join([]string{ + "s := ", + stringsPkg.Use(), ".Join([]string{`&", packageName, ".", ccTypeName, "{` + \n", + "`", fieldname, ":` + ", fmtPkg.Use(), `.Sprintf("%#v", this.`, fieldname, `)`, + " + `}`", + `}`, + `,", "`, + `)`}, "") + p.P(outStr) + p.P(`return s`) + p.Out() + p.P(`}`) + } + } + + if !p.atleastOne { + return + } + + p.P(`func valueToGoString`, p.localName, `(v interface{}, typ string) string {`) + p.In() + p.P(`rv := `, reflectPkg.Use(), `.ValueOf(v)`) + p.P(`if rv.IsNil() {`) + p.In() + p.P(`return "nil"`) + p.Out() + p.P(`}`) + p.P(`pv := `, reflectPkg.Use(), `.Indirect(rv).Interface()`) + p.P(`return `, fmtPkg.Use(), `.Sprintf("func(v %v) *%v { return &v } ( %#v )", typ, typ, pv)`) + p.Out() + p.P(`}`) + + if extensionToGoStringUsed { + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + fmt.Fprintf(os.Stderr, "The GoString plugin for messages with extensions requires importing gogoprotobuf. Please see file %s", file.GetName()) + os.Exit(1) + } + p.P(`func extensionToGoString`, p.localName, `(m `, protoPkg.Use(), `.Message) string {`) + p.In() + p.P(`e := `, protoPkg.Use(), `.GetUnsafeExtensionsMap(m)`) + p.P(`if e == nil { return "nil" }`) + p.P(`s := "proto.NewUnsafeXXX_InternalExtensions(map[int32]proto.Extension{"`) + p.P(`keys := make([]int, 0, len(e))`) + p.P(`for k := range e {`) + p.In() + p.P(`keys = append(keys, int(k))`) + p.Out() + p.P(`}`) + p.P(sortPkg.Use(), `.Ints(keys)`) + p.P(`ss := []string{}`) + p.P(`for _, k := range keys {`) + p.In() + p.P(`ss = append(ss, `, strconvPkg.Use(), `.Itoa(k) + ": " + e[int32(k)].GoString())`) + p.Out() + p.P(`}`) + p.P(`s+=`, stringsPkg.Use(), `.Join(ss, ",") + "})"`) + p.P(`return s`) + p.Out() + p.P(`}`) + } +} + +func init() { + generator.RegisterPlugin(NewGoString()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/gostring/gostringtest.go b/vendor/github.com/gogo/protobuf/plugin/gostring/gostringtest.go new file mode 100644 index 00000000000..c790e590880 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/gostring/gostringtest.go @@ -0,0 +1,90 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package gostring + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/plugin/testgen" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type test struct { + *generator.Generator +} + +func NewTest(g *generator.Generator) testgen.TestPlugin { + return &test{g} +} + +func (p *test) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool { + used := false + randPkg := imports.NewImport("math/rand") + timePkg := imports.NewImport("time") + testingPkg := imports.NewImport("testing") + fmtPkg := imports.NewImport("fmt") + parserPkg := imports.NewImport("go/parser") + for _, message := range file.Messages() { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + if !gogoproto.HasGoString(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + + if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) { + used = true + p.P(`func Test`, ccTypeName, `GoString(t *`, testingPkg.Use(), `.T) {`) + p.In() + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(`, timePkg.Use(), `.Now().UnixNano()))`) + p.P(`p := NewPopulated`, ccTypeName, `(popr, false)`) + p.P(`s1 := p.GoString()`) + p.P(`s2 := `, fmtPkg.Use(), `.Sprintf("%#v", p)`) + p.P(`if s1 != s2 {`) + p.In() + p.P(`t.Fatalf("GoString want %v got %v", s1, s2)`) + p.Out() + p.P(`}`) + p.P(`_, err := `, parserPkg.Use(), `.ParseExpr(s1)`) + p.P(`if err != nil {`) + p.In() + p.P(`t.Fatal(err)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + } + + } + return used +} + +func init() { + testgen.RegisterTestPlugin(NewTest) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/marshalto/marshalto.go b/vendor/github.com/gogo/protobuf/plugin/marshalto/marshalto.go new file mode 100644 index 00000000000..24110cb4431 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/marshalto/marshalto.go @@ -0,0 +1,1205 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The marshalto plugin generates a Marshal and MarshalTo method for each message. +The `Marshal() ([]byte, error)` method results in the fact that the message +implements the Marshaler interface. +This allows proto.Marshal to be faster by calling the generated Marshal method rather than using reflect to Marshal the struct. + +If is enabled by the following extensions: + + - marshaler + - marshaler_all + +Or the following extensions: + + - unsafe_marshaler + - unsafe_marshaler_all + +That is if you want to use the unsafe package in your generated code. +The speed up using the unsafe package is not very significant. + +The generation of marshalling tests are enabled using one of the following extensions: + + - testgen + - testgen_all + +And benchmarks given it is enabled using one of the following extensions: + + - benchgen + - benchgen_all + +Let us look at: + + github.com/gogo/protobuf/test/example/example.proto + +Btw all the output can be seen at: + + github.com/gogo/protobuf/test/example/* + +The following message: + +option (gogoproto.marshaler_all) = true; + +message B { + option (gogoproto.description) = true; + optional A A = 1 [(gogoproto.nullable) = false, (gogoproto.embed) = true]; + repeated bytes G = 2 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uint128", (gogoproto.nullable) = false]; +} + +given to the marshalto plugin, will generate the following code: + + func (m *B) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil + } + + func (m *B) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + dAtA[i] = 0xa + i++ + i = encodeVarintExample(dAtA, i, uint64(m.A.Size())) + n2, err := m.A.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n2 + if len(m.G) > 0 { + for _, msg := range m.G { + dAtA[i] = 0x12 + i++ + i = encodeVarintExample(dAtA, i, uint64(msg.Size())) + n, err := msg.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if m.XXX_unrecognized != nil { + i += copy(dAtA[i:], m.XXX_unrecognized) + } + return i, nil + } + +As shown above Marshal calculates the size of the not yet marshalled message +and allocates the appropriate buffer. +This is followed by calling the MarshalTo method which requires a preallocated buffer. +The MarshalTo method allows a user to rather preallocated a reusable buffer. + +The Size method is generated using the size plugin and the gogoproto.sizer, gogoproto.sizer_all extensions. +The user can also using the generated Size method to check that his reusable buffer is still big enough. + +The generated tests and benchmarks will keep you safe and show that this is really a significant speed improvement. + +An additional message-level option `stable_marshaler` (and the file-level +option `stable_marshaler_all`) exists which causes the generated marshalling +code to behave deterministically. Today, this only changes the serialization of +maps; they are serialized in sort order. +*/ +package marshalto + +import ( + "fmt" + "sort" + "strconv" + "strings" + + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/proto" + descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" + "github.com/gogo/protobuf/vanity" +) + +type NumGen interface { + Next() string + Current() string +} + +type numGen struct { + index int +} + +func NewNumGen() NumGen { + return &numGen{0} +} + +func (this *numGen) Next() string { + this.index++ + return this.Current() +} + +func (this *numGen) Current() string { + return strconv.Itoa(this.index) +} + +type marshalto struct { + *generator.Generator + generator.PluginImports + atleastOne bool + errorsPkg generator.Single + protoPkg generator.Single + sortKeysPkg generator.Single + mathPkg generator.Single + typesPkg generator.Single + binaryPkg generator.Single + localName string +} + +func NewMarshal() *marshalto { + return &marshalto{} +} + +func (p *marshalto) Name() string { + return "marshalto" +} + +func (p *marshalto) Init(g *generator.Generator) { + p.Generator = g +} + +func (p *marshalto) callFixed64(varName ...string) { + p.P(p.binaryPkg.Use(), `.LittleEndian.PutUint64(dAtA[i:], uint64(`, strings.Join(varName, ""), `))`) + p.P(`i += 8`) +} + +func (p *marshalto) callFixed32(varName ...string) { + p.P(p.binaryPkg.Use(), `.LittleEndian.PutUint32(dAtA[i:], uint32(`, strings.Join(varName, ""), `))`) + p.P(`i += 4`) +} + +func (p *marshalto) callVarint(varName ...string) { + p.P(`i = encodeVarint`, p.localName, `(dAtA, i, uint64(`, strings.Join(varName, ""), `))`) +} + +func (p *marshalto) encodeVarint(varName string) { + p.P(`for `, varName, ` >= 1<<7 {`) + p.In() + p.P(`dAtA[i] = uint8(uint64(`, varName, `)&0x7f|0x80)`) + p.P(varName, ` >>= 7`) + p.P(`i++`) + p.Out() + p.P(`}`) + p.P(`dAtA[i] = uint8(`, varName, `)`) + p.P(`i++`) +} + +func (p *marshalto) encodeKey(fieldNumber int32, wireType int) { + x := uint32(fieldNumber)<<3 | uint32(wireType) + i := 0 + keybuf := make([]byte, 0) + for i = 0; x > 127; i++ { + keybuf = append(keybuf, 0x80|uint8(x&0x7F)) + x >>= 7 + } + keybuf = append(keybuf, uint8(x)) + for _, b := range keybuf { + p.P(`dAtA[i] = `, fmt.Sprintf("%#v", b)) + p.P(`i++`) + } +} + +func keySize(fieldNumber int32, wireType int) int { + x := uint32(fieldNumber)<<3 | uint32(wireType) + size := 0 + for size = 0; x > 127; size++ { + x >>= 7 + } + size++ + return size +} + +func wireToType(wire string) int { + switch wire { + case "fixed64": + return proto.WireFixed64 + case "fixed32": + return proto.WireFixed32 + case "varint": + return proto.WireVarint + case "bytes": + return proto.WireBytes + case "group": + return proto.WireBytes + case "zigzag32": + return proto.WireVarint + case "zigzag64": + return proto.WireVarint + } + panic("unreachable") +} + +func (p *marshalto) mapField(numGen NumGen, field *descriptor.FieldDescriptorProto, kvField *descriptor.FieldDescriptorProto, varName string, protoSizer bool) { + switch kvField.GetType() { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: + p.callFixed64(p.mathPkg.Use(), `.Float64bits(float64(`, varName, `))`) + case descriptor.FieldDescriptorProto_TYPE_FLOAT: + p.callFixed32(p.mathPkg.Use(), `.Float32bits(float32(`, varName, `))`) + case descriptor.FieldDescriptorProto_TYPE_INT64, + descriptor.FieldDescriptorProto_TYPE_UINT64, + descriptor.FieldDescriptorProto_TYPE_INT32, + descriptor.FieldDescriptorProto_TYPE_UINT32, + descriptor.FieldDescriptorProto_TYPE_ENUM: + p.callVarint(varName) + case descriptor.FieldDescriptorProto_TYPE_FIXED64, + descriptor.FieldDescriptorProto_TYPE_SFIXED64: + p.callFixed64(varName) + case descriptor.FieldDescriptorProto_TYPE_FIXED32, + descriptor.FieldDescriptorProto_TYPE_SFIXED32: + p.callFixed32(varName) + case descriptor.FieldDescriptorProto_TYPE_BOOL: + p.P(`if `, varName, ` {`) + p.In() + p.P(`dAtA[i] = 1`) + p.Out() + p.P(`} else {`) + p.In() + p.P(`dAtA[i] = 0`) + p.Out() + p.P(`}`) + p.P(`i++`) + case descriptor.FieldDescriptorProto_TYPE_STRING, + descriptor.FieldDescriptorProto_TYPE_BYTES: + if gogoproto.IsCustomType(field) && kvField.IsBytes() { + p.callVarint(varName, `.Size()`) + p.P(`n`, numGen.Next(), `, err := `, varName, `.MarshalTo(dAtA[i:])`) + p.P(`if err != nil {`) + p.In() + p.P(`return 0, err`) + p.Out() + p.P(`}`) + p.P(`i+=n`, numGen.Current()) + } else { + p.callVarint(`len(`, varName, `)`) + p.P(`i+=copy(dAtA[i:], `, varName, `)`) + } + case descriptor.FieldDescriptorProto_TYPE_SINT32: + p.callVarint(`(uint32(`, varName, `) << 1) ^ uint32((`, varName, ` >> 31))`) + case descriptor.FieldDescriptorProto_TYPE_SINT64: + p.callVarint(`(uint64(`, varName, `) << 1) ^ uint64((`, varName, ` >> 63))`) + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + if gogoproto.IsStdTime(field) { + p.callVarint(p.typesPkg.Use(), `.SizeOfStdTime(*`, varName, `)`) + p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdTimeMarshalTo(*`, varName, `, dAtA[i:])`) + } else if gogoproto.IsStdDuration(field) { + p.callVarint(p.typesPkg.Use(), `.SizeOfStdDuration(*`, varName, `)`) + p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdDurationMarshalTo(*`, varName, `, dAtA[i:])`) + } else if protoSizer { + p.callVarint(varName, `.ProtoSize()`) + p.P(`n`, numGen.Next(), `, err := `, varName, `.MarshalTo(dAtA[i:])`) + } else { + p.callVarint(varName, `.Size()`) + p.P(`n`, numGen.Next(), `, err := `, varName, `.MarshalTo(dAtA[i:])`) + } + p.P(`if err != nil {`) + p.In() + p.P(`return 0, err`) + p.Out() + p.P(`}`) + p.P(`i+=n`, numGen.Current()) + } +} + +type orderFields []*descriptor.FieldDescriptorProto + +func (this orderFields) Len() int { + return len(this) +} + +func (this orderFields) Less(i, j int) bool { + return this[i].GetNumber() < this[j].GetNumber() +} + +func (this orderFields) Swap(i, j int) { + this[i], this[j] = this[j], this[i] +} + +func (p *marshalto) generateField(proto3 bool, numGen NumGen, file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto) { + fieldname := p.GetOneOfFieldName(message, field) + nullable := gogoproto.IsNullable(field) + repeated := field.IsRepeated() + required := field.IsRequired() + + protoSizer := gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) + doNilCheck := gogoproto.NeedsNilCheck(proto3, field) + if required && nullable { + p.P(`if m.`, fieldname, `== nil {`) + p.In() + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + p.P(`return 0, new(`, p.protoPkg.Use(), `.RequiredNotSetError)`) + } else { + p.P(`return 0, `, p.protoPkg.Use(), `.NewRequiredNotSetError("`, field.GetName(), `")`) + } + p.Out() + p.P(`} else {`) + } else if repeated { + p.P(`if len(m.`, fieldname, `) > 0 {`) + p.In() + } else if doNilCheck { + p.P(`if m.`, fieldname, ` != nil {`) + p.In() + } + packed := field.IsPacked() || (proto3 && field.IsPacked3()) + wireType := field.WireType() + fieldNumber := field.GetNumber() + if packed { + wireType = proto.WireBytes + } + switch *field.Type { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: + if packed { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`len(m.`, fieldname, `) * 8`) + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + p.P(`f`, numGen.Next(), ` := `, p.mathPkg.Use(), `.Float64bits(float64(num))`) + p.callFixed64("f" + numGen.Current()) + p.Out() + p.P(`}`) + } else if repeated { + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.P(`f`, numGen.Next(), ` := `, p.mathPkg.Use(), `.Float64bits(float64(num))`) + p.callFixed64("f" + numGen.Current()) + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`if m.`, fieldname, ` != 0 {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.callFixed64(p.mathPkg.Use(), `.Float64bits(float64(m.`+fieldname, `))`) + p.Out() + p.P(`}`) + } else if !nullable { + p.encodeKey(fieldNumber, wireType) + p.callFixed64(p.mathPkg.Use(), `.Float64bits(float64(m.`+fieldname, `))`) + } else { + p.encodeKey(fieldNumber, wireType) + p.callFixed64(p.mathPkg.Use(), `.Float64bits(float64(*m.`+fieldname, `))`) + } + case descriptor.FieldDescriptorProto_TYPE_FLOAT: + if packed { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`len(m.`, fieldname, `) * 4`) + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + p.P(`f`, numGen.Next(), ` := `, p.mathPkg.Use(), `.Float32bits(float32(num))`) + p.callFixed32("f" + numGen.Current()) + p.Out() + p.P(`}`) + } else if repeated { + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.P(`f`, numGen.Next(), ` := `, p.mathPkg.Use(), `.Float32bits(float32(num))`) + p.callFixed32("f" + numGen.Current()) + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`if m.`, fieldname, ` != 0 {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.callFixed32(p.mathPkg.Use(), `.Float32bits(float32(m.`+fieldname, `))`) + p.Out() + p.P(`}`) + } else if !nullable { + p.encodeKey(fieldNumber, wireType) + p.callFixed32(p.mathPkg.Use(), `.Float32bits(float32(m.`+fieldname, `))`) + } else { + p.encodeKey(fieldNumber, wireType) + p.callFixed32(p.mathPkg.Use(), `.Float32bits(float32(*m.`+fieldname, `))`) + } + case descriptor.FieldDescriptorProto_TYPE_INT64, + descriptor.FieldDescriptorProto_TYPE_UINT64, + descriptor.FieldDescriptorProto_TYPE_INT32, + descriptor.FieldDescriptorProto_TYPE_UINT32, + descriptor.FieldDescriptorProto_TYPE_ENUM: + if packed { + jvar := "j" + numGen.Next() + p.P(`dAtA`, numGen.Next(), ` := make([]byte, len(m.`, fieldname, `)*10)`) + p.P(`var `, jvar, ` int`) + if *field.Type == descriptor.FieldDescriptorProto_TYPE_INT64 || + *field.Type == descriptor.FieldDescriptorProto_TYPE_INT32 { + p.P(`for _, num1 := range m.`, fieldname, ` {`) + p.In() + p.P(`num := uint64(num1)`) + } else { + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + } + p.P(`for num >= 1<<7 {`) + p.In() + p.P(`dAtA`, numGen.Current(), `[`, jvar, `] = uint8(uint64(num)&0x7f|0x80)`) + p.P(`num >>= 7`) + p.P(jvar, `++`) + p.Out() + p.P(`}`) + p.P(`dAtA`, numGen.Current(), `[`, jvar, `] = uint8(num)`) + p.P(jvar, `++`) + p.Out() + p.P(`}`) + p.encodeKey(fieldNumber, wireType) + p.callVarint(jvar) + p.P(`i += copy(dAtA[i:], dAtA`, numGen.Current(), `[:`, jvar, `])`) + } else if repeated { + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.callVarint("num") + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`if m.`, fieldname, ` != 0 {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.callVarint(`m.`, fieldname) + p.Out() + p.P(`}`) + } else if !nullable { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`m.`, fieldname) + } else { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`*m.`, fieldname) + } + case descriptor.FieldDescriptorProto_TYPE_FIXED64, + descriptor.FieldDescriptorProto_TYPE_SFIXED64: + if packed { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`len(m.`, fieldname, `) * 8`) + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + p.callFixed64("num") + p.Out() + p.P(`}`) + } else if repeated { + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.callFixed64("num") + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`if m.`, fieldname, ` != 0 {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.callFixed64("m." + fieldname) + p.Out() + p.P(`}`) + } else if !nullable { + p.encodeKey(fieldNumber, wireType) + p.callFixed64("m." + fieldname) + } else { + p.encodeKey(fieldNumber, wireType) + p.callFixed64("*m." + fieldname) + } + case descriptor.FieldDescriptorProto_TYPE_FIXED32, + descriptor.FieldDescriptorProto_TYPE_SFIXED32: + if packed { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`len(m.`, fieldname, `) * 4`) + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + p.callFixed32("num") + p.Out() + p.P(`}`) + } else if repeated { + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.callFixed32("num") + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`if m.`, fieldname, ` != 0 {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.callFixed32("m." + fieldname) + p.Out() + p.P(`}`) + } else if !nullable { + p.encodeKey(fieldNumber, wireType) + p.callFixed32("m." + fieldname) + } else { + p.encodeKey(fieldNumber, wireType) + p.callFixed32("*m." + fieldname) + } + case descriptor.FieldDescriptorProto_TYPE_BOOL: + if packed { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`len(m.`, fieldname, `)`) + p.P(`for _, b := range m.`, fieldname, ` {`) + p.In() + p.P(`if b {`) + p.In() + p.P(`dAtA[i] = 1`) + p.Out() + p.P(`} else {`) + p.In() + p.P(`dAtA[i] = 0`) + p.Out() + p.P(`}`) + p.P(`i++`) + p.Out() + p.P(`}`) + } else if repeated { + p.P(`for _, b := range m.`, fieldname, ` {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.P(`if b {`) + p.In() + p.P(`dAtA[i] = 1`) + p.Out() + p.P(`} else {`) + p.In() + p.P(`dAtA[i] = 0`) + p.Out() + p.P(`}`) + p.P(`i++`) + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`if m.`, fieldname, ` {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.P(`if m.`, fieldname, ` {`) + p.In() + p.P(`dAtA[i] = 1`) + p.Out() + p.P(`} else {`) + p.In() + p.P(`dAtA[i] = 0`) + p.Out() + p.P(`}`) + p.P(`i++`) + p.Out() + p.P(`}`) + } else if !nullable { + p.encodeKey(fieldNumber, wireType) + p.P(`if m.`, fieldname, ` {`) + p.In() + p.P(`dAtA[i] = 1`) + p.Out() + p.P(`} else {`) + p.In() + p.P(`dAtA[i] = 0`) + p.Out() + p.P(`}`) + p.P(`i++`) + } else { + p.encodeKey(fieldNumber, wireType) + p.P(`if *m.`, fieldname, ` {`) + p.In() + p.P(`dAtA[i] = 1`) + p.Out() + p.P(`} else {`) + p.In() + p.P(`dAtA[i] = 0`) + p.Out() + p.P(`}`) + p.P(`i++`) + } + case descriptor.FieldDescriptorProto_TYPE_STRING: + if repeated { + p.P(`for _, s := range m.`, fieldname, ` {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.P(`l = len(s)`) + p.encodeVarint("l") + p.P(`i+=copy(dAtA[i:], s)`) + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`if len(m.`, fieldname, `) > 0 {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.callVarint(`len(m.`, fieldname, `)`) + p.P(`i+=copy(dAtA[i:], m.`, fieldname, `)`) + p.Out() + p.P(`}`) + } else if !nullable { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`len(m.`, fieldname, `)`) + p.P(`i+=copy(dAtA[i:], m.`, fieldname, `)`) + } else { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`len(*m.`, fieldname, `)`) + p.P(`i+=copy(dAtA[i:], *m.`, fieldname, `)`) + } + case descriptor.FieldDescriptorProto_TYPE_GROUP: + panic(fmt.Errorf("marshaler does not support group %v", fieldname)) + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + if p.IsMap(field) { + m := p.GoMapType(nil, field) + keygoTyp, keywire := p.GoType(nil, m.KeyField) + keygoAliasTyp, _ := p.GoType(nil, m.KeyAliasField) + // keys may not be pointers + keygoTyp = strings.Replace(keygoTyp, "*", "", 1) + keygoAliasTyp = strings.Replace(keygoAliasTyp, "*", "", 1) + keyCapTyp := generator.CamelCase(keygoTyp) + valuegoTyp, valuewire := p.GoType(nil, m.ValueField) + valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField) + nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp) + keyKeySize := keySize(1, wireToType(keywire)) + valueKeySize := keySize(2, wireToType(valuewire)) + if gogoproto.IsStableMarshaler(file.FileDescriptorProto, message.DescriptorProto) { + keysName := `keysFor` + fieldname + p.P(keysName, ` := make([]`, keygoTyp, `, 0, len(m.`, fieldname, `))`) + p.P(`for k, _ := range m.`, fieldname, ` {`) + p.In() + p.P(keysName, ` = append(`, keysName, `, `, keygoTyp, `(k))`) + p.Out() + p.P(`}`) + p.P(p.sortKeysPkg.Use(), `.`, keyCapTyp, `s(`, keysName, `)`) + p.P(`for _, k := range `, keysName, ` {`) + } else { + p.P(`for k, _ := range m.`, fieldname, ` {`) + } + p.In() + p.encodeKey(fieldNumber, wireType) + sum := []string{strconv.Itoa(keyKeySize)} + switch m.KeyField.GetType() { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE, + descriptor.FieldDescriptorProto_TYPE_FIXED64, + descriptor.FieldDescriptorProto_TYPE_SFIXED64: + sum = append(sum, `8`) + case descriptor.FieldDescriptorProto_TYPE_FLOAT, + descriptor.FieldDescriptorProto_TYPE_FIXED32, + descriptor.FieldDescriptorProto_TYPE_SFIXED32: + sum = append(sum, `4`) + case descriptor.FieldDescriptorProto_TYPE_INT64, + descriptor.FieldDescriptorProto_TYPE_UINT64, + descriptor.FieldDescriptorProto_TYPE_UINT32, + descriptor.FieldDescriptorProto_TYPE_ENUM, + descriptor.FieldDescriptorProto_TYPE_INT32: + sum = append(sum, `sov`+p.localName+`(uint64(k))`) + case descriptor.FieldDescriptorProto_TYPE_BOOL: + sum = append(sum, `1`) + case descriptor.FieldDescriptorProto_TYPE_STRING, + descriptor.FieldDescriptorProto_TYPE_BYTES: + sum = append(sum, `len(k)+sov`+p.localName+`(uint64(len(k)))`) + case descriptor.FieldDescriptorProto_TYPE_SINT32, + descriptor.FieldDescriptorProto_TYPE_SINT64: + sum = append(sum, `soz`+p.localName+`(uint64(k))`) + } + if gogoproto.IsStableMarshaler(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`v := m.`, fieldname, `[`, keygoAliasTyp, `(k)]`) + } else { + p.P(`v := m.`, fieldname, `[k]`) + } + accessor := `v` + switch m.ValueField.GetType() { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE, + descriptor.FieldDescriptorProto_TYPE_FIXED64, + descriptor.FieldDescriptorProto_TYPE_SFIXED64: + sum = append(sum, strconv.Itoa(valueKeySize)) + sum = append(sum, strconv.Itoa(8)) + case descriptor.FieldDescriptorProto_TYPE_FLOAT, + descriptor.FieldDescriptorProto_TYPE_FIXED32, + descriptor.FieldDescriptorProto_TYPE_SFIXED32: + sum = append(sum, strconv.Itoa(valueKeySize)) + sum = append(sum, strconv.Itoa(4)) + case descriptor.FieldDescriptorProto_TYPE_INT64, + descriptor.FieldDescriptorProto_TYPE_UINT64, + descriptor.FieldDescriptorProto_TYPE_UINT32, + descriptor.FieldDescriptorProto_TYPE_ENUM, + descriptor.FieldDescriptorProto_TYPE_INT32: + sum = append(sum, strconv.Itoa(valueKeySize)) + sum = append(sum, `sov`+p.localName+`(uint64(v))`) + case descriptor.FieldDescriptorProto_TYPE_BOOL: + sum = append(sum, strconv.Itoa(valueKeySize)) + sum = append(sum, `1`) + case descriptor.FieldDescriptorProto_TYPE_STRING: + sum = append(sum, strconv.Itoa(valueKeySize)) + sum = append(sum, `len(v)+sov`+p.localName+`(uint64(len(v)))`) + case descriptor.FieldDescriptorProto_TYPE_BYTES: + if gogoproto.IsCustomType(field) { + p.P(`cSize := 0`) + if gogoproto.IsNullable(field) { + p.P(`if `, accessor, ` != nil {`) + p.In() + } + p.P(`cSize = `, accessor, `.Size()`) + p.P(`cSize += `, strconv.Itoa(valueKeySize), ` + sov`+p.localName+`(uint64(cSize))`) + if gogoproto.IsNullable(field) { + p.Out() + p.P(`}`) + } + sum = append(sum, `cSize`) + } else { + p.P(`byteSize := 0`) + if proto3 { + p.P(`if len(v) > 0 {`) + } else { + p.P(`if v != nil {`) + } + p.In() + p.P(`byteSize = `, strconv.Itoa(valueKeySize), ` + len(v)+sov`+p.localName+`(uint64(len(v)))`) + p.Out() + p.P(`}`) + sum = append(sum, `byteSize`) + } + case descriptor.FieldDescriptorProto_TYPE_SINT32, + descriptor.FieldDescriptorProto_TYPE_SINT64: + sum = append(sum, strconv.Itoa(valueKeySize)) + sum = append(sum, `soz`+p.localName+`(uint64(v))`) + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + if valuegoTyp != valuegoAliasTyp && + !gogoproto.IsStdTime(field) && + !gogoproto.IsStdDuration(field) { + if nullable { + // cast back to the type that has the generated methods on it + accessor = `((` + valuegoTyp + `)(` + accessor + `))` + } else { + accessor = `((*` + valuegoTyp + `)(&` + accessor + `))` + } + } else if !nullable { + accessor = `(&v)` + } + p.P(`msgSize := 0`) + p.P(`if `, accessor, ` != nil {`) + p.In() + if gogoproto.IsStdTime(field) { + p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdTime(*`, accessor, `)`) + } else if gogoproto.IsStdDuration(field) { + p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdDuration(*`, accessor, `)`) + } else if protoSizer { + p.P(`msgSize = `, accessor, `.ProtoSize()`) + } else { + p.P(`msgSize = `, accessor, `.Size()`) + } + p.P(`msgSize += `, strconv.Itoa(valueKeySize), ` + sov`+p.localName+`(uint64(msgSize))`) + p.Out() + p.P(`}`) + sum = append(sum, `msgSize`) + } + p.P(`mapSize := `, strings.Join(sum, " + ")) + p.callVarint("mapSize") + p.encodeKey(1, wireToType(keywire)) + p.mapField(numGen, field, m.KeyField, "k", protoSizer) + nullableMsg := nullable && (m.ValueField.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE || + gogoproto.IsCustomType(field) && m.ValueField.IsBytes()) + plainBytes := m.ValueField.IsBytes() && !gogoproto.IsCustomType(field) + if nullableMsg { + p.P(`if `, accessor, ` != nil { `) + p.In() + } else if plainBytes { + if proto3 { + p.P(`if len(`, accessor, `) > 0 {`) + } else { + p.P(`if `, accessor, ` != nil {`) + } + p.In() + } + p.encodeKey(2, wireToType(valuewire)) + p.mapField(numGen, field, m.ValueField, accessor, protoSizer) + if nullableMsg || plainBytes { + p.Out() + p.P(`}`) + } + p.Out() + p.P(`}`) + } else if repeated { + p.P(`for _, msg := range m.`, fieldname, ` {`) + p.In() + p.encodeKey(fieldNumber, wireType) + varName := "msg" + if gogoproto.IsStdTime(field) { + if gogoproto.IsNullable(field) { + varName = "*" + varName + } + p.callVarint(p.typesPkg.Use(), `.SizeOfStdTime(`, varName, `)`) + p.P(`n, err := `, p.typesPkg.Use(), `.StdTimeMarshalTo(`, varName, `, dAtA[i:])`) + } else if gogoproto.IsStdDuration(field) { + if gogoproto.IsNullable(field) { + varName = "*" + varName + } + p.callVarint(p.typesPkg.Use(), `.SizeOfStdDuration(`, varName, `)`) + p.P(`n, err := `, p.typesPkg.Use(), `.StdDurationMarshalTo(`, varName, `, dAtA[i:])`) + } else if protoSizer { + p.callVarint(varName, ".ProtoSize()") + p.P(`n, err := `, varName, `.MarshalTo(dAtA[i:])`) + } else { + p.callVarint(varName, ".Size()") + p.P(`n, err := `, varName, `.MarshalTo(dAtA[i:])`) + } + p.P(`if err != nil {`) + p.In() + p.P(`return 0, err`) + p.Out() + p.P(`}`) + p.P(`i+=n`) + p.Out() + p.P(`}`) + } else { + p.encodeKey(fieldNumber, wireType) + varName := `m.` + fieldname + if gogoproto.IsStdTime(field) { + if gogoproto.IsNullable(field) { + varName = "*" + varName + } + p.callVarint(p.typesPkg.Use(), `.SizeOfStdTime(`, varName, `)`) + p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdTimeMarshalTo(`, varName, `, dAtA[i:])`) + } else if gogoproto.IsStdDuration(field) { + if gogoproto.IsNullable(field) { + varName = "*" + varName + } + p.callVarint(p.typesPkg.Use(), `.SizeOfStdDuration(`, varName, `)`) + p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdDurationMarshalTo(`, varName, `, dAtA[i:])`) + } else if protoSizer { + p.callVarint(varName, `.ProtoSize()`) + p.P(`n`, numGen.Next(), `, err := `, varName, `.MarshalTo(dAtA[i:])`) + } else { + p.callVarint(varName, `.Size()`) + p.P(`n`, numGen.Next(), `, err := `, varName, `.MarshalTo(dAtA[i:])`) + } + p.P(`if err != nil {`) + p.In() + p.P(`return 0, err`) + p.Out() + p.P(`}`) + p.P(`i+=n`, numGen.Current()) + } + case descriptor.FieldDescriptorProto_TYPE_BYTES: + if !gogoproto.IsCustomType(field) { + if repeated { + p.P(`for _, b := range m.`, fieldname, ` {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.callVarint("len(b)") + p.P(`i+=copy(dAtA[i:], b)`) + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`if len(m.`, fieldname, `) > 0 {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.callVarint(`len(m.`, fieldname, `)`) + p.P(`i+=copy(dAtA[i:], m.`, fieldname, `)`) + p.Out() + p.P(`}`) + } else { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`len(m.`, fieldname, `)`) + p.P(`i+=copy(dAtA[i:], m.`, fieldname, `)`) + } + } else { + if repeated { + p.P(`for _, msg := range m.`, fieldname, ` {`) + p.In() + p.encodeKey(fieldNumber, wireType) + if protoSizer { + p.callVarint(`msg.ProtoSize()`) + } else { + p.callVarint(`msg.Size()`) + } + p.P(`n, err := msg.MarshalTo(dAtA[i:])`) + p.P(`if err != nil {`) + p.In() + p.P(`return 0, err`) + p.Out() + p.P(`}`) + p.P(`i+=n`) + p.Out() + p.P(`}`) + } else { + p.encodeKey(fieldNumber, wireType) + if protoSizer { + p.callVarint(`m.`, fieldname, `.ProtoSize()`) + } else { + p.callVarint(`m.`, fieldname, `.Size()`) + } + p.P(`n`, numGen.Next(), `, err := m.`, fieldname, `.MarshalTo(dAtA[i:])`) + p.P(`if err != nil {`) + p.In() + p.P(`return 0, err`) + p.Out() + p.P(`}`) + p.P(`i+=n`, numGen.Current()) + } + } + case descriptor.FieldDescriptorProto_TYPE_SINT32: + if packed { + datavar := "dAtA" + numGen.Next() + jvar := "j" + numGen.Next() + p.P(datavar, ` := make([]byte, len(m.`, fieldname, ")*5)") + p.P(`var `, jvar, ` int`) + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + xvar := "x" + numGen.Next() + p.P(xvar, ` := (uint32(num) << 1) ^ uint32((num >> 31))`) + p.P(`for `, xvar, ` >= 1<<7 {`) + p.In() + p.P(datavar, `[`, jvar, `] = uint8(uint64(`, xvar, `)&0x7f|0x80)`) + p.P(jvar, `++`) + p.P(xvar, ` >>= 7`) + p.Out() + p.P(`}`) + p.P(datavar, `[`, jvar, `] = uint8(`, xvar, `)`) + p.P(jvar, `++`) + p.Out() + p.P(`}`) + p.encodeKey(fieldNumber, wireType) + p.callVarint(jvar) + p.P(`i+=copy(dAtA[i:], `, datavar, `[:`, jvar, `])`) + } else if repeated { + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.P(`x`, numGen.Next(), ` := (uint32(num) << 1) ^ uint32((num >> 31))`) + p.encodeVarint("x" + numGen.Current()) + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`if m.`, fieldname, ` != 0 {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.callVarint(`(uint32(m.`, fieldname, `) << 1) ^ uint32((m.`, fieldname, ` >> 31))`) + p.Out() + p.P(`}`) + } else if !nullable { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`(uint32(m.`, fieldname, `) << 1) ^ uint32((m.`, fieldname, ` >> 31))`) + } else { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`(uint32(*m.`, fieldname, `) << 1) ^ uint32((*m.`, fieldname, ` >> 31))`) + } + case descriptor.FieldDescriptorProto_TYPE_SINT64: + if packed { + jvar := "j" + numGen.Next() + xvar := "x" + numGen.Next() + datavar := "dAtA" + numGen.Next() + p.P(`var `, jvar, ` int`) + p.P(datavar, ` := make([]byte, len(m.`, fieldname, `)*10)`) + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + p.P(xvar, ` := (uint64(num) << 1) ^ uint64((num >> 63))`) + p.P(`for `, xvar, ` >= 1<<7 {`) + p.In() + p.P(datavar, `[`, jvar, `] = uint8(uint64(`, xvar, `)&0x7f|0x80)`) + p.P(jvar, `++`) + p.P(xvar, ` >>= 7`) + p.Out() + p.P(`}`) + p.P(datavar, `[`, jvar, `] = uint8(`, xvar, `)`) + p.P(jvar, `++`) + p.Out() + p.P(`}`) + p.encodeKey(fieldNumber, wireType) + p.callVarint(jvar) + p.P(`i+=copy(dAtA[i:], `, datavar, `[:`, jvar, `])`) + } else if repeated { + p.P(`for _, num := range m.`, fieldname, ` {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.P(`x`, numGen.Next(), ` := (uint64(num) << 1) ^ uint64((num >> 63))`) + p.encodeVarint("x" + numGen.Current()) + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`if m.`, fieldname, ` != 0 {`) + p.In() + p.encodeKey(fieldNumber, wireType) + p.callVarint(`(uint64(m.`, fieldname, `) << 1) ^ uint64((m.`, fieldname, ` >> 63))`) + p.Out() + p.P(`}`) + } else if !nullable { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`(uint64(m.`, fieldname, `) << 1) ^ uint64((m.`, fieldname, ` >> 63))`) + } else { + p.encodeKey(fieldNumber, wireType) + p.callVarint(`(uint64(*m.`, fieldname, `) << 1) ^ uint64((*m.`, fieldname, ` >> 63))`) + } + default: + panic("not implemented") + } + if (required && nullable) || repeated || doNilCheck { + p.Out() + p.P(`}`) + } +} + +func (p *marshalto) Generate(file *generator.FileDescriptor) { + numGen := NewNumGen() + p.PluginImports = generator.NewPluginImports(p.Generator) + + p.atleastOne = false + p.localName = generator.FileName(file) + + p.mathPkg = p.NewImport("math") + p.sortKeysPkg = p.NewImport("github.com/gogo/protobuf/sortkeys") + p.protoPkg = p.NewImport("github.com/gogo/protobuf/proto") + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + p.protoPkg = p.NewImport("github.com/golang/protobuf/proto") + } + p.errorsPkg = p.NewImport("errors") + p.binaryPkg = p.NewImport("encoding/binary") + p.typesPkg = p.NewImport("github.com/gogo/protobuf/types") + + for _, message := range file.Messages() { + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + if !gogoproto.IsMarshaler(file.FileDescriptorProto, message.DescriptorProto) && + !gogoproto.IsUnsafeMarshaler(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + p.atleastOne = true + + p.P(`func (m *`, ccTypeName, `) Marshal() (dAtA []byte, err error) {`) + p.In() + if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`size := m.ProtoSize()`) + } else { + p.P(`size := m.Size()`) + } + p.P(`dAtA = make([]byte, size)`) + p.P(`n, err := m.MarshalTo(dAtA)`) + p.P(`if err != nil {`) + p.In() + p.P(`return nil, err`) + p.Out() + p.P(`}`) + p.P(`return dAtA[:n], nil`) + p.Out() + p.P(`}`) + p.P(``) + p.P(`func (m *`, ccTypeName, `) MarshalTo(dAtA []byte) (int, error) {`) + p.In() + p.P(`var i int`) + p.P(`_ = i`) + p.P(`var l int`) + p.P(`_ = l`) + fields := orderFields(message.GetField()) + sort.Sort(fields) + oneofs := make(map[string]struct{}) + for _, field := range message.Field { + oneof := field.OneofIndex != nil + if !oneof { + proto3 := gogoproto.IsProto3(file.FileDescriptorProto) + p.generateField(proto3, numGen, file, message, field) + } else { + fieldname := p.GetFieldName(message, field) + if _, ok := oneofs[fieldname]; !ok { + oneofs[fieldname] = struct{}{} + p.P(`if m.`, fieldname, ` != nil {`) + p.In() + p.P(`nn`, numGen.Next(), `, err := m.`, fieldname, `.MarshalTo(dAtA[i:])`) + p.P(`if err != nil {`) + p.In() + p.P(`return 0, err`) + p.Out() + p.P(`}`) + p.P(`i+=nn`, numGen.Current()) + p.Out() + p.P(`}`) + } + } + } + if message.DescriptorProto.HasExtension() { + if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`n, err := `, p.protoPkg.Use(), `.EncodeInternalExtension(m, dAtA[i:])`) + p.P(`if err != nil {`) + p.In() + p.P(`return 0, err`) + p.Out() + p.P(`}`) + p.P(`i+=n`) + } else { + p.P(`if m.XXX_extensions != nil {`) + p.In() + p.P(`i+=copy(dAtA[i:], m.XXX_extensions)`) + p.Out() + p.P(`}`) + } + } + if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`if m.XXX_unrecognized != nil {`) + p.In() + p.P(`i+=copy(dAtA[i:], m.XXX_unrecognized)`) + p.Out() + p.P(`}`) + } + + p.P(`return i, nil`) + p.Out() + p.P(`}`) + p.P() + + //Generate MarshalTo methods for oneof fields + m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto) + for _, field := range m.Field { + oneof := field.OneofIndex != nil + if !oneof { + continue + } + ccTypeName := p.OneOfTypeName(message, field) + p.P(`func (m *`, ccTypeName, `) MarshalTo(dAtA []byte) (int, error) {`) + p.In() + p.P(`i := 0`) + vanity.TurnOffNullableForNativeTypes(field) + p.generateField(false, numGen, file, message, field) + p.P(`return i, nil`) + p.Out() + p.P(`}`) + } + } + + if p.atleastOne { + p.P(`func encodeVarint`, p.localName, `(dAtA []byte, offset int, v uint64) int {`) + p.In() + p.P(`for v >= 1<<7 {`) + p.In() + p.P(`dAtA[offset] = uint8(v&0x7f|0x80)`) + p.P(`v >>= 7`) + p.P(`offset++`) + p.Out() + p.P(`}`) + p.P(`dAtA[offset] = uint8(v)`) + p.P(`return offset+1`) + p.Out() + p.P(`}`) + } + +} + +func init() { + generator.RegisterPlugin(NewMarshal()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/oneofcheck/oneofcheck.go b/vendor/github.com/gogo/protobuf/plugin/oneofcheck/oneofcheck.go new file mode 100644 index 00000000000..0f822e8a8ac --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/oneofcheck/oneofcheck.go @@ -0,0 +1,93 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The oneofcheck plugin is used to check whether oneof is not used incorrectly. +For instance: +An error is caused if a oneof field: + - is used in a face + - is an embedded field + +*/ +package oneofcheck + +import ( + "fmt" + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" + "os" +) + +type plugin struct { + *generator.Generator +} + +func NewPlugin() *plugin { + return &plugin{} +} + +func (p *plugin) Name() string { + return "oneofcheck" +} + +func (p *plugin) Init(g *generator.Generator) { + p.Generator = g +} + +func (p *plugin) Generate(file *generator.FileDescriptor) { + for _, msg := range file.Messages() { + face := gogoproto.IsFace(file.FileDescriptorProto, msg.DescriptorProto) + for _, field := range msg.GetField() { + if field.OneofIndex == nil { + continue + } + if face { + fmt.Fprintf(os.Stderr, "ERROR: field %v.%v cannot be in a face and oneof\n", generator.CamelCase(*msg.Name), generator.CamelCase(*field.Name)) + os.Exit(1) + } + if gogoproto.IsEmbed(field) { + fmt.Fprintf(os.Stderr, "ERROR: field %v.%v cannot be in an oneof and an embedded field\n", generator.CamelCase(*msg.Name), generator.CamelCase(*field.Name)) + os.Exit(1) + } + if !gogoproto.IsNullable(field) { + fmt.Fprintf(os.Stderr, "ERROR: field %v.%v cannot be in an oneof and a non-nullable field\n", generator.CamelCase(*msg.Name), generator.CamelCase(*field.Name)) + os.Exit(1) + } + if gogoproto.IsUnion(file.FileDescriptorProto, msg.DescriptorProto) { + fmt.Fprintf(os.Stderr, "ERROR: field %v.%v cannot be in an oneof and in an union (deprecated)\n", generator.CamelCase(*msg.Name), generator.CamelCase(*field.Name)) + os.Exit(1) + } + } + } +} + +func (p *plugin) GenerateImports(*generator.FileDescriptor) {} + +func init() { + generator.RegisterPlugin(NewPlugin()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/populate/populate.go b/vendor/github.com/gogo/protobuf/plugin/populate/populate.go new file mode 100644 index 00000000000..cf61fe9b0ab --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/populate/populate.go @@ -0,0 +1,795 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The populate plugin generates a NewPopulated function. +This function returns a newly populated structure. + +It is enabled by the following extensions: + + - populate + - populate_all + +Let us look at: + + github.com/gogo/protobuf/test/example/example.proto + +Btw all the output can be seen at: + + github.com/gogo/protobuf/test/example/* + +The following message: + + option (gogoproto.populate_all) = true; + + message B { + optional A A = 1 [(gogoproto.nullable) = false, (gogoproto.embed) = true]; + repeated bytes G = 2 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uint128", (gogoproto.nullable) = false]; + } + +given to the populate plugin, will generate code the following code: + + func NewPopulatedB(r randyExample, easy bool) *B { + this := &B{} + v2 := NewPopulatedA(r, easy) + this.A = *v2 + if r.Intn(10) != 0 { + v3 := r.Intn(10) + this.G = make([]github_com_gogo_protobuf_test_custom.Uint128, v3) + for i := 0; i < v3; i++ { + v4 := github_com_gogo_protobuf_test_custom.NewPopulatedUint128(r) + this.G[i] = *v4 + } + } + if !easy && r.Intn(10) != 0 { + this.XXX_unrecognized = randUnrecognizedExample(r, 3) + } + return this + } + +The idea that is useful for testing. +Most of the other plugins' generated test code uses it. +You will still be able to use the generated test code of other packages +if you turn off the popluate plugin and write your own custom NewPopulated function. + +If the easy flag is not set the XXX_unrecognized and XXX_extensions fields are also populated. +These have caused problems with JSON marshalling and unmarshalling tests. + +*/ +package populate + +import ( + "fmt" + "math" + "strconv" + "strings" + + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/proto" + descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" + "github.com/gogo/protobuf/vanity" +) + +type VarGen interface { + Next() string + Current() string +} + +type varGen struct { + index int64 +} + +func NewVarGen() VarGen { + return &varGen{0} +} + +func (this *varGen) Next() string { + this.index++ + return fmt.Sprintf("v%d", this.index) +} + +func (this *varGen) Current() string { + return fmt.Sprintf("v%d", this.index) +} + +type plugin struct { + *generator.Generator + generator.PluginImports + varGen VarGen + atleastOne bool + localName string + typesPkg generator.Single +} + +func NewPlugin() *plugin { + return &plugin{} +} + +func (p *plugin) Name() string { + return "populate" +} + +func (p *plugin) Init(g *generator.Generator) { + p.Generator = g +} + +func value(typeName string, fieldType descriptor.FieldDescriptorProto_Type) string { + switch fieldType { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: + return typeName + "(r.Float64())" + case descriptor.FieldDescriptorProto_TYPE_FLOAT: + return typeName + "(r.Float32())" + case descriptor.FieldDescriptorProto_TYPE_INT64, + descriptor.FieldDescriptorProto_TYPE_SFIXED64, + descriptor.FieldDescriptorProto_TYPE_SINT64: + return typeName + "(r.Int63())" + case descriptor.FieldDescriptorProto_TYPE_UINT64, + descriptor.FieldDescriptorProto_TYPE_FIXED64: + return typeName + "(uint64(r.Uint32()))" + case descriptor.FieldDescriptorProto_TYPE_INT32, + descriptor.FieldDescriptorProto_TYPE_SINT32, + descriptor.FieldDescriptorProto_TYPE_SFIXED32, + descriptor.FieldDescriptorProto_TYPE_ENUM: + return typeName + "(r.Int31())" + case descriptor.FieldDescriptorProto_TYPE_UINT32, + descriptor.FieldDescriptorProto_TYPE_FIXED32: + return typeName + "(r.Uint32())" + case descriptor.FieldDescriptorProto_TYPE_BOOL: + return typeName + `(bool(r.Intn(2) == 0))` + case descriptor.FieldDescriptorProto_TYPE_STRING, + descriptor.FieldDescriptorProto_TYPE_GROUP, + descriptor.FieldDescriptorProto_TYPE_MESSAGE, + descriptor.FieldDescriptorProto_TYPE_BYTES: + } + panic(fmt.Errorf("unexpected type %v", typeName)) +} + +func negative(fieldType descriptor.FieldDescriptorProto_Type) bool { + switch fieldType { + case descriptor.FieldDescriptorProto_TYPE_UINT64, + descriptor.FieldDescriptorProto_TYPE_FIXED64, + descriptor.FieldDescriptorProto_TYPE_UINT32, + descriptor.FieldDescriptorProto_TYPE_FIXED32, + descriptor.FieldDescriptorProto_TYPE_BOOL: + return false + } + return true +} + +func (p *plugin) getFuncName(goTypName string) string { + funcName := "NewPopulated" + goTypName + goTypNames := strings.Split(goTypName, ".") + if len(goTypNames) == 2 { + funcName = goTypNames[0] + ".NewPopulated" + goTypNames[1] + } else if len(goTypNames) != 1 { + panic(fmt.Errorf("unreachable: too many dots in %v", goTypName)) + } + switch funcName { + case "time.NewPopulatedTime": + funcName = p.typesPkg.Use() + ".NewPopulatedStdTime" + case "time.NewPopulatedDuration": + funcName = p.typesPkg.Use() + ".NewPopulatedStdDuration" + } + return funcName +} + +func (p *plugin) getFuncCall(goTypName string) string { + funcName := p.getFuncName(goTypName) + funcCall := funcName + "(r, easy)" + return funcCall +} + +func (p *plugin) getCustomFuncCall(goTypName string) string { + funcName := p.getFuncName(goTypName) + funcCall := funcName + "(r)" + return funcCall +} + +func (p *plugin) getEnumVal(field *descriptor.FieldDescriptorProto, goTyp string) string { + enum := p.ObjectNamed(field.GetTypeName()).(*generator.EnumDescriptor) + l := len(enum.Value) + values := make([]string, l) + for i := range enum.Value { + values[i] = strconv.Itoa(int(*enum.Value[i].Number)) + } + arr := "[]int32{" + strings.Join(values, ",") + "}" + val := strings.Join([]string{generator.GoTypeToName(goTyp), `(`, arr, `[r.Intn(`, fmt.Sprintf("%d", l), `)])`}, "") + return val +} + +func (p *plugin) GenerateField(file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto) { + proto3 := gogoproto.IsProto3(file.FileDescriptorProto) + goTyp, _ := p.GoType(message, field) + fieldname := p.GetOneOfFieldName(message, field) + goTypName := generator.GoTypeToName(goTyp) + if p.IsMap(field) { + m := p.GoMapType(nil, field) + keygoTyp, _ := p.GoType(nil, m.KeyField) + keygoTyp = strings.Replace(keygoTyp, "*", "", 1) + keygoAliasTyp, _ := p.GoType(nil, m.KeyAliasField) + keygoAliasTyp = strings.Replace(keygoAliasTyp, "*", "", 1) + + valuegoTyp, _ := p.GoType(nil, m.ValueField) + valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField) + keytypName := generator.GoTypeToName(keygoTyp) + keygoAliasTyp = generator.GoTypeToName(keygoAliasTyp) + valuetypAliasName := generator.GoTypeToName(valuegoAliasTyp) + + nullable, valuegoTyp, valuegoAliasTyp := generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp) + + p.P(p.varGen.Next(), ` := r.Intn(10)`) + p.P(`this.`, fieldname, ` = make(`, m.GoType, `)`) + p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`) + p.In() + keyval := "" + if m.KeyField.IsString() { + keyval = fmt.Sprintf("randString%v(r)", p.localName) + } else { + keyval = value(keytypName, m.KeyField.GetType()) + } + if keygoAliasTyp != keygoTyp { + keyval = keygoAliasTyp + `(` + keyval + `)` + } + if m.ValueField.IsMessage() || p.IsGroup(field) || + (m.ValueField.IsBytes() && gogoproto.IsCustomType(field)) { + s := `this.` + fieldname + `[` + keyval + `] = ` + if gogoproto.IsStdTime(field) || gogoproto.IsStdDuration(field) { + valuegoTyp = valuegoAliasTyp + } + funcCall := p.getCustomFuncCall(goTypName) + if !gogoproto.IsCustomType(field) { + goTypName = generator.GoTypeToName(valuegoTyp) + funcCall = p.getFuncCall(goTypName) + } + if !nullable { + funcCall = `*` + funcCall + } + if valuegoTyp != valuegoAliasTyp { + funcCall = `(` + valuegoAliasTyp + `)(` + funcCall + `)` + } + s += funcCall + p.P(s) + } else if m.ValueField.IsEnum() { + s := `this.` + fieldname + `[` + keyval + `]` + ` = ` + p.getEnumVal(m.ValueField, valuegoTyp) + p.P(s) + } else if m.ValueField.IsBytes() { + count := p.varGen.Next() + p.P(count, ` := r.Intn(100)`) + p.P(p.varGen.Next(), ` := `, keyval) + p.P(`this.`, fieldname, `[`, p.varGen.Current(), `] = make(`, valuegoTyp, `, `, count, `)`) + p.P(`for i := 0; i < `, count, `; i++ {`) + p.In() + p.P(`this.`, fieldname, `[`, p.varGen.Current(), `][i] = byte(r.Intn(256))`) + p.Out() + p.P(`}`) + } else if m.ValueField.IsString() { + s := `this.` + fieldname + `[` + keyval + `]` + ` = ` + fmt.Sprintf("randString%v(r)", p.localName) + p.P(s) + } else { + p.P(p.varGen.Next(), ` := `, keyval) + p.P(`this.`, fieldname, `[`, p.varGen.Current(), `] = `, value(valuetypAliasName, m.ValueField.GetType())) + if negative(m.ValueField.GetType()) { + p.P(`if r.Intn(2) == 0 {`) + p.In() + p.P(`this.`, fieldname, `[`, p.varGen.Current(), `] *= -1`) + p.Out() + p.P(`}`) + } + } + p.Out() + p.P(`}`) + } else if gogoproto.IsCustomType(field) { + funcCall := p.getCustomFuncCall(goTypName) + if field.IsRepeated() { + p.P(p.varGen.Next(), ` := r.Intn(10)`) + p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`) + p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`) + p.In() + p.P(p.varGen.Next(), `:= `, funcCall) + p.P(`this.`, fieldname, `[i] = *`, p.varGen.Current()) + p.Out() + p.P(`}`) + } else if gogoproto.IsNullable(field) { + p.P(`this.`, fieldname, ` = `, funcCall) + } else { + p.P(p.varGen.Next(), `:= `, funcCall) + p.P(`this.`, fieldname, ` = *`, p.varGen.Current()) + } + } else if field.IsMessage() || p.IsGroup(field) { + funcCall := p.getFuncCall(goTypName) + if field.IsRepeated() { + p.P(p.varGen.Next(), ` := r.Intn(5)`) + p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`) + p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`) + p.In() + if gogoproto.IsNullable(field) { + p.P(`this.`, fieldname, `[i] = `, funcCall) + } else { + p.P(p.varGen.Next(), `:= `, funcCall) + p.P(`this.`, fieldname, `[i] = *`, p.varGen.Current()) + } + p.Out() + p.P(`}`) + } else { + if gogoproto.IsNullable(field) { + p.P(`this.`, fieldname, ` = `, funcCall) + } else { + p.P(p.varGen.Next(), `:= `, funcCall) + p.P(`this.`, fieldname, ` = *`, p.varGen.Current()) + } + } + } else { + if field.IsEnum() { + val := p.getEnumVal(field, goTyp) + if field.IsRepeated() { + p.P(p.varGen.Next(), ` := r.Intn(10)`) + p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`) + p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`) + p.In() + p.P(`this.`, fieldname, `[i] = `, val) + p.Out() + p.P(`}`) + } else if !gogoproto.IsNullable(field) || proto3 { + p.P(`this.`, fieldname, ` = `, val) + } else { + p.P(p.varGen.Next(), ` := `, val) + p.P(`this.`, fieldname, ` = &`, p.varGen.Current()) + } + } else if field.IsBytes() { + if field.IsRepeated() { + p.P(p.varGen.Next(), ` := r.Intn(10)`) + p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`) + p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`) + p.In() + p.P(p.varGen.Next(), ` := r.Intn(100)`) + p.P(`this.`, fieldname, `[i] = make([]byte,`, p.varGen.Current(), `)`) + p.P(`for j := 0; j < `, p.varGen.Current(), `; j++ {`) + p.In() + p.P(`this.`, fieldname, `[i][j] = byte(r.Intn(256))`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + } else { + p.P(p.varGen.Next(), ` := r.Intn(100)`) + p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`) + p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`) + p.In() + p.P(`this.`, fieldname, `[i] = byte(r.Intn(256))`) + p.Out() + p.P(`}`) + } + } else if field.IsString() { + typName := generator.GoTypeToName(goTyp) + val := fmt.Sprintf("%s(randString%v(r))", typName, p.localName) + if field.IsRepeated() { + p.P(p.varGen.Next(), ` := r.Intn(10)`) + p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`) + p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`) + p.In() + p.P(`this.`, fieldname, `[i] = `, val) + p.Out() + p.P(`}`) + } else if !gogoproto.IsNullable(field) || proto3 { + p.P(`this.`, fieldname, ` = `, val) + } else { + p.P(p.varGen.Next(), `:= `, val) + p.P(`this.`, fieldname, ` = &`, p.varGen.Current()) + } + } else { + typName := generator.GoTypeToName(goTyp) + if field.IsRepeated() { + p.P(p.varGen.Next(), ` := r.Intn(10)`) + p.P(`this.`, fieldname, ` = make(`, goTyp, `, `, p.varGen.Current(), `)`) + p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`) + p.In() + p.P(`this.`, fieldname, `[i] = `, value(typName, field.GetType())) + if negative(field.GetType()) { + p.P(`if r.Intn(2) == 0 {`) + p.In() + p.P(`this.`, fieldname, `[i] *= -1`) + p.Out() + p.P(`}`) + } + p.Out() + p.P(`}`) + } else if !gogoproto.IsNullable(field) || proto3 { + p.P(`this.`, fieldname, ` = `, value(typName, field.GetType())) + if negative(field.GetType()) { + p.P(`if r.Intn(2) == 0 {`) + p.In() + p.P(`this.`, fieldname, ` *= -1`) + p.Out() + p.P(`}`) + } + } else { + p.P(p.varGen.Next(), ` := `, value(typName, field.GetType())) + if negative(field.GetType()) { + p.P(`if r.Intn(2) == 0 {`) + p.In() + p.P(p.varGen.Current(), ` *= -1`) + p.Out() + p.P(`}`) + } + p.P(`this.`, fieldname, ` = &`, p.varGen.Current()) + } + } + } +} + +func (p *plugin) hasLoop(pkg string, field *descriptor.FieldDescriptorProto, visited []*generator.Descriptor, excludes []*generator.Descriptor) *generator.Descriptor { + if field.IsMessage() || p.IsGroup(field) || p.IsMap(field) { + var fieldMessage *generator.Descriptor + if p.IsMap(field) { + m := p.GoMapType(nil, field) + if !m.ValueField.IsMessage() { + return nil + } + fieldMessage = p.ObjectNamed(m.ValueField.GetTypeName()).(*generator.Descriptor) + } else { + fieldMessage = p.ObjectNamed(field.GetTypeName()).(*generator.Descriptor) + } + fieldTypeName := generator.CamelCaseSlice(fieldMessage.TypeName()) + for _, message := range visited { + messageTypeName := generator.CamelCaseSlice(message.TypeName()) + if fieldTypeName == messageTypeName { + for _, e := range excludes { + if fieldTypeName == generator.CamelCaseSlice(e.TypeName()) { + return nil + } + } + return fieldMessage + } + } + + for _, f := range fieldMessage.Field { + if strings.HasPrefix(f.GetTypeName(), "."+pkg) { + visited = append(visited, fieldMessage) + loopTo := p.hasLoop(pkg, f, visited, excludes) + if loopTo != nil { + return loopTo + } + } + } + } + return nil +} + +func (p *plugin) loops(pkg string, field *descriptor.FieldDescriptorProto, message *generator.Descriptor) int { + //fmt.Fprintf(os.Stderr, "loops %v %v\n", field.GetTypeName(), generator.CamelCaseSlice(message.TypeName())) + excludes := []*generator.Descriptor{} + loops := 0 + for { + visited := []*generator.Descriptor{} + loopTo := p.hasLoop(pkg, field, visited, excludes) + if loopTo == nil { + break + } + //fmt.Fprintf(os.Stderr, "loopTo %v\n", generator.CamelCaseSlice(loopTo.TypeName())) + excludes = append(excludes, loopTo) + loops++ + } + return loops +} + +func (p *plugin) Generate(file *generator.FileDescriptor) { + p.atleastOne = false + p.PluginImports = generator.NewPluginImports(p.Generator) + p.varGen = NewVarGen() + proto3 := gogoproto.IsProto3(file.FileDescriptorProto) + p.typesPkg = p.NewImport("github.com/gogo/protobuf/types") + p.localName = generator.FileName(file) + protoPkg := p.NewImport("github.com/gogo/protobuf/proto") + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + protoPkg = p.NewImport("github.com/golang/protobuf/proto") + } + + for _, message := range file.Messages() { + if !gogoproto.HasPopulate(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + p.atleastOne = true + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + loopLevels := make([]int, len(message.Field)) + maxLoopLevel := 0 + for i, field := range message.Field { + loopLevels[i] = p.loops(file.GetPackage(), field, message) + if loopLevels[i] > maxLoopLevel { + maxLoopLevel = loopLevels[i] + } + } + ranTotal := 0 + for i := range loopLevels { + ranTotal += int(math.Pow10(maxLoopLevel - loopLevels[i])) + } + p.P(`func NewPopulated`, ccTypeName, `(r randy`, p.localName, `, easy bool) *`, ccTypeName, ` {`) + p.In() + p.P(`this := &`, ccTypeName, `{}`) + if gogoproto.IsUnion(message.File(), message.DescriptorProto) && len(message.Field) > 0 { + p.P(`fieldNum := r.Intn(`, fmt.Sprintf("%d", ranTotal), `)`) + p.P(`switch fieldNum {`) + k := 0 + for i, field := range message.Field { + is := []string{} + ran := int(math.Pow10(maxLoopLevel - loopLevels[i])) + for j := 0; j < ran; j++ { + is = append(is, fmt.Sprintf("%d", j+k)) + } + k += ran + p.P(`case `, strings.Join(is, ","), `:`) + p.In() + p.GenerateField(file, message, field) + p.Out() + } + p.P(`}`) + } else { + var maxFieldNumber int32 + oneofs := make(map[string]struct{}) + for fieldIndex, field := range message.Field { + if field.GetNumber() > maxFieldNumber { + maxFieldNumber = field.GetNumber() + } + oneof := field.OneofIndex != nil + if !oneof { + if field.IsRequired() || (!gogoproto.IsNullable(field) && !field.IsRepeated()) || (proto3 && !field.IsMessage()) { + p.GenerateField(file, message, field) + } else { + if loopLevels[fieldIndex] > 0 { + p.P(`if r.Intn(10) == 0 {`) + } else { + p.P(`if r.Intn(10) != 0 {`) + } + p.In() + p.GenerateField(file, message, field) + p.Out() + p.P(`}`) + } + } else { + fieldname := p.GetFieldName(message, field) + if _, ok := oneofs[fieldname]; ok { + continue + } else { + oneofs[fieldname] = struct{}{} + } + fieldNumbers := []int32{} + for _, f := range message.Field { + fname := p.GetFieldName(message, f) + if fname == fieldname { + fieldNumbers = append(fieldNumbers, f.GetNumber()) + } + } + + p.P(`oneofNumber_`, fieldname, ` := `, fmt.Sprintf("%#v", fieldNumbers), `[r.Intn(`, strconv.Itoa(len(fieldNumbers)), `)]`) + p.P(`switch oneofNumber_`, fieldname, ` {`) + for _, f := range message.Field { + fname := p.GetFieldName(message, f) + if fname != fieldname { + continue + } + p.P(`case `, strconv.Itoa(int(f.GetNumber())), `:`) + p.In() + ccTypeName := p.OneOfTypeName(message, f) + p.P(`this.`, fname, ` = NewPopulated`, ccTypeName, `(r, easy)`) + p.Out() + } + p.P(`}`) + } + } + if message.DescriptorProto.HasExtension() { + p.P(`if !easy && r.Intn(10) != 0 {`) + p.In() + p.P(`l := r.Intn(5)`) + p.P(`for i := 0; i < l; i++ {`) + p.In() + if len(message.DescriptorProto.GetExtensionRange()) > 1 { + p.P(`eIndex := r.Intn(`, strconv.Itoa(len(message.DescriptorProto.GetExtensionRange())), `)`) + p.P(`fieldNumber := 0`) + p.P(`switch eIndex {`) + for i, e := range message.DescriptorProto.GetExtensionRange() { + p.P(`case `, strconv.Itoa(i), `:`) + p.In() + p.P(`fieldNumber = r.Intn(`, strconv.Itoa(int(e.GetEnd()-e.GetStart())), `) + `, strconv.Itoa(int(e.GetStart()))) + p.Out() + if e.GetEnd() > maxFieldNumber { + maxFieldNumber = e.GetEnd() + } + } + p.P(`}`) + } else { + e := message.DescriptorProto.GetExtensionRange()[0] + p.P(`fieldNumber := r.Intn(`, strconv.Itoa(int(e.GetEnd()-e.GetStart())), `) + `, strconv.Itoa(int(e.GetStart()))) + if e.GetEnd() > maxFieldNumber { + maxFieldNumber = e.GetEnd() + } + } + p.P(`wire := r.Intn(4)`) + p.P(`if wire == 3 { wire = 5 }`) + p.P(`dAtA := randField`, p.localName, `(nil, r, fieldNumber, wire)`) + p.P(protoPkg.Use(), `.SetRawExtension(this, int32(fieldNumber), dAtA)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + } + + if maxFieldNumber < (1 << 10) { + p.P(`if !easy && r.Intn(10) != 0 {`) + p.In() + if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`this.XXX_unrecognized = randUnrecognized`, p.localName, `(r, `, strconv.Itoa(int(maxFieldNumber+1)), `)`) + } + p.Out() + p.P(`}`) + } + } + p.P(`return this`) + p.Out() + p.P(`}`) + p.P(``) + + //Generate NewPopulated functions for oneof fields + m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto) + for _, f := range m.Field { + oneof := f.OneofIndex != nil + if !oneof { + continue + } + ccTypeName := p.OneOfTypeName(message, f) + p.P(`func NewPopulated`, ccTypeName, `(r randy`, p.localName, `, easy bool) *`, ccTypeName, ` {`) + p.In() + p.P(`this := &`, ccTypeName, `{}`) + vanity.TurnOffNullableForNativeTypes(f) + p.GenerateField(file, message, f) + p.P(`return this`) + p.Out() + p.P(`}`) + } + } + + if !p.atleastOne { + return + } + + p.P(`type randy`, p.localName, ` interface {`) + p.In() + p.P(`Float32() float32`) + p.P(`Float64() float64`) + p.P(`Int63() int64`) + p.P(`Int31() int32`) + p.P(`Uint32() uint32`) + p.P(`Intn(n int) int`) + p.Out() + p.P(`}`) + + p.P(`func randUTF8Rune`, p.localName, `(r randy`, p.localName, `) rune {`) + p.In() + p.P(`ru := r.Intn(62)`) + p.P(`if ru < 10 {`) + p.In() + p.P(`return rune(ru+48)`) + p.Out() + p.P(`} else if ru < 36 {`) + p.In() + p.P(`return rune(ru+55)`) + p.Out() + p.P(`}`) + p.P(`return rune(ru+61)`) + p.Out() + p.P(`}`) + + p.P(`func randString`, p.localName, `(r randy`, p.localName, `) string {`) + p.In() + p.P(p.varGen.Next(), ` := r.Intn(100)`) + p.P(`tmps := make([]rune, `, p.varGen.Current(), `)`) + p.P(`for i := 0; i < `, p.varGen.Current(), `; i++ {`) + p.In() + p.P(`tmps[i] = randUTF8Rune`, p.localName, `(r)`) + p.Out() + p.P(`}`) + p.P(`return string(tmps)`) + p.Out() + p.P(`}`) + + p.P(`func randUnrecognized`, p.localName, `(r randy`, p.localName, `, maxFieldNumber int) (dAtA []byte) {`) + p.In() + p.P(`l := r.Intn(5)`) + p.P(`for i := 0; i < l; i++ {`) + p.In() + p.P(`wire := r.Intn(4)`) + p.P(`if wire == 3 { wire = 5 }`) + p.P(`fieldNumber := maxFieldNumber + r.Intn(100)`) + p.P(`dAtA = randField`, p.localName, `(dAtA, r, fieldNumber, wire)`) + p.Out() + p.P(`}`) + p.P(`return dAtA`) + p.Out() + p.P(`}`) + + p.P(`func randField`, p.localName, `(dAtA []byte, r randy`, p.localName, `, fieldNumber int, wire int) []byte {`) + p.In() + p.P(`key := uint32(fieldNumber)<<3 | uint32(wire)`) + p.P(`switch wire {`) + p.P(`case 0:`) + p.In() + p.P(`dAtA = encodeVarintPopulate`, p.localName, `(dAtA, uint64(key))`) + p.P(p.varGen.Next(), ` := r.Int63()`) + p.P(`if r.Intn(2) == 0 {`) + p.In() + p.P(p.varGen.Current(), ` *= -1`) + p.Out() + p.P(`}`) + p.P(`dAtA = encodeVarintPopulate`, p.localName, `(dAtA, uint64(`, p.varGen.Current(), `))`) + p.Out() + p.P(`case 1:`) + p.In() + p.P(`dAtA = encodeVarintPopulate`, p.localName, `(dAtA, uint64(key))`) + p.P(`dAtA = append(dAtA, byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)))`) + p.Out() + p.P(`case 2:`) + p.In() + p.P(`dAtA = encodeVarintPopulate`, p.localName, `(dAtA, uint64(key))`) + p.P(`ll := r.Intn(100)`) + p.P(`dAtA = encodeVarintPopulate`, p.localName, `(dAtA, uint64(ll))`) + p.P(`for j := 0; j < ll; j++ {`) + p.In() + p.P(`dAtA = append(dAtA, byte(r.Intn(256)))`) + p.Out() + p.P(`}`) + p.Out() + p.P(`default:`) + p.In() + p.P(`dAtA = encodeVarintPopulate`, p.localName, `(dAtA, uint64(key))`) + p.P(`dAtA = append(dAtA, byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)), byte(r.Intn(256)))`) + p.Out() + p.P(`}`) + p.P(`return dAtA`) + p.Out() + p.P(`}`) + + p.P(`func encodeVarintPopulate`, p.localName, `(dAtA []byte, v uint64) []byte {`) + p.In() + p.P(`for v >= 1<<7 {`) + p.In() + p.P(`dAtA = append(dAtA, uint8(uint64(v)&0x7f|0x80))`) + p.P(`v >>= 7`) + p.Out() + p.P(`}`) + p.P(`dAtA = append(dAtA, uint8(v))`) + p.P(`return dAtA`) + p.Out() + p.P(`}`) + +} + +func init() { + generator.RegisterPlugin(NewPlugin()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/size/size.go b/vendor/github.com/gogo/protobuf/plugin/size/size.go new file mode 100644 index 00000000000..79cd403be15 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/size/size.go @@ -0,0 +1,674 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The size plugin generates a Size or ProtoSize method for each message. +This is useful with the MarshalTo method generated by the marshalto plugin and the +gogoproto.marshaler and gogoproto.marshaler_all extensions. + +It is enabled by the following extensions: + + - sizer + - sizer_all + - protosizer + - protosizer_all + +The size plugin also generates a test given it is enabled using one of the following extensions: + + - testgen + - testgen_all + +And a benchmark given it is enabled using one of the following extensions: + + - benchgen + - benchgen_all + +Let us look at: + + github.com/gogo/protobuf/test/example/example.proto + +Btw all the output can be seen at: + + github.com/gogo/protobuf/test/example/* + +The following message: + + option (gogoproto.sizer_all) = true; + + message B { + option (gogoproto.description) = true; + optional A A = 1 [(gogoproto.nullable) = false, (gogoproto.embed) = true]; + repeated bytes G = 2 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uint128", (gogoproto.nullable) = false]; + } + +given to the size plugin, will generate the following code: + + func (m *B) Size() (n int) { + var l int + _ = l + l = m.A.Size() + n += 1 + l + sovExample(uint64(l)) + if len(m.G) > 0 { + for _, e := range m.G { + l = e.Size() + n += 1 + l + sovExample(uint64(l)) + } + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n + } + +and the following test code: + + func TestBSize(t *testing5.T) { + popr := math_rand5.New(math_rand5.NewSource(time5.Now().UnixNano())) + p := NewPopulatedB(popr, true) + dAtA, err := github_com_gogo_protobuf_proto2.Marshal(p) + if err != nil { + panic(err) + } + size := p.Size() + if len(dAtA) != size { + t.Fatalf("size %v != marshalled size %v", size, len(dAtA)) + } + } + + func BenchmarkBSize(b *testing5.B) { + popr := math_rand5.New(math_rand5.NewSource(616)) + total := 0 + pops := make([]*B, 1000) + for i := 0; i < 1000; i++ { + pops[i] = NewPopulatedB(popr, false) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + total += pops[i%1000].Size() + } + b.SetBytes(int64(total / b.N)) + } + +The sovExample function is a size of varint function for the example.pb.go file. + +*/ +package size + +import ( + "fmt" + "os" + "strconv" + "strings" + + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/proto" + descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" + "github.com/gogo/protobuf/vanity" +) + +type size struct { + *generator.Generator + generator.PluginImports + atleastOne bool + localName string + typesPkg generator.Single +} + +func NewSize() *size { + return &size{} +} + +func (p *size) Name() string { + return "size" +} + +func (p *size) Init(g *generator.Generator) { + p.Generator = g +} + +func wireToType(wire string) int { + switch wire { + case "fixed64": + return proto.WireFixed64 + case "fixed32": + return proto.WireFixed32 + case "varint": + return proto.WireVarint + case "bytes": + return proto.WireBytes + case "group": + return proto.WireBytes + case "zigzag32": + return proto.WireVarint + case "zigzag64": + return proto.WireVarint + } + panic("unreachable") +} + +func keySize(fieldNumber int32, wireType int) int { + x := uint32(fieldNumber)<<3 | uint32(wireType) + size := 0 + for size = 0; x > 127; size++ { + x >>= 7 + } + size++ + return size +} + +func (p *size) sizeVarint() { + p.P(` + func sov`, p.localName, `(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n + }`) +} + +func (p *size) sizeZigZag() { + p.P(`func soz`, p.localName, `(x uint64) (n int) { + return sov`, p.localName, `(uint64((x << 1) ^ uint64((int64(x) >> 63)))) + }`) +} + +func (p *size) std(field *descriptor.FieldDescriptorProto, name string) (string, bool) { + if gogoproto.IsStdTime(field) { + if gogoproto.IsNullable(field) { + return p.typesPkg.Use() + `.SizeOfStdTime(*` + name + `)`, true + } else { + return p.typesPkg.Use() + `.SizeOfStdTime(` + name + `)`, true + } + } else if gogoproto.IsStdDuration(field) { + if gogoproto.IsNullable(field) { + return p.typesPkg.Use() + `.SizeOfStdDuration(*` + name + `)`, true + } else { + return p.typesPkg.Use() + `.SizeOfStdDuration(` + name + `)`, true + } + } + return "", false +} + +func (p *size) generateField(proto3 bool, file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto, sizeName string) { + fieldname := p.GetOneOfFieldName(message, field) + nullable := gogoproto.IsNullable(field) + repeated := field.IsRepeated() + doNilCheck := gogoproto.NeedsNilCheck(proto3, field) + if repeated { + p.P(`if len(m.`, fieldname, `) > 0 {`) + p.In() + } else if doNilCheck { + p.P(`if m.`, fieldname, ` != nil {`) + p.In() + } + packed := field.IsPacked() || (proto3 && field.IsPacked3()) + _, wire := p.GoType(message, field) + wireType := wireToType(wire) + fieldNumber := field.GetNumber() + if packed { + wireType = proto.WireBytes + } + key := keySize(fieldNumber, wireType) + switch *field.Type { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE, + descriptor.FieldDescriptorProto_TYPE_FIXED64, + descriptor.FieldDescriptorProto_TYPE_SFIXED64: + if packed { + p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(len(m.`, fieldname, `)*8))`, `+len(m.`, fieldname, `)*8`) + } else if repeated { + p.P(`n+=`, strconv.Itoa(key+8), `*len(m.`, fieldname, `)`) + } else if proto3 { + p.P(`if m.`, fieldname, ` != 0 {`) + p.In() + p.P(`n+=`, strconv.Itoa(key+8)) + p.Out() + p.P(`}`) + } else if nullable { + p.P(`n+=`, strconv.Itoa(key+8)) + } else { + p.P(`n+=`, strconv.Itoa(key+8)) + } + case descriptor.FieldDescriptorProto_TYPE_FLOAT, + descriptor.FieldDescriptorProto_TYPE_FIXED32, + descriptor.FieldDescriptorProto_TYPE_SFIXED32: + if packed { + p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(len(m.`, fieldname, `)*4))`, `+len(m.`, fieldname, `)*4`) + } else if repeated { + p.P(`n+=`, strconv.Itoa(key+4), `*len(m.`, fieldname, `)`) + } else if proto3 { + p.P(`if m.`, fieldname, ` != 0 {`) + p.In() + p.P(`n+=`, strconv.Itoa(key+4)) + p.Out() + p.P(`}`) + } else if nullable { + p.P(`n+=`, strconv.Itoa(key+4)) + } else { + p.P(`n+=`, strconv.Itoa(key+4)) + } + case descriptor.FieldDescriptorProto_TYPE_INT64, + descriptor.FieldDescriptorProto_TYPE_UINT64, + descriptor.FieldDescriptorProto_TYPE_UINT32, + descriptor.FieldDescriptorProto_TYPE_ENUM, + descriptor.FieldDescriptorProto_TYPE_INT32: + if packed { + p.P(`l = 0`) + p.P(`for _, e := range m.`, fieldname, ` {`) + p.In() + p.P(`l+=sov`, p.localName, `(uint64(e))`) + p.Out() + p.P(`}`) + p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(l))+l`) + } else if repeated { + p.P(`for _, e := range m.`, fieldname, ` {`) + p.In() + p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(e))`) + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`if m.`, fieldname, ` != 0 {`) + p.In() + p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(m.`, fieldname, `))`) + p.Out() + p.P(`}`) + } else if nullable { + p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(*m.`, fieldname, `))`) + } else { + p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(m.`, fieldname, `))`) + } + case descriptor.FieldDescriptorProto_TYPE_BOOL: + if packed { + p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(len(m.`, fieldname, `)))`, `+len(m.`, fieldname, `)*1`) + } else if repeated { + p.P(`n+=`, strconv.Itoa(key+1), `*len(m.`, fieldname, `)`) + } else if proto3 { + p.P(`if m.`, fieldname, ` {`) + p.In() + p.P(`n+=`, strconv.Itoa(key+1)) + p.Out() + p.P(`}`) + } else if nullable { + p.P(`n+=`, strconv.Itoa(key+1)) + } else { + p.P(`n+=`, strconv.Itoa(key+1)) + } + case descriptor.FieldDescriptorProto_TYPE_STRING: + if repeated { + p.P(`for _, s := range m.`, fieldname, ` { `) + p.In() + p.P(`l = len(s)`) + p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`) + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`l=len(m.`, fieldname, `)`) + p.P(`if l > 0 {`) + p.In() + p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`) + p.Out() + p.P(`}`) + } else if nullable { + p.P(`l=len(*m.`, fieldname, `)`) + p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`) + } else { + p.P(`l=len(m.`, fieldname, `)`) + p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`) + } + case descriptor.FieldDescriptorProto_TYPE_GROUP: + panic(fmt.Errorf("size does not support group %v", fieldname)) + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + if p.IsMap(field) { + m := p.GoMapType(nil, field) + _, keywire := p.GoType(nil, m.KeyAliasField) + valuegoTyp, _ := p.GoType(nil, m.ValueField) + valuegoAliasTyp, valuewire := p.GoType(nil, m.ValueAliasField) + _, fieldwire := p.GoType(nil, field) + + nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp) + + fieldKeySize := keySize(field.GetNumber(), wireToType(fieldwire)) + keyKeySize := keySize(1, wireToType(keywire)) + valueKeySize := keySize(2, wireToType(valuewire)) + p.P(`for k, v := range m.`, fieldname, ` { `) + p.In() + p.P(`_ = k`) + p.P(`_ = v`) + sum := []string{strconv.Itoa(keyKeySize)} + switch m.KeyField.GetType() { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE, + descriptor.FieldDescriptorProto_TYPE_FIXED64, + descriptor.FieldDescriptorProto_TYPE_SFIXED64: + sum = append(sum, `8`) + case descriptor.FieldDescriptorProto_TYPE_FLOAT, + descriptor.FieldDescriptorProto_TYPE_FIXED32, + descriptor.FieldDescriptorProto_TYPE_SFIXED32: + sum = append(sum, `4`) + case descriptor.FieldDescriptorProto_TYPE_INT64, + descriptor.FieldDescriptorProto_TYPE_UINT64, + descriptor.FieldDescriptorProto_TYPE_UINT32, + descriptor.FieldDescriptorProto_TYPE_ENUM, + descriptor.FieldDescriptorProto_TYPE_INT32: + sum = append(sum, `sov`+p.localName+`(uint64(k))`) + case descriptor.FieldDescriptorProto_TYPE_BOOL: + sum = append(sum, `1`) + case descriptor.FieldDescriptorProto_TYPE_STRING, + descriptor.FieldDescriptorProto_TYPE_BYTES: + sum = append(sum, `len(k)+sov`+p.localName+`(uint64(len(k)))`) + case descriptor.FieldDescriptorProto_TYPE_SINT32, + descriptor.FieldDescriptorProto_TYPE_SINT64: + sum = append(sum, `soz`+p.localName+`(uint64(k))`) + } + switch m.ValueField.GetType() { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE, + descriptor.FieldDescriptorProto_TYPE_FIXED64, + descriptor.FieldDescriptorProto_TYPE_SFIXED64: + sum = append(sum, strconv.Itoa(valueKeySize)) + sum = append(sum, strconv.Itoa(8)) + case descriptor.FieldDescriptorProto_TYPE_FLOAT, + descriptor.FieldDescriptorProto_TYPE_FIXED32, + descriptor.FieldDescriptorProto_TYPE_SFIXED32: + sum = append(sum, strconv.Itoa(valueKeySize)) + sum = append(sum, strconv.Itoa(4)) + case descriptor.FieldDescriptorProto_TYPE_INT64, + descriptor.FieldDescriptorProto_TYPE_UINT64, + descriptor.FieldDescriptorProto_TYPE_UINT32, + descriptor.FieldDescriptorProto_TYPE_ENUM, + descriptor.FieldDescriptorProto_TYPE_INT32: + sum = append(sum, strconv.Itoa(valueKeySize)) + sum = append(sum, `sov`+p.localName+`(uint64(v))`) + case descriptor.FieldDescriptorProto_TYPE_BOOL: + sum = append(sum, strconv.Itoa(valueKeySize)) + sum = append(sum, `1`) + case descriptor.FieldDescriptorProto_TYPE_STRING: + sum = append(sum, strconv.Itoa(valueKeySize)) + sum = append(sum, `len(v)+sov`+p.localName+`(uint64(len(v)))`) + case descriptor.FieldDescriptorProto_TYPE_BYTES: + if gogoproto.IsCustomType(field) { + p.P(`l = 0`) + if nullable { + p.P(`if v != nil {`) + p.In() + } + p.P(`l = v.`, sizeName, `()`) + p.P(`l += `, strconv.Itoa(valueKeySize), ` + sov`+p.localName+`(uint64(l))`) + if nullable { + p.Out() + p.P(`}`) + } + sum = append(sum, `l`) + } else { + p.P(`l = 0`) + if proto3 { + p.P(`if len(v) > 0 {`) + } else { + p.P(`if v != nil {`) + } + p.In() + p.P(`l = `, strconv.Itoa(valueKeySize), ` + len(v)+sov`+p.localName+`(uint64(len(v)))`) + p.Out() + p.P(`}`) + sum = append(sum, `l`) + } + case descriptor.FieldDescriptorProto_TYPE_SINT32, + descriptor.FieldDescriptorProto_TYPE_SINT64: + sum = append(sum, strconv.Itoa(valueKeySize)) + sum = append(sum, `soz`+p.localName+`(uint64(v))`) + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + stdSizeCall, stdOk := p.std(field, "v") + if nullable { + p.P(`l = 0`) + p.P(`if v != nil {`) + p.In() + if stdOk { + p.P(`l = `, stdSizeCall) + } else if valuegoTyp != valuegoAliasTyp { + p.P(`l = ((`, valuegoTyp, `)(v)).`, sizeName, `()`) + } else { + p.P(`l = v.`, sizeName, `()`) + } + p.P(`l += `, strconv.Itoa(valueKeySize), ` + sov`+p.localName+`(uint64(l))`) + p.Out() + p.P(`}`) + sum = append(sum, `l`) + } else { + if stdOk { + p.P(`l = `, stdSizeCall) + } else if valuegoTyp != valuegoAliasTyp { + p.P(`l = ((*`, valuegoTyp, `)(&v)).`, sizeName, `()`) + } else { + p.P(`l = v.`, sizeName, `()`) + } + sum = append(sum, strconv.Itoa(valueKeySize)) + sum = append(sum, `l+sov`+p.localName+`(uint64(l))`) + } + } + p.P(`mapEntrySize := `, strings.Join(sum, "+")) + p.P(`n+=mapEntrySize+`, fieldKeySize, `+sov`, p.localName, `(uint64(mapEntrySize))`) + p.Out() + p.P(`}`) + } else if repeated { + p.P(`for _, e := range m.`, fieldname, ` { `) + p.In() + stdSizeCall, stdOk := p.std(field, "e") + if stdOk { + p.P(`l=`, stdSizeCall) + } else { + p.P(`l=e.`, sizeName, `()`) + } + p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`) + p.Out() + p.P(`}`) + } else { + stdSizeCall, stdOk := p.std(field, "m."+fieldname) + if stdOk { + p.P(`l=`, stdSizeCall) + } else { + p.P(`l=m.`, fieldname, `.`, sizeName, `()`) + } + p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`) + } + case descriptor.FieldDescriptorProto_TYPE_BYTES: + if !gogoproto.IsCustomType(field) { + if repeated { + p.P(`for _, b := range m.`, fieldname, ` { `) + p.In() + p.P(`l = len(b)`) + p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`) + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`l=len(m.`, fieldname, `)`) + p.P(`if l > 0 {`) + p.In() + p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`) + p.Out() + p.P(`}`) + } else { + p.P(`l=len(m.`, fieldname, `)`) + p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`) + } + } else { + if repeated { + p.P(`for _, e := range m.`, fieldname, ` { `) + p.In() + p.P(`l=e.`, sizeName, `()`) + p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`) + p.Out() + p.P(`}`) + } else { + p.P(`l=m.`, fieldname, `.`, sizeName, `()`) + p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`) + } + } + case descriptor.FieldDescriptorProto_TYPE_SINT32, + descriptor.FieldDescriptorProto_TYPE_SINT64: + if packed { + p.P(`l = 0`) + p.P(`for _, e := range m.`, fieldname, ` {`) + p.In() + p.P(`l+=soz`, p.localName, `(uint64(e))`) + p.Out() + p.P(`}`) + p.P(`n+=`, strconv.Itoa(key), `+sov`, p.localName, `(uint64(l))+l`) + } else if repeated { + p.P(`for _, e := range m.`, fieldname, ` {`) + p.In() + p.P(`n+=`, strconv.Itoa(key), `+soz`, p.localName, `(uint64(e))`) + p.Out() + p.P(`}`) + } else if proto3 { + p.P(`if m.`, fieldname, ` != 0 {`) + p.In() + p.P(`n+=`, strconv.Itoa(key), `+soz`, p.localName, `(uint64(m.`, fieldname, `))`) + p.Out() + p.P(`}`) + } else if nullable { + p.P(`n+=`, strconv.Itoa(key), `+soz`, p.localName, `(uint64(*m.`, fieldname, `))`) + } else { + p.P(`n+=`, strconv.Itoa(key), `+soz`, p.localName, `(uint64(m.`, fieldname, `))`) + } + default: + panic("not implemented") + } + if repeated || doNilCheck { + p.Out() + p.P(`}`) + } +} + +func (p *size) Generate(file *generator.FileDescriptor) { + p.PluginImports = generator.NewPluginImports(p.Generator) + p.atleastOne = false + p.localName = generator.FileName(file) + p.typesPkg = p.NewImport("github.com/gogo/protobuf/types") + protoPkg := p.NewImport("github.com/gogo/protobuf/proto") + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + protoPkg = p.NewImport("github.com/golang/protobuf/proto") + } + for _, message := range file.Messages() { + sizeName := "" + if gogoproto.IsSizer(file.FileDescriptorProto, message.DescriptorProto) && gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) { + fmt.Fprintf(os.Stderr, "ERROR: message %v cannot support both sizer and protosizer plugins\n", generator.CamelCase(*message.Name)) + os.Exit(1) + } + if gogoproto.IsSizer(file.FileDescriptorProto, message.DescriptorProto) { + sizeName = "Size" + } else if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) { + sizeName = "ProtoSize" + } else { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + p.atleastOne = true + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + p.P(`func (m *`, ccTypeName, `) `, sizeName, `() (n int) {`) + p.In() + p.P(`var l int`) + p.P(`_ = l`) + oneofs := make(map[string]struct{}) + for _, field := range message.Field { + oneof := field.OneofIndex != nil + if !oneof { + proto3 := gogoproto.IsProto3(file.FileDescriptorProto) + p.generateField(proto3, file, message, field, sizeName) + } else { + fieldname := p.GetFieldName(message, field) + if _, ok := oneofs[fieldname]; ok { + continue + } else { + oneofs[fieldname] = struct{}{} + } + p.P(`if m.`, fieldname, ` != nil {`) + p.In() + p.P(`n+=m.`, fieldname, `.`, sizeName, `()`) + p.Out() + p.P(`}`) + } + } + if message.DescriptorProto.HasExtension() { + if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`n += `, protoPkg.Use(), `.SizeOfInternalExtension(m)`) + } else { + p.P(`if m.XXX_extensions != nil {`) + p.In() + p.P(`n+=len(m.XXX_extensions)`) + p.Out() + p.P(`}`) + } + } + if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`if m.XXX_unrecognized != nil {`) + p.In() + p.P(`n+=len(m.XXX_unrecognized)`) + p.Out() + p.P(`}`) + } + p.P(`return n`) + p.Out() + p.P(`}`) + p.P() + + //Generate Size methods for oneof fields + m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto) + for _, f := range m.Field { + oneof := f.OneofIndex != nil + if !oneof { + continue + } + ccTypeName := p.OneOfTypeName(message, f) + p.P(`func (m *`, ccTypeName, `) `, sizeName, `() (n int) {`) + p.In() + p.P(`var l int`) + p.P(`_ = l`) + vanity.TurnOffNullableForNativeTypes(f) + p.generateField(false, file, message, f, sizeName) + p.P(`return n`) + p.Out() + p.P(`}`) + } + } + + if !p.atleastOne { + return + } + + p.sizeVarint() + p.sizeZigZag() + +} + +func init() { + generator.RegisterPlugin(NewSize()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/size/sizetest.go b/vendor/github.com/gogo/protobuf/plugin/size/sizetest.go new file mode 100644 index 00000000000..1df98730007 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/size/sizetest.go @@ -0,0 +1,134 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package size + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/plugin/testgen" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type test struct { + *generator.Generator +} + +func NewTest(g *generator.Generator) testgen.TestPlugin { + return &test{g} +} + +func (p *test) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool { + used := false + randPkg := imports.NewImport("math/rand") + timePkg := imports.NewImport("time") + testingPkg := imports.NewImport("testing") + protoPkg := imports.NewImport("github.com/gogo/protobuf/proto") + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + protoPkg = imports.NewImport("github.com/golang/protobuf/proto") + } + for _, message := range file.Messages() { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + sizeName := "" + if gogoproto.IsSizer(file.FileDescriptorProto, message.DescriptorProto) { + sizeName = "Size" + } else if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) { + sizeName = "ProtoSize" + } else { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + + if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) { + used = true + p.P(`func Test`, ccTypeName, sizeName, `(t *`, testingPkg.Use(), `.T) {`) + p.In() + p.P(`seed := `, timePkg.Use(), `.Now().UnixNano()`) + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(seed))`) + p.P(`p := NewPopulated`, ccTypeName, `(popr, true)`) + p.P(`size2 := `, protoPkg.Use(), `.Size(p)`) + p.P(`dAtA, err := `, protoPkg.Use(), `.Marshal(p)`) + p.P(`if err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`) + p.Out() + p.P(`}`) + p.P(`size := p.`, sizeName, `()`) + p.P(`if len(dAtA) != size {`) + p.In() + p.P(`t.Errorf("seed = %d, size %v != marshalled size %v", seed, size, len(dAtA))`) + p.Out() + p.P(`}`) + p.P(`if size2 != size {`) + p.In() + p.P(`t.Errorf("seed = %d, size %v != before marshal proto.Size %v", seed, size, size2)`) + p.Out() + p.P(`}`) + p.P(`size3 := `, protoPkg.Use(), `.Size(p)`) + p.P(`if size3 != size {`) + p.In() + p.P(`t.Errorf("seed = %d, size %v != after marshal proto.Size %v", seed, size, size3)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + p.P() + } + + if gogoproto.HasBenchGen(file.FileDescriptorProto, message.DescriptorProto) { + used = true + p.P(`func Benchmark`, ccTypeName, sizeName, `(b *`, testingPkg.Use(), `.B) {`) + p.In() + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(616))`) + p.P(`total := 0`) + p.P(`pops := make([]*`, ccTypeName, `, 1000)`) + p.P(`for i := 0; i < 1000; i++ {`) + p.In() + p.P(`pops[i] = NewPopulated`, ccTypeName, `(popr, false)`) + p.Out() + p.P(`}`) + p.P(`b.ResetTimer()`) + p.P(`for i := 0; i < b.N; i++ {`) + p.In() + p.P(`total += pops[i%1000].`, sizeName, `()`) + p.Out() + p.P(`}`) + p.P(`b.SetBytes(int64(total / b.N))`) + p.Out() + p.P(`}`) + p.P() + } + + } + return used +} + +func init() { + testgen.RegisterTestPlugin(NewTest) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/stringer/stringer.go b/vendor/github.com/gogo/protobuf/plugin/stringer/stringer.go new file mode 100644 index 00000000000..098a9db7710 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/stringer/stringer.go @@ -0,0 +1,296 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The stringer plugin generates a String method for each message. + +It is enabled by the following extensions: + + - stringer + - stringer_all + +The stringer plugin also generates a test given it is enabled using one of the following extensions: + + - testgen + - testgen_all + +Let us look at: + + github.com/gogo/protobuf/test/example/example.proto + +Btw all the output can be seen at: + + github.com/gogo/protobuf/test/example/* + +The following message: + + option (gogoproto.goproto_stringer_all) = false; + option (gogoproto.stringer_all) = true; + + message A { + optional string Description = 1 [(gogoproto.nullable) = false]; + optional int64 Number = 2 [(gogoproto.nullable) = false]; + optional bytes Id = 3 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uuid", (gogoproto.nullable) = false]; + } + +given to the stringer stringer, will generate the following code: + + func (this *A) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&A{`, + `Description:` + fmt.Sprintf("%v", this.Description) + `,`, + `Number:` + fmt.Sprintf("%v", this.Number) + `,`, + `Id:` + fmt.Sprintf("%v", this.Id) + `,`, + `XXX_unrecognized:` + fmt.Sprintf("%v", this.XXX_unrecognized) + `,`, + `}`, + }, "") + return s + } + +and the following test code: + + func TestAStringer(t *testing4.T) { + popr := math_rand4.New(math_rand4.NewSource(time4.Now().UnixNano())) + p := NewPopulatedA(popr, false) + s1 := p.String() + s2 := fmt1.Sprintf("%v", p) + if s1 != s2 { + t.Fatalf("String want %v got %v", s1, s2) + } + } + +Typically fmt.Printf("%v") will stop to print when it reaches a pointer and +not print their values, while the generated String method will always print all values, recursively. + +*/ +package stringer + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" + "strings" +) + +type stringer struct { + *generator.Generator + generator.PluginImports + atleastOne bool + localName string +} + +func NewStringer() *stringer { + return &stringer{} +} + +func (p *stringer) Name() string { + return "stringer" +} + +func (p *stringer) Init(g *generator.Generator) { + p.Generator = g +} + +func (p *stringer) Generate(file *generator.FileDescriptor) { + proto3 := gogoproto.IsProto3(file.FileDescriptorProto) + p.PluginImports = generator.NewPluginImports(p.Generator) + p.atleastOne = false + + p.localName = generator.FileName(file) + + fmtPkg := p.NewImport("fmt") + stringsPkg := p.NewImport("strings") + reflectPkg := p.NewImport("reflect") + sortKeysPkg := p.NewImport("github.com/gogo/protobuf/sortkeys") + protoPkg := p.NewImport("github.com/gogo/protobuf/proto") + for _, message := range file.Messages() { + if !gogoproto.IsStringer(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if gogoproto.EnabledGoStringer(file.FileDescriptorProto, message.DescriptorProto) { + panic("old string method needs to be disabled, please use gogoproto.goproto_stringer or gogoproto.goproto_stringer_all and set it to false") + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + p.atleastOne = true + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + p.P(`func (this *`, ccTypeName, `) String() string {`) + p.In() + p.P(`if this == nil {`) + p.In() + p.P(`return "nil"`) + p.Out() + p.P(`}`) + for _, field := range message.Field { + if !p.IsMap(field) { + continue + } + fieldname := p.GetFieldName(message, field) + + m := p.GoMapType(nil, field) + mapgoTyp, keyField, keyAliasField := m.GoType, m.KeyField, m.KeyAliasField + keysName := `keysFor` + fieldname + keygoTyp, _ := p.GoType(nil, keyField) + keygoTyp = strings.Replace(keygoTyp, "*", "", 1) + keygoAliasTyp, _ := p.GoType(nil, keyAliasField) + keygoAliasTyp = strings.Replace(keygoAliasTyp, "*", "", 1) + keyCapTyp := generator.CamelCase(keygoTyp) + p.P(keysName, ` := make([]`, keygoTyp, `, 0, len(this.`, fieldname, `))`) + p.P(`for k, _ := range this.`, fieldname, ` {`) + p.In() + if keygoAliasTyp == keygoTyp { + p.P(keysName, ` = append(`, keysName, `, k)`) + } else { + p.P(keysName, ` = append(`, keysName, `, `, keygoTyp, `(k))`) + } + p.Out() + p.P(`}`) + p.P(sortKeysPkg.Use(), `.`, keyCapTyp, `s(`, keysName, `)`) + mapName := `mapStringFor` + fieldname + p.P(mapName, ` := "`, mapgoTyp, `{"`) + p.P(`for _, k := range `, keysName, ` {`) + p.In() + if keygoAliasTyp == keygoTyp { + p.P(mapName, ` += fmt.Sprintf("%v: %v,", k, this.`, fieldname, `[k])`) + } else { + p.P(mapName, ` += fmt.Sprintf("%v: %v,", k, this.`, fieldname, `[`, keygoAliasTyp, `(k)])`) + } + p.Out() + p.P(`}`) + p.P(mapName, ` += "}"`) + } + p.P("s := ", stringsPkg.Use(), ".Join([]string{`&", ccTypeName, "{`,") + oneofs := make(map[string]struct{}) + for _, field := range message.Field { + nullable := gogoproto.IsNullable(field) + repeated := field.IsRepeated() + fieldname := p.GetFieldName(message, field) + oneof := field.OneofIndex != nil + if oneof { + if _, ok := oneofs[fieldname]; ok { + continue + } else { + oneofs[fieldname] = struct{}{} + } + p.P("`", fieldname, ":`", ` + `, fmtPkg.Use(), `.Sprintf("%v", this.`, fieldname, ") + `,", "`,") + } else if p.IsMap(field) { + mapName := `mapStringFor` + fieldname + p.P("`", fieldname, ":`", ` + `, mapName, " + `,", "`,") + } else if (field.IsMessage() && !gogoproto.IsCustomType(field)) || p.IsGroup(field) { + desc := p.ObjectNamed(field.GetTypeName()) + msgname := p.TypeName(desc) + msgnames := strings.Split(msgname, ".") + typeName := msgnames[len(msgnames)-1] + if nullable { + p.P("`", fieldname, ":`", ` + `, stringsPkg.Use(), `.Replace(`, fmtPkg.Use(), `.Sprintf("%v", this.`, fieldname, `), "`, typeName, `","`, msgname, `"`, ", 1) + `,", "`,") + } else if repeated { + p.P("`", fieldname, ":`", ` + `, stringsPkg.Use(), `.Replace(`, stringsPkg.Use(), `.Replace(`, fmtPkg.Use(), `.Sprintf("%v", this.`, fieldname, `), "`, typeName, `","`, msgname, `"`, ", 1),`&`,``,1) + `,", "`,") + } else { + p.P("`", fieldname, ":`", ` + `, stringsPkg.Use(), `.Replace(`, stringsPkg.Use(), `.Replace(this.`, fieldname, `.String(), "`, typeName, `","`, msgname, `"`, ", 1),`&`,``,1) + `,", "`,") + } + } else { + if nullable && !repeated && !proto3 { + p.P("`", fieldname, ":`", ` + valueToString`, p.localName, `(this.`, fieldname, ") + `,", "`,") + } else { + p.P("`", fieldname, ":`", ` + `, fmtPkg.Use(), `.Sprintf("%v", this.`, fieldname, ") + `,", "`,") + } + } + } + if message.DescriptorProto.HasExtension() { + if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) { + p.P("`XXX_InternalExtensions:` + ", protoPkg.Use(), ".StringFromInternalExtension(this) + `,`,") + } else { + p.P("`XXX_extensions:` + ", protoPkg.Use(), ".StringFromExtensionsBytes(this.XXX_extensions) + `,`,") + } + } + if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) { + p.P("`XXX_unrecognized:` + ", fmtPkg.Use(), `.Sprintf("%v", this.XXX_unrecognized) + `, "`,`,") + } + p.P("`}`,") + p.P(`}`, `,""`, ")") + p.P(`return s`) + p.Out() + p.P(`}`) + + //Generate String methods for oneof fields + for _, field := range message.Field { + oneof := field.OneofIndex != nil + if !oneof { + continue + } + ccTypeName := p.OneOfTypeName(message, field) + p.P(`func (this *`, ccTypeName, `) String() string {`) + p.In() + p.P(`if this == nil {`) + p.In() + p.P(`return "nil"`) + p.Out() + p.P(`}`) + p.P("s := ", stringsPkg.Use(), ".Join([]string{`&", ccTypeName, "{`,") + fieldname := p.GetOneOfFieldName(message, field) + if field.IsMessage() || p.IsGroup(field) { + desc := p.ObjectNamed(field.GetTypeName()) + msgname := p.TypeName(desc) + msgnames := strings.Split(msgname, ".") + typeName := msgnames[len(msgnames)-1] + p.P("`", fieldname, ":`", ` + `, stringsPkg.Use(), `.Replace(`, fmtPkg.Use(), `.Sprintf("%v", this.`, fieldname, `), "`, typeName, `","`, msgname, `"`, ", 1) + `,", "`,") + } else { + p.P("`", fieldname, ":`", ` + `, fmtPkg.Use(), `.Sprintf("%v", this.`, fieldname, ") + `,", "`,") + } + p.P("`}`,") + p.P(`}`, `,""`, ")") + p.P(`return s`) + p.Out() + p.P(`}`) + } + } + + if !p.atleastOne { + return + } + + p.P(`func valueToString`, p.localName, `(v interface{}) string {`) + p.In() + p.P(`rv := `, reflectPkg.Use(), `.ValueOf(v)`) + p.P(`if rv.IsNil() {`) + p.In() + p.P(`return "nil"`) + p.Out() + p.P(`}`) + p.P(`pv := `, reflectPkg.Use(), `.Indirect(rv).Interface()`) + p.P(`return `, fmtPkg.Use(), `.Sprintf("*%v", pv)`) + p.Out() + p.P(`}`) + +} + +func init() { + generator.RegisterPlugin(NewStringer()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/stringer/stringertest.go b/vendor/github.com/gogo/protobuf/plugin/stringer/stringertest.go new file mode 100644 index 00000000000..0912a22df63 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/stringer/stringertest.go @@ -0,0 +1,83 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package stringer + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/plugin/testgen" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type test struct { + *generator.Generator +} + +func NewTest(g *generator.Generator) testgen.TestPlugin { + return &test{g} +} + +func (p *test) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool { + used := false + randPkg := imports.NewImport("math/rand") + timePkg := imports.NewImport("time") + testingPkg := imports.NewImport("testing") + fmtPkg := imports.NewImport("fmt") + for _, message := range file.Messages() { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + if !gogoproto.IsStringer(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + + if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) { + used = true + p.P(`func Test`, ccTypeName, `Stringer(t *`, testingPkg.Use(), `.T) {`) + p.In() + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(`, timePkg.Use(), `.Now().UnixNano()))`) + p.P(`p := NewPopulated`, ccTypeName, `(popr, false)`) + p.P(`s1 := p.String()`) + p.P(`s2 := `, fmtPkg.Use(), `.Sprintf("%v", p)`) + p.P(`if s1 != s2 {`) + p.In() + p.P(`t.Fatalf("String want %v got %v", s1, s2)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + } + + } + return used +} + +func init() { + testgen.RegisterTestPlugin(NewTest) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/testgen/testgen.go b/vendor/github.com/gogo/protobuf/plugin/testgen/testgen.go new file mode 100644 index 00000000000..e0a9287e560 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/testgen/testgen.go @@ -0,0 +1,608 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The testgen plugin generates Test and Benchmark functions for each message. + +Tests are enabled using the following extensions: + + - testgen + - testgen_all + +Benchmarks are enabled using the following extensions: + + - benchgen + - benchgen_all + +Let us look at: + + github.com/gogo/protobuf/test/example/example.proto + +Btw all the output can be seen at: + + github.com/gogo/protobuf/test/example/* + +The following message: + + option (gogoproto.testgen_all) = true; + option (gogoproto.benchgen_all) = true; + + message A { + optional string Description = 1 [(gogoproto.nullable) = false]; + optional int64 Number = 2 [(gogoproto.nullable) = false]; + optional bytes Id = 3 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uuid", (gogoproto.nullable) = false]; + } + +given to the testgen plugin, will generate the following test code: + + func TestAProto(t *testing.T) { + popr := math_rand.New(math_rand.NewSource(time.Now().UnixNano())) + p := NewPopulatedA(popr, false) + dAtA, err := github_com_gogo_protobuf_proto.Marshal(p) + if err != nil { + panic(err) + } + msg := &A{} + if err := github_com_gogo_protobuf_proto.Unmarshal(dAtA, msg); err != nil { + panic(err) + } + for i := range dAtA { + dAtA[i] = byte(popr.Intn(256)) + } + if err := p.VerboseEqual(msg); err != nil { + t.Fatalf("%#v !VerboseProto %#v, since %v", msg, p, err) + } + if !p.Equal(msg) { + t.Fatalf("%#v !Proto %#v", msg, p) + } + } + + func BenchmarkAProtoMarshal(b *testing.B) { + popr := math_rand.New(math_rand.NewSource(616)) + total := 0 + pops := make([]*A, 10000) + for i := 0; i < 10000; i++ { + pops[i] = NewPopulatedA(popr, false) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + dAtA, err := github_com_gogo_protobuf_proto.Marshal(pops[i%10000]) + if err != nil { + panic(err) + } + total += len(dAtA) + } + b.SetBytes(int64(total / b.N)) + } + + func BenchmarkAProtoUnmarshal(b *testing.B) { + popr := math_rand.New(math_rand.NewSource(616)) + total := 0 + datas := make([][]byte, 10000) + for i := 0; i < 10000; i++ { + dAtA, err := github_com_gogo_protobuf_proto.Marshal(NewPopulatedA(popr, false)) + if err != nil { + panic(err) + } + datas[i] = dAtA + } + msg := &A{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + total += len(datas[i%10000]) + if err := github_com_gogo_protobuf_proto.Unmarshal(datas[i%10000], msg); err != nil { + panic(err) + } + } + b.SetBytes(int64(total / b.N)) + } + + + func TestAJSON(t *testing1.T) { + popr := math_rand1.New(math_rand1.NewSource(time1.Now().UnixNano())) + p := NewPopulatedA(popr, true) + jsondata, err := encoding_json.Marshal(p) + if err != nil { + panic(err) + } + msg := &A{} + err = encoding_json.Unmarshal(jsondata, msg) + if err != nil { + panic(err) + } + if err := p.VerboseEqual(msg); err != nil { + t.Fatalf("%#v !VerboseProto %#v, since %v", msg, p, err) + } + if !p.Equal(msg) { + t.Fatalf("%#v !Json Equal %#v", msg, p) + } + } + + func TestAProtoText(t *testing2.T) { + popr := math_rand2.New(math_rand2.NewSource(time2.Now().UnixNano())) + p := NewPopulatedA(popr, true) + dAtA := github_com_gogo_protobuf_proto1.MarshalTextString(p) + msg := &A{} + if err := github_com_gogo_protobuf_proto1.UnmarshalText(dAtA, msg); err != nil { + panic(err) + } + if err := p.VerboseEqual(msg); err != nil { + t.Fatalf("%#v !VerboseProto %#v, since %v", msg, p, err) + } + if !p.Equal(msg) { + t.Fatalf("%#v !Proto %#v", msg, p) + } + } + + func TestAProtoCompactText(t *testing2.T) { + popr := math_rand2.New(math_rand2.NewSource(time2.Now().UnixNano())) + p := NewPopulatedA(popr, true) + dAtA := github_com_gogo_protobuf_proto1.CompactTextString(p) + msg := &A{} + if err := github_com_gogo_protobuf_proto1.UnmarshalText(dAtA, msg); err != nil { + panic(err) + } + if err := p.VerboseEqual(msg); err != nil { + t.Fatalf("%#v !VerboseProto %#v, since %v", msg, p, err) + } + if !p.Equal(msg) { + t.Fatalf("%#v !Proto %#v", msg, p) + } + } + +Other registered tests are also generated. +Tests are registered to this test plugin by calling the following function. + + func RegisterTestPlugin(newFunc NewTestPlugin) + +where NewTestPlugin is: + + type NewTestPlugin func(g *generator.Generator) TestPlugin + +and TestPlugin is an interface: + + type TestPlugin interface { + Generate(imports generator.PluginImports, file *generator.FileDescriptor) (used bool) + } + +Plugins that use this interface include: + + - populate + - gostring + - equal + - union + - and more + +Please look at these plugins as examples of how to create your own. +A good idea is to let each plugin generate its own tests. + +*/ +package testgen + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type TestPlugin interface { + Generate(imports generator.PluginImports, file *generator.FileDescriptor) (used bool) +} + +type NewTestPlugin func(g *generator.Generator) TestPlugin + +var testplugins = make([]NewTestPlugin, 0) + +func RegisterTestPlugin(newFunc NewTestPlugin) { + testplugins = append(testplugins, newFunc) +} + +type plugin struct { + *generator.Generator + generator.PluginImports + tests []TestPlugin +} + +func NewPlugin() *plugin { + return &plugin{} +} + +func (p *plugin) Name() string { + return "testgen" +} + +func (p *plugin) Init(g *generator.Generator) { + p.Generator = g + p.tests = make([]TestPlugin, 0, len(testplugins)) + for i := range testplugins { + p.tests = append(p.tests, testplugins[i](g)) + } +} + +func (p *plugin) Generate(file *generator.FileDescriptor) { + p.PluginImports = generator.NewPluginImports(p.Generator) + atLeastOne := false + for i := range p.tests { + used := p.tests[i].Generate(p.PluginImports, file) + if used { + atLeastOne = true + } + } + if atLeastOne { + p.P(`//These tests are generated by github.com/gogo/protobuf/plugin/testgen`) + } +} + +type testProto struct { + *generator.Generator +} + +func newProto(g *generator.Generator) TestPlugin { + return &testProto{g} +} + +func (p *testProto) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool { + used := false + testingPkg := imports.NewImport("testing") + randPkg := imports.NewImport("math/rand") + timePkg := imports.NewImport("time") + protoPkg := imports.NewImport("github.com/gogo/protobuf/proto") + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + protoPkg = imports.NewImport("github.com/golang/protobuf/proto") + } + for _, message := range file.Messages() { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) { + used = true + + p.P(`func Test`, ccTypeName, `Proto(t *`, testingPkg.Use(), `.T) {`) + p.In() + p.P(`seed := `, timePkg.Use(), `.Now().UnixNano()`) + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(seed))`) + p.P(`p := NewPopulated`, ccTypeName, `(popr, false)`) + p.P(`dAtA, err := `, protoPkg.Use(), `.Marshal(p)`) + p.P(`if err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`) + p.Out() + p.P(`}`) + p.P(`msg := &`, ccTypeName, `{}`) + p.P(`if err := `, protoPkg.Use(), `.Unmarshal(dAtA, msg); err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`) + p.Out() + p.P(`}`) + p.P(`littlefuzz := make([]byte, len(dAtA))`) + p.P(`copy(littlefuzz, dAtA)`) + p.P(`for i := range dAtA {`) + p.In() + p.P(`dAtA[i] = byte(popr.Intn(256))`) + p.Out() + p.P(`}`) + if gogoproto.HasVerboseEqual(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`if err := p.VerboseEqual(msg); err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, %#v !VerboseProto %#v, since %v", seed, msg, p, err)`) + p.Out() + p.P(`}`) + } + p.P(`if !p.Equal(msg) {`) + p.In() + p.P(`t.Fatalf("seed = %d, %#v !Proto %#v", seed, msg, p)`) + p.Out() + p.P(`}`) + p.P(`if len(littlefuzz) > 0 {`) + p.In() + p.P(`fuzzamount := 100`) + p.P(`for i := 0; i < fuzzamount; i++ {`) + p.In() + p.P(`littlefuzz[popr.Intn(len(littlefuzz))] = byte(popr.Intn(256))`) + p.P(`littlefuzz = append(littlefuzz, byte(popr.Intn(256)))`) + p.Out() + p.P(`}`) + p.P(`// shouldn't panic`) + p.P(`_ = `, protoPkg.Use(), `.Unmarshal(littlefuzz, msg)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + p.P() + } + + if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) { + if gogoproto.IsMarshaler(file.FileDescriptorProto, message.DescriptorProto) || gogoproto.IsUnsafeMarshaler(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`func Test`, ccTypeName, `MarshalTo(t *`, testingPkg.Use(), `.T) {`) + p.In() + p.P(`seed := `, timePkg.Use(), `.Now().UnixNano()`) + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(seed))`) + p.P(`p := NewPopulated`, ccTypeName, `(popr, false)`) + if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`size := p.ProtoSize()`) + } else { + p.P(`size := p.Size()`) + } + p.P(`dAtA := make([]byte, size)`) + p.P(`for i := range dAtA {`) + p.In() + p.P(`dAtA[i] = byte(popr.Intn(256))`) + p.Out() + p.P(`}`) + p.P(`_, err := p.MarshalTo(dAtA)`) + p.P(`if err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`) + p.Out() + p.P(`}`) + p.P(`msg := &`, ccTypeName, `{}`) + p.P(`if err := `, protoPkg.Use(), `.Unmarshal(dAtA, msg); err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`) + p.Out() + p.P(`}`) + p.P(`for i := range dAtA {`) + p.In() + p.P(`dAtA[i] = byte(popr.Intn(256))`) + p.Out() + p.P(`}`) + if gogoproto.HasVerboseEqual(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`if err := p.VerboseEqual(msg); err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, %#v !VerboseProto %#v, since %v", seed, msg, p, err)`) + p.Out() + p.P(`}`) + } + p.P(`if !p.Equal(msg) {`) + p.In() + p.P(`t.Fatalf("seed = %d, %#v !Proto %#v", seed, msg, p)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + p.P() + } + } + + if gogoproto.HasBenchGen(file.FileDescriptorProto, message.DescriptorProto) { + used = true + p.P(`func Benchmark`, ccTypeName, `ProtoMarshal(b *`, testingPkg.Use(), `.B) {`) + p.In() + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(616))`) + p.P(`total := 0`) + p.P(`pops := make([]*`, ccTypeName, `, 10000)`) + p.P(`for i := 0; i < 10000; i++ {`) + p.In() + p.P(`pops[i] = NewPopulated`, ccTypeName, `(popr, false)`) + p.Out() + p.P(`}`) + p.P(`b.ResetTimer()`) + p.P(`for i := 0; i < b.N; i++ {`) + p.In() + p.P(`dAtA, err := `, protoPkg.Use(), `.Marshal(pops[i%10000])`) + p.P(`if err != nil {`) + p.In() + p.P(`panic(err)`) + p.Out() + p.P(`}`) + p.P(`total += len(dAtA)`) + p.Out() + p.P(`}`) + p.P(`b.SetBytes(int64(total / b.N))`) + p.Out() + p.P(`}`) + p.P() + + p.P(`func Benchmark`, ccTypeName, `ProtoUnmarshal(b *`, testingPkg.Use(), `.B) {`) + p.In() + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(616))`) + p.P(`total := 0`) + p.P(`datas := make([][]byte, 10000)`) + p.P(`for i := 0; i < 10000; i++ {`) + p.In() + p.P(`dAtA, err := `, protoPkg.Use(), `.Marshal(NewPopulated`, ccTypeName, `(popr, false))`) + p.P(`if err != nil {`) + p.In() + p.P(`panic(err)`) + p.Out() + p.P(`}`) + p.P(`datas[i] = dAtA`) + p.Out() + p.P(`}`) + p.P(`msg := &`, ccTypeName, `{}`) + p.P(`b.ResetTimer()`) + p.P(`for i := 0; i < b.N; i++ {`) + p.In() + p.P(`total += len(datas[i%10000])`) + p.P(`if err := `, protoPkg.Use(), `.Unmarshal(datas[i%10000], msg); err != nil {`) + p.In() + p.P(`panic(err)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + p.P(`b.SetBytes(int64(total / b.N))`) + p.Out() + p.P(`}`) + p.P() + } + } + return used +} + +type testJson struct { + *generator.Generator +} + +func newJson(g *generator.Generator) TestPlugin { + return &testJson{g} +} + +func (p *testJson) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool { + used := false + testingPkg := imports.NewImport("testing") + randPkg := imports.NewImport("math/rand") + timePkg := imports.NewImport("time") + jsonPkg := imports.NewImport("github.com/gogo/protobuf/jsonpb") + for _, message := range file.Messages() { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) { + used = true + p.P(`func Test`, ccTypeName, `JSON(t *`, testingPkg.Use(), `.T) {`) + p.In() + p.P(`seed := `, timePkg.Use(), `.Now().UnixNano()`) + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(seed))`) + p.P(`p := NewPopulated`, ccTypeName, `(popr, true)`) + p.P(`marshaler := `, jsonPkg.Use(), `.Marshaler{}`) + p.P(`jsondata, err := marshaler.MarshalToString(p)`) + p.P(`if err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`) + p.Out() + p.P(`}`) + p.P(`msg := &`, ccTypeName, `{}`) + p.P(`err = `, jsonPkg.Use(), `.UnmarshalString(jsondata, msg)`) + p.P(`if err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`) + p.Out() + p.P(`}`) + if gogoproto.HasVerboseEqual(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`if err := p.VerboseEqual(msg); err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, %#v !VerboseProto %#v, since %v", seed, msg, p, err)`) + p.Out() + p.P(`}`) + } + p.P(`if !p.Equal(msg) {`) + p.In() + p.P(`t.Fatalf("seed = %d, %#v !Json Equal %#v", seed, msg, p)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + } + } + return used +} + +type testText struct { + *generator.Generator +} + +func newText(g *generator.Generator) TestPlugin { + return &testText{g} +} + +func (p *testText) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool { + used := false + testingPkg := imports.NewImport("testing") + randPkg := imports.NewImport("math/rand") + timePkg := imports.NewImport("time") + protoPkg := imports.NewImport("github.com/gogo/protobuf/proto") + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + protoPkg = imports.NewImport("github.com/golang/protobuf/proto") + } + //fmtPkg := imports.NewImport("fmt") + for _, message := range file.Messages() { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) { + used = true + + p.P(`func Test`, ccTypeName, `ProtoText(t *`, testingPkg.Use(), `.T) {`) + p.In() + p.P(`seed := `, timePkg.Use(), `.Now().UnixNano()`) + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(seed))`) + p.P(`p := NewPopulated`, ccTypeName, `(popr, true)`) + p.P(`dAtA := `, protoPkg.Use(), `.MarshalTextString(p)`) + p.P(`msg := &`, ccTypeName, `{}`) + p.P(`if err := `, protoPkg.Use(), `.UnmarshalText(dAtA, msg); err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`) + p.Out() + p.P(`}`) + if gogoproto.HasVerboseEqual(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`if err := p.VerboseEqual(msg); err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, %#v !VerboseProto %#v, since %v", seed, msg, p, err)`) + p.Out() + p.P(`}`) + } + p.P(`if !p.Equal(msg) {`) + p.In() + p.P(`t.Fatalf("seed = %d, %#v !Proto %#v", seed, msg, p)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + p.P() + + p.P(`func Test`, ccTypeName, `ProtoCompactText(t *`, testingPkg.Use(), `.T) {`) + p.In() + p.P(`seed := `, timePkg.Use(), `.Now().UnixNano()`) + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(seed))`) + p.P(`p := NewPopulated`, ccTypeName, `(popr, true)`) + p.P(`dAtA := `, protoPkg.Use(), `.CompactTextString(p)`) + p.P(`msg := &`, ccTypeName, `{}`) + p.P(`if err := `, protoPkg.Use(), `.UnmarshalText(dAtA, msg); err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, err = %v", seed, err)`) + p.Out() + p.P(`}`) + if gogoproto.HasVerboseEqual(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`if err := p.VerboseEqual(msg); err != nil {`) + p.In() + p.P(`t.Fatalf("seed = %d, %#v !VerboseProto %#v, since %v", seed, msg, p, err)`) + p.Out() + p.P(`}`) + } + p.P(`if !p.Equal(msg) {`) + p.In() + p.P(`t.Fatalf("seed = %d, %#v !Proto %#v", seed, msg, p)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + p.P() + + } + } + return used +} + +func init() { + RegisterTestPlugin(newProto) + RegisterTestPlugin(newJson) + RegisterTestPlugin(newText) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/union/union.go b/vendor/github.com/gogo/protobuf/plugin/union/union.go new file mode 100644 index 00000000000..72edb2498cd --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/union/union.go @@ -0,0 +1,209 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The onlyone plugin generates code for the onlyone extension. +All fields must be nullable and only one of the fields may be set, like a union. +Two methods are generated + + GetValue() interface{} + +and + + SetValue(v interface{}) (set bool) + +These provide easier interaction with a onlyone. + +The onlyone extension is not called union as this causes compile errors in the C++ generated code. +There can only be one ;) + +It is enabled by the following extensions: + + - onlyone + - onlyone_all + +The onlyone plugin also generates a test given it is enabled using one of the following extensions: + + - testgen + - testgen_all + +Lets look at: + + github.com/gogo/protobuf/test/example/example.proto + +Btw all the output can be seen at: + + github.com/gogo/protobuf/test/example/* + +The following message: + + message U { + option (gogoproto.onlyone) = true; + optional A A = 1; + optional B B = 2; + } + +given to the onlyone plugin, will generate code which looks a lot like this: + + func (this *U) GetValue() interface{} { + if this.A != nil { + return this.A + } + if this.B != nil { + return this.B + } + return nil + } + + func (this *U) SetValue(value interface{}) bool { + switch vt := value.(type) { + case *A: + this.A = vt + case *B: + this.B = vt + default: + return false + } + return true + } + +and the following test code: + + func TestUUnion(t *testing.T) { + popr := math_rand.New(math_rand.NewSource(time.Now().UnixNano())) + p := NewPopulatedU(popr) + v := p.GetValue() + msg := &U{} + if !msg.SetValue(v) { + t.Fatalf("Union: Could not set Value") + } + if !p.Equal(msg) { + t.Fatalf("%#v !Union Equal %#v", msg, p) + } + } + +*/ +package union + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type union struct { + *generator.Generator + generator.PluginImports +} + +func NewUnion() *union { + return &union{} +} + +func (p *union) Name() string { + return "union" +} + +func (p *union) Init(g *generator.Generator) { + p.Generator = g +} + +func (p *union) Generate(file *generator.FileDescriptor) { + p.PluginImports = generator.NewPluginImports(p.Generator) + + for _, message := range file.Messages() { + if !gogoproto.IsUnion(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if message.DescriptorProto.HasExtension() { + panic("onlyone does not currently support extensions") + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + p.P(`func (this *`, ccTypeName, `) GetValue() interface{} {`) + p.In() + for _, field := range message.Field { + fieldname := p.GetFieldName(message, field) + if fieldname == "Value" { + panic("cannot have a onlyone message " + ccTypeName + " with a field named Value") + } + p.P(`if this.`, fieldname, ` != nil {`) + p.In() + p.P(`return this.`, fieldname) + p.Out() + p.P(`}`) + } + p.P(`return nil`) + p.Out() + p.P(`}`) + p.P(``) + p.P(`func (this *`, ccTypeName, `) SetValue(value interface{}) bool {`) + p.In() + p.P(`switch vt := value.(type) {`) + p.In() + for _, field := range message.Field { + fieldname := p.GetFieldName(message, field) + goTyp, _ := p.GoType(message, field) + p.P(`case `, goTyp, `:`) + p.In() + p.P(`this.`, fieldname, ` = vt`) + p.Out() + } + p.P(`default:`) + p.In() + for _, field := range message.Field { + fieldname := p.GetFieldName(message, field) + if field.IsMessage() { + goTyp, _ := p.GoType(message, field) + obj := p.ObjectNamed(field.GetTypeName()).(*generator.Descriptor) + + if gogoproto.IsUnion(obj.File(), obj.DescriptorProto) { + p.P(`this.`, fieldname, ` = new(`, generator.GoTypeToName(goTyp), `)`) + p.P(`if set := this.`, fieldname, `.SetValue(value); set {`) + p.In() + p.P(`return true`) + p.Out() + p.P(`}`) + p.P(`this.`, fieldname, ` = nil`) + } + } + } + p.P(`return false`) + p.Out() + p.P(`}`) + p.P(`return true`) + p.Out() + p.P(`}`) + } +} + +func init() { + generator.RegisterPlugin(NewUnion()) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/union/uniontest.go b/vendor/github.com/gogo/protobuf/plugin/union/uniontest.go new file mode 100644 index 00000000000..949cf833850 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/union/uniontest.go @@ -0,0 +1,86 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package union + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/plugin/testgen" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type test struct { + *generator.Generator +} + +func NewTest(g *generator.Generator) testgen.TestPlugin { + return &test{g} +} + +func (p *test) Generate(imports generator.PluginImports, file *generator.FileDescriptor) bool { + used := false + randPkg := imports.NewImport("math/rand") + timePkg := imports.NewImport("time") + testingPkg := imports.NewImport("testing") + for _, message := range file.Messages() { + if !gogoproto.IsUnion(file.FileDescriptorProto, message.DescriptorProto) || + !gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + used = true + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + + p.P(`func Test`, ccTypeName, `OnlyOne(t *`, testingPkg.Use(), `.T) {`) + p.In() + p.P(`popr := `, randPkg.Use(), `.New(`, randPkg.Use(), `.NewSource(`, timePkg.Use(), `.Now().UnixNano()))`) + p.P(`p := NewPopulated`, ccTypeName, `(popr, true)`) + p.P(`v := p.GetValue()`) + p.P(`msg := &`, ccTypeName, `{}`) + p.P(`if !msg.SetValue(v) {`) + p.In() + p.P(`t.Fatalf("OnlyOne: Could not set Value")`) + p.Out() + p.P(`}`) + p.P(`if !p.Equal(msg) {`) + p.In() + p.P(`t.Fatalf("%#v !OnlyOne Equal %#v", msg, p)`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + + } + return used +} + +func init() { + testgen.RegisterTestPlugin(NewTest) +} diff --git a/vendor/github.com/gogo/protobuf/plugin/unmarshal/unmarshal.go b/vendor/github.com/gogo/protobuf/plugin/unmarshal/unmarshal.go new file mode 100644 index 00000000000..b5d9613df1d --- /dev/null +++ b/vendor/github.com/gogo/protobuf/plugin/unmarshal/unmarshal.go @@ -0,0 +1,1349 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +The unmarshal plugin generates a Unmarshal method for each message. +The `Unmarshal([]byte) error` method results in the fact that the message +implements the Unmarshaler interface. +The allows proto.Unmarshal to be faster by calling the generated Unmarshal method rather than using reflect. + +If is enabled by the following extensions: + + - unmarshaler + - unmarshaler_all + +Or the following extensions: + + - unsafe_unmarshaler + - unsafe_unmarshaler_all + +That is if you want to use the unsafe package in your generated code. +The speed up using the unsafe package is not very significant. + +The generation of unmarshalling tests are enabled using one of the following extensions: + + - testgen + - testgen_all + +And benchmarks given it is enabled using one of the following extensions: + + - benchgen + - benchgen_all + +Let us look at: + + github.com/gogo/protobuf/test/example/example.proto + +Btw all the output can be seen at: + + github.com/gogo/protobuf/test/example/* + +The following message: + + option (gogoproto.unmarshaler_all) = true; + + message B { + option (gogoproto.description) = true; + optional A A = 1 [(gogoproto.nullable) = false, (gogoproto.embed) = true]; + repeated bytes G = 2 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uint128", (gogoproto.nullable) = false]; + } + +given to the unmarshal plugin, will generate the following code: + + func (m *B) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + switch fieldNum { + case 1: + if wireType != 2 { + return proto.ErrWrongType + } + var msglen int + for shift := uint(0); ; shift += 7 { + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + if err := m.A.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return proto.ErrWrongType + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.G = append(m.G, github_com_gogo_protobuf_test_custom.Uint128{}) + if err := m.G[len(m.G)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + var sizeOfWire int + for { + sizeOfWire++ + wire >>= 7 + if wire == 0 { + break + } + } + iNdEx -= sizeOfWire + skippy, err := skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + return nil + } + +Remember when using this code to call proto.Unmarshal. +This will call m.Reset and invoke the generated Unmarshal method for you. +If you call m.Unmarshal without m.Reset you could be merging protocol buffers. + +*/ +package unmarshal + +import ( + "fmt" + "strconv" + "strings" + + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/proto" + descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +type unmarshal struct { + *generator.Generator + generator.PluginImports + atleastOne bool + ioPkg generator.Single + mathPkg generator.Single + typesPkg generator.Single + binaryPkg generator.Single + localName string +} + +func NewUnmarshal() *unmarshal { + return &unmarshal{} +} + +func (p *unmarshal) Name() string { + return "unmarshal" +} + +func (p *unmarshal) Init(g *generator.Generator) { + p.Generator = g +} + +func (p *unmarshal) decodeVarint(varName string, typName string) { + p.P(`for shift := uint(0); ; shift += 7 {`) + p.In() + p.P(`if shift >= 64 {`) + p.In() + p.P(`return ErrIntOverflow` + p.localName) + p.Out() + p.P(`}`) + p.P(`if iNdEx >= l {`) + p.In() + p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + p.P(`b := dAtA[iNdEx]`) + p.P(`iNdEx++`) + p.P(varName, ` |= (`, typName, `(b) & 0x7F) << shift`) + p.P(`if b < 0x80 {`) + p.In() + p.P(`break`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) +} + +func (p *unmarshal) decodeFixed32(varName string, typeName string) { + p.P(`if (iNdEx+4) > l {`) + p.In() + p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + p.P(varName, ` = `, typeName, `(`, p.binaryPkg.Use(), `.LittleEndian.Uint32(dAtA[iNdEx:]))`) + p.P(`iNdEx += 4`) +} + +func (p *unmarshal) decodeFixed64(varName string, typeName string) { + p.P(`if (iNdEx+8) > l {`) + p.In() + p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + p.P(varName, ` = `, typeName, `(`, p.binaryPkg.Use(), `.LittleEndian.Uint64(dAtA[iNdEx:]))`) + p.P(`iNdEx += 8`) +} + +func (p *unmarshal) declareMapField(varName string, nullable bool, customType bool, field *descriptor.FieldDescriptorProto) { + switch field.GetType() { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: + p.P(`var `, varName, ` float64`) + case descriptor.FieldDescriptorProto_TYPE_FLOAT: + p.P(`var `, varName, ` float32`) + case descriptor.FieldDescriptorProto_TYPE_INT64: + p.P(`var `, varName, ` int64`) + case descriptor.FieldDescriptorProto_TYPE_UINT64: + p.P(`var `, varName, ` uint64`) + case descriptor.FieldDescriptorProto_TYPE_INT32: + p.P(`var `, varName, ` int32`) + case descriptor.FieldDescriptorProto_TYPE_FIXED64: + p.P(`var `, varName, ` uint64`) + case descriptor.FieldDescriptorProto_TYPE_FIXED32: + p.P(`var `, varName, ` uint32`) + case descriptor.FieldDescriptorProto_TYPE_BOOL: + p.P(`var `, varName, ` bool`) + case descriptor.FieldDescriptorProto_TYPE_STRING: + cast, _ := p.GoType(nil, field) + cast = strings.Replace(cast, "*", "", 1) + p.P(`var `, varName, ` `, cast) + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + if gogoproto.IsStdTime(field) { + p.P(varName, ` := new(time.Time)`) + } else if gogoproto.IsStdDuration(field) { + p.P(varName, ` := new(time.Duration)`) + } else { + desc := p.ObjectNamed(field.GetTypeName()) + msgname := p.TypeName(desc) + if nullable { + p.P(`var `, varName, ` *`, msgname) + } else { + p.P(varName, ` := &`, msgname, `{}`) + } + } + case descriptor.FieldDescriptorProto_TYPE_BYTES: + if customType { + _, ctyp, err := generator.GetCustomType(field) + if err != nil { + panic(err) + } + p.P(`var `, varName, `1 `, ctyp) + p.P(`var `, varName, ` = &`, varName, `1`) + } else { + p.P(varName, ` := []byte{}`) + } + case descriptor.FieldDescriptorProto_TYPE_UINT32: + p.P(`var `, varName, ` uint32`) + case descriptor.FieldDescriptorProto_TYPE_ENUM: + typName := p.TypeName(p.ObjectNamed(field.GetTypeName())) + p.P(`var `, varName, ` `, typName) + case descriptor.FieldDescriptorProto_TYPE_SFIXED32: + p.P(`var `, varName, ` int32`) + case descriptor.FieldDescriptorProto_TYPE_SFIXED64: + p.P(`var `, varName, ` int64`) + case descriptor.FieldDescriptorProto_TYPE_SINT32: + p.P(`var `, varName, ` int32`) + case descriptor.FieldDescriptorProto_TYPE_SINT64: + p.P(`var `, varName, ` int64`) + } +} + +func (p *unmarshal) mapField(varName string, customType bool, field *descriptor.FieldDescriptorProto) { + switch field.GetType() { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: + p.P(`var `, varName, `temp uint64`) + p.decodeFixed64(varName+"temp", "uint64") + p.P(varName, ` = `, p.mathPkg.Use(), `.Float64frombits(`, varName, `temp)`) + case descriptor.FieldDescriptorProto_TYPE_FLOAT: + p.P(`var `, varName, `temp uint32`) + p.decodeFixed32(varName+"temp", "uint32") + p.P(varName, ` = `, p.mathPkg.Use(), `.Float32frombits(`, varName, `temp)`) + case descriptor.FieldDescriptorProto_TYPE_INT64: + p.decodeVarint(varName, "int64") + case descriptor.FieldDescriptorProto_TYPE_UINT64: + p.decodeVarint(varName, "uint64") + case descriptor.FieldDescriptorProto_TYPE_INT32: + p.decodeVarint(varName, "int32") + case descriptor.FieldDescriptorProto_TYPE_FIXED64: + p.decodeFixed64(varName, "uint64") + case descriptor.FieldDescriptorProto_TYPE_FIXED32: + p.decodeFixed32(varName, "uint32") + case descriptor.FieldDescriptorProto_TYPE_BOOL: + p.P(`var `, varName, `temp int`) + p.decodeVarint(varName+"temp", "int") + p.P(varName, ` = bool(`, varName, `temp != 0)`) + case descriptor.FieldDescriptorProto_TYPE_STRING: + p.P(`var stringLen`, varName, ` uint64`) + p.decodeVarint("stringLen"+varName, "uint64") + p.P(`intStringLen`, varName, ` := int(stringLen`, varName, `)`) + p.P(`if intStringLen`, varName, ` < 0 {`) + p.In() + p.P(`return ErrInvalidLength` + p.localName) + p.Out() + p.P(`}`) + p.P(`postStringIndex`, varName, ` := iNdEx + intStringLen`, varName) + p.P(`if postStringIndex`, varName, ` > l {`) + p.In() + p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + cast, _ := p.GoType(nil, field) + cast = strings.Replace(cast, "*", "", 1) + p.P(varName, ` = `, cast, `(dAtA[iNdEx:postStringIndex`, varName, `])`) + p.P(`iNdEx = postStringIndex`, varName) + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + p.P(`var mapmsglen int`) + p.decodeVarint("mapmsglen", "int") + p.P(`if mapmsglen < 0 {`) + p.In() + p.P(`return ErrInvalidLength` + p.localName) + p.Out() + p.P(`}`) + p.P(`postmsgIndex := iNdEx + mapmsglen`) + p.P(`if mapmsglen < 0 {`) + p.In() + p.P(`return ErrInvalidLength` + p.localName) + p.Out() + p.P(`}`) + p.P(`if postmsgIndex > l {`) + p.In() + p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + buf := `dAtA[iNdEx:postmsgIndex]` + if gogoproto.IsStdTime(field) { + p.P(`if err := `, p.typesPkg.Use(), `.StdTimeUnmarshal(`, varName, `, `, buf, `); err != nil {`) + } else if gogoproto.IsStdDuration(field) { + p.P(`if err := `, p.typesPkg.Use(), `.StdDurationUnmarshal(`, varName, `, `, buf, `); err != nil {`) + } else { + desc := p.ObjectNamed(field.GetTypeName()) + msgname := p.TypeName(desc) + p.P(varName, ` = &`, msgname, `{}`) + p.P(`if err := `, varName, `.Unmarshal(`, buf, `); err != nil {`) + } + p.In() + p.P(`return err`) + p.Out() + p.P(`}`) + p.P(`iNdEx = postmsgIndex`) + case descriptor.FieldDescriptorProto_TYPE_BYTES: + p.P(`var mapbyteLen uint64`) + p.decodeVarint("mapbyteLen", "uint64") + p.P(`intMapbyteLen := int(mapbyteLen)`) + p.P(`if intMapbyteLen < 0 {`) + p.In() + p.P(`return ErrInvalidLength` + p.localName) + p.Out() + p.P(`}`) + p.P(`postbytesIndex := iNdEx + intMapbyteLen`) + p.P(`if postbytesIndex > l {`) + p.In() + p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + if customType { + p.P(`if err := `, varName, `.Unmarshal(dAtA[iNdEx:postbytesIndex]); err != nil {`) + p.In() + p.P(`return err`) + p.Out() + p.P(`}`) + } else { + p.P(varName, ` = make([]byte, mapbyteLen)`) + p.P(`copy(`, varName, `, dAtA[iNdEx:postbytesIndex])`) + } + p.P(`iNdEx = postbytesIndex`) + case descriptor.FieldDescriptorProto_TYPE_UINT32: + p.decodeVarint(varName, "uint32") + case descriptor.FieldDescriptorProto_TYPE_ENUM: + typName := p.TypeName(p.ObjectNamed(field.GetTypeName())) + p.decodeVarint(varName, typName) + case descriptor.FieldDescriptorProto_TYPE_SFIXED32: + p.decodeFixed32(varName, "int32") + case descriptor.FieldDescriptorProto_TYPE_SFIXED64: + p.decodeFixed64(varName, "int64") + case descriptor.FieldDescriptorProto_TYPE_SINT32: + p.P(`var `, varName, `temp int32`) + p.decodeVarint(varName+"temp", "int32") + p.P(varName, `temp = int32((uint32(`, varName, `temp) >> 1) ^ uint32(((`, varName, `temp&1)<<31)>>31))`) + p.P(varName, ` = int32(`, varName, `temp)`) + case descriptor.FieldDescriptorProto_TYPE_SINT64: + p.P(`var `, varName, `temp uint64`) + p.decodeVarint(varName+"temp", "uint64") + p.P(varName, `temp = (`, varName, `temp >> 1) ^ uint64((int64(`, varName, `temp&1)<<63)>>63)`) + p.P(varName, ` = int64(`, varName, `temp)`) + } +} + +func (p *unmarshal) noStarOrSliceType(msg *generator.Descriptor, field *descriptor.FieldDescriptorProto) string { + typ, _ := p.GoType(msg, field) + if typ[0] == '*' { + return typ[1:] + } + if typ[0] == '[' && typ[1] == ']' { + return typ[2:] + } + return typ +} + +func (p *unmarshal) field(file *generator.FileDescriptor, msg *generator.Descriptor, field *descriptor.FieldDescriptorProto, fieldname string, proto3 bool) { + repeated := field.IsRepeated() + nullable := gogoproto.IsNullable(field) + typ := p.noStarOrSliceType(msg, field) + oneof := field.OneofIndex != nil + switch *field.Type { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: + p.P(`var v uint64`) + p.decodeFixed64("v", "uint64") + if oneof { + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{`, typ, "(", p.mathPkg.Use(), `.Float64frombits(v))}`) + } else if repeated { + p.P(`v2 := `, typ, "(", p.mathPkg.Use(), `.Float64frombits(v))`) + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v2)`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = `, typ, "(", p.mathPkg.Use(), `.Float64frombits(v))`) + } else { + p.P(`v2 := `, typ, "(", p.mathPkg.Use(), `.Float64frombits(v))`) + p.P(`m.`, fieldname, ` = &v2`) + } + case descriptor.FieldDescriptorProto_TYPE_FLOAT: + p.P(`var v uint32`) + p.decodeFixed32("v", "uint32") + if oneof { + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{`, typ, "(", p.mathPkg.Use(), `.Float32frombits(v))}`) + } else if repeated { + p.P(`v2 := `, typ, "(", p.mathPkg.Use(), `.Float32frombits(v))`) + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v2)`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = `, typ, "(", p.mathPkg.Use(), `.Float32frombits(v))`) + } else { + p.P(`v2 := `, typ, "(", p.mathPkg.Use(), `.Float32frombits(v))`) + p.P(`m.`, fieldname, ` = &v2`) + } + case descriptor.FieldDescriptorProto_TYPE_INT64: + if oneof { + p.P(`var v `, typ) + p.decodeVarint("v", typ) + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{v}`) + } else if repeated { + p.P(`var v `, typ) + p.decodeVarint("v", typ) + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = 0`) + p.decodeVarint("m."+fieldname, typ) + } else { + p.P(`var v `, typ) + p.decodeVarint("v", typ) + p.P(`m.`, fieldname, ` = &v`) + } + case descriptor.FieldDescriptorProto_TYPE_UINT64: + if oneof { + p.P(`var v `, typ) + p.decodeVarint("v", typ) + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{v}`) + } else if repeated { + p.P(`var v `, typ) + p.decodeVarint("v", typ) + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = 0`) + p.decodeVarint("m."+fieldname, typ) + } else { + p.P(`var v `, typ) + p.decodeVarint("v", typ) + p.P(`m.`, fieldname, ` = &v`) + } + case descriptor.FieldDescriptorProto_TYPE_INT32: + if oneof { + p.P(`var v `, typ) + p.decodeVarint("v", typ) + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{v}`) + } else if repeated { + p.P(`var v `, typ) + p.decodeVarint("v", typ) + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = 0`) + p.decodeVarint("m."+fieldname, typ) + } else { + p.P(`var v `, typ) + p.decodeVarint("v", typ) + p.P(`m.`, fieldname, ` = &v`) + } + case descriptor.FieldDescriptorProto_TYPE_FIXED64: + if oneof { + p.P(`var v `, typ) + p.decodeFixed64("v", typ) + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{v}`) + } else if repeated { + p.P(`var v `, typ) + p.decodeFixed64("v", typ) + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = 0`) + p.decodeFixed64("m."+fieldname, typ) + } else { + p.P(`var v `, typ) + p.decodeFixed64("v", typ) + p.P(`m.`, fieldname, ` = &v`) + } + case descriptor.FieldDescriptorProto_TYPE_FIXED32: + if oneof { + p.P(`var v `, typ) + p.decodeFixed32("v", typ) + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{v}`) + } else if repeated { + p.P(`var v `, typ) + p.decodeFixed32("v", typ) + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = 0`) + p.decodeFixed32("m."+fieldname, typ) + } else { + p.P(`var v `, typ) + p.decodeFixed32("v", typ) + p.P(`m.`, fieldname, ` = &v`) + } + case descriptor.FieldDescriptorProto_TYPE_BOOL: + p.P(`var v int`) + p.decodeVarint("v", "int") + if oneof { + p.P(`b := `, typ, `(v != 0)`) + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{b}`) + } else if repeated { + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, typ, `(v != 0))`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = `, typ, `(v != 0)`) + } else { + p.P(`b := `, typ, `(v != 0)`) + p.P(`m.`, fieldname, ` = &b`) + } + case descriptor.FieldDescriptorProto_TYPE_STRING: + p.P(`var stringLen uint64`) + p.decodeVarint("stringLen", "uint64") + p.P(`intStringLen := int(stringLen)`) + p.P(`if intStringLen < 0 {`) + p.In() + p.P(`return ErrInvalidLength` + p.localName) + p.Out() + p.P(`}`) + p.P(`postIndex := iNdEx + intStringLen`) + p.P(`if postIndex > l {`) + p.In() + p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + if oneof { + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{`, typ, `(dAtA[iNdEx:postIndex])}`) + } else if repeated { + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, typ, `(dAtA[iNdEx:postIndex]))`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = `, typ, `(dAtA[iNdEx:postIndex])`) + } else { + p.P(`s := `, typ, `(dAtA[iNdEx:postIndex])`) + p.P(`m.`, fieldname, ` = &s`) + } + p.P(`iNdEx = postIndex`) + case descriptor.FieldDescriptorProto_TYPE_GROUP: + panic(fmt.Errorf("unmarshaler does not support group %v", fieldname)) + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + desc := p.ObjectNamed(field.GetTypeName()) + msgname := p.TypeName(desc) + p.P(`var msglen int`) + p.decodeVarint("msglen", "int") + p.P(`if msglen < 0 {`) + p.In() + p.P(`return ErrInvalidLength` + p.localName) + p.Out() + p.P(`}`) + p.P(`postIndex := iNdEx + msglen`) + p.P(`if postIndex > l {`) + p.In() + p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + if oneof { + buf := `dAtA[iNdEx:postIndex]` + if gogoproto.IsStdTime(field) { + if nullable { + p.P(`v := new(time.Time)`) + p.P(`if err := `, p.typesPkg.Use(), `.StdTimeUnmarshal(v, `, buf, `); err != nil {`) + } else { + p.P(`v := time.Time{}`) + p.P(`if err := `, p.typesPkg.Use(), `.StdTimeUnmarshal(&v, `, buf, `); err != nil {`) + } + } else if gogoproto.IsStdDuration(field) { + if nullable { + p.P(`v := new(time.Duration)`) + p.P(`if err := `, p.typesPkg.Use(), `.StdDurationUnmarshal(v, `, buf, `); err != nil {`) + } else { + p.P(`v := time.Duration(0)`) + p.P(`if err := `, p.typesPkg.Use(), `.StdDurationUnmarshal(&v, `, buf, `); err != nil {`) + } + } else { + p.P(`v := &`, msgname, `{}`) + p.P(`if err := v.Unmarshal(`, buf, `); err != nil {`) + } + p.In() + p.P(`return err`) + p.Out() + p.P(`}`) + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{v}`) + } else if p.IsMap(field) { + m := p.GoMapType(nil, field) + + keygoTyp, _ := p.GoType(nil, m.KeyField) + keygoAliasTyp, _ := p.GoType(nil, m.KeyAliasField) + // keys may not be pointers + keygoTyp = strings.Replace(keygoTyp, "*", "", 1) + keygoAliasTyp = strings.Replace(keygoAliasTyp, "*", "", 1) + + valuegoTyp, _ := p.GoType(nil, m.ValueField) + valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField) + + // if the map type is an alias and key or values are aliases (type Foo map[Bar]Baz), + // we need to explicitly record their use here. + if gogoproto.IsCastKey(field) { + p.RecordTypeUse(m.KeyAliasField.GetTypeName()) + } + if gogoproto.IsCastValue(field) { + p.RecordTypeUse(m.ValueAliasField.GetTypeName()) + } + + nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp) + if gogoproto.IsStdTime(field) || gogoproto.IsStdDuration(field) { + valuegoTyp = valuegoAliasTyp + } + + p.P(`if m.`, fieldname, ` == nil {`) + p.In() + p.P(`m.`, fieldname, ` = make(`, m.GoType, `)`) + p.Out() + p.P(`}`) + + p.declareMapField("mapkey", false, false, m.KeyAliasField) + p.declareMapField("mapvalue", nullable, gogoproto.IsCustomType(field), m.ValueAliasField) + p.P(`for iNdEx < postIndex {`) + p.In() + + p.P(`entryPreIndex := iNdEx`) + p.P(`var wire uint64`) + p.decodeVarint("wire", "uint64") + p.P(`fieldNum := int32(wire >> 3)`) + + p.P(`if fieldNum == 1 {`) + p.In() + p.mapField("mapkey", false, m.KeyAliasField) + p.Out() + p.P(`} else if fieldNum == 2 {`) + p.In() + p.mapField("mapvalue", gogoproto.IsCustomType(field), m.ValueAliasField) + p.Out() + p.P(`} else {`) + p.In() + p.P(`iNdEx = entryPreIndex`) + p.P(`skippy, err := skip`, p.localName, `(dAtA[iNdEx:])`) + p.P(`if err != nil {`) + p.In() + p.P(`return err`) + p.Out() + p.P(`}`) + p.P(`if skippy < 0 {`) + p.In() + p.P(`return ErrInvalidLength`, p.localName) + p.Out() + p.P(`}`) + p.P(`if (iNdEx + skippy) > postIndex {`) + p.In() + p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + p.P(`iNdEx += skippy`) + p.Out() + p.P(`}`) + + p.Out() + p.P(`}`) + + s := `m.` + fieldname + if keygoTyp == keygoAliasTyp { + s += `[mapkey]` + } else { + s += `[` + keygoAliasTyp + `(mapkey)]` + } + + v := `mapvalue` + if (m.ValueField.IsMessage() || gogoproto.IsCustomType(field)) && !nullable { + v = `*` + v + } + if valuegoTyp != valuegoAliasTyp { + v = `((` + valuegoAliasTyp + `)(` + v + `))` + } + + p.P(s, ` = `, v) + } else if repeated { + if gogoproto.IsStdTime(field) { + if nullable { + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, new(time.Time))`) + } else { + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, time.Time{})`) + } + } else if gogoproto.IsStdDuration(field) { + if nullable { + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, new(time.Duration))`) + } else { + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, time.Duration(0))`) + } + } else if nullable && !gogoproto.IsCustomType(field) { + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, &`, msgname, `{})`) + } else { + goType, _ := p.GoType(nil, field) + // remove the slice from the type, i.e. []*T -> *T + goType = goType[2:] + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, goType, `{})`) + } + varName := `m.` + fieldname + `[len(m.` + fieldname + `)-1]` + buf := `dAtA[iNdEx:postIndex]` + if gogoproto.IsStdTime(field) { + if nullable { + p.P(`if err := `, p.typesPkg.Use(), `.StdTimeUnmarshal(`, varName, `,`, buf, `); err != nil {`) + } else { + p.P(`if err := `, p.typesPkg.Use(), `.StdTimeUnmarshal(&(`, varName, `),`, buf, `); err != nil {`) + } + } else if gogoproto.IsStdDuration(field) { + if nullable { + p.P(`if err := `, p.typesPkg.Use(), `.StdDurationUnmarshal(`, varName, `,`, buf, `); err != nil {`) + } else { + p.P(`if err := `, p.typesPkg.Use(), `.StdDurationUnmarshal(&(`, varName, `),`, buf, `); err != nil {`) + } + } else { + p.P(`if err := `, varName, `.Unmarshal(`, buf, `); err != nil {`) + } + p.In() + p.P(`return err`) + p.Out() + p.P(`}`) + } else if nullable { + p.P(`if m.`, fieldname, ` == nil {`) + p.In() + if gogoproto.IsStdTime(field) { + p.P(`m.`, fieldname, ` = new(time.Time)`) + } else if gogoproto.IsStdDuration(field) { + p.P(`m.`, fieldname, ` = new(time.Duration)`) + } else { + goType, _ := p.GoType(nil, field) + // remove the star from the type + p.P(`m.`, fieldname, ` = &`, goType[1:], `{}`) + } + p.Out() + p.P(`}`) + if gogoproto.IsStdTime(field) { + p.P(`if err := `, p.typesPkg.Use(), `.StdTimeUnmarshal(m.`, fieldname, `, dAtA[iNdEx:postIndex]); err != nil {`) + } else if gogoproto.IsStdDuration(field) { + p.P(`if err := `, p.typesPkg.Use(), `.StdDurationUnmarshal(m.`, fieldname, `, dAtA[iNdEx:postIndex]); err != nil {`) + } else { + p.P(`if err := m.`, fieldname, `.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {`) + } + p.In() + p.P(`return err`) + p.Out() + p.P(`}`) + } else { + if gogoproto.IsStdTime(field) { + p.P(`if err := `, p.typesPkg.Use(), `.StdTimeUnmarshal(&m.`, fieldname, `, dAtA[iNdEx:postIndex]); err != nil {`) + } else if gogoproto.IsStdDuration(field) { + p.P(`if err := `, p.typesPkg.Use(), `.StdDurationUnmarshal(&m.`, fieldname, `, dAtA[iNdEx:postIndex]); err != nil {`) + } else { + p.P(`if err := m.`, fieldname, `.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {`) + } + p.In() + p.P(`return err`) + p.Out() + p.P(`}`) + } + p.P(`iNdEx = postIndex`) + + case descriptor.FieldDescriptorProto_TYPE_BYTES: + p.P(`var byteLen int`) + p.decodeVarint("byteLen", "int") + p.P(`if byteLen < 0 {`) + p.In() + p.P(`return ErrInvalidLength` + p.localName) + p.Out() + p.P(`}`) + p.P(`postIndex := iNdEx + byteLen`) + p.P(`if postIndex > l {`) + p.In() + p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + if !gogoproto.IsCustomType(field) { + if oneof { + p.P(`v := make([]byte, postIndex-iNdEx)`) + p.P(`copy(v, dAtA[iNdEx:postIndex])`) + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{v}`) + } else if repeated { + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, make([]byte, postIndex-iNdEx))`) + p.P(`copy(m.`, fieldname, `[len(m.`, fieldname, `)-1], dAtA[iNdEx:postIndex])`) + } else { + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `[:0] , dAtA[iNdEx:postIndex]...)`) + p.P(`if m.`, fieldname, ` == nil {`) + p.In() + p.P(`m.`, fieldname, ` = []byte{}`) + p.Out() + p.P(`}`) + } + } else { + _, ctyp, err := generator.GetCustomType(field) + if err != nil { + panic(err) + } + if oneof { + p.P(`var vv `, ctyp) + p.P(`v := &vv`) + p.P(`if err := v.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {`) + p.In() + p.P(`return err`) + p.Out() + p.P(`}`) + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{*v}`) + } else if repeated { + p.P(`var v `, ctyp) + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) + p.P(`if err := m.`, fieldname, `[len(m.`, fieldname, `)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil {`) + p.In() + p.P(`return err`) + p.Out() + p.P(`}`) + } else if nullable { + p.P(`var v `, ctyp) + p.P(`m.`, fieldname, ` = &v`) + p.P(`if err := m.`, fieldname, `.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {`) + p.In() + p.P(`return err`) + p.Out() + p.P(`}`) + } else { + p.P(`if err := m.`, fieldname, `.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {`) + p.In() + p.P(`return err`) + p.Out() + p.P(`}`) + } + } + p.P(`iNdEx = postIndex`) + case descriptor.FieldDescriptorProto_TYPE_UINT32: + if oneof { + p.P(`var v `, typ) + p.decodeVarint("v", typ) + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{v}`) + } else if repeated { + p.P(`var v `, typ) + p.decodeVarint("v", typ) + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = 0`) + p.decodeVarint("m."+fieldname, typ) + } else { + p.P(`var v `, typ) + p.decodeVarint("v", typ) + p.P(`m.`, fieldname, ` = &v`) + } + case descriptor.FieldDescriptorProto_TYPE_ENUM: + typName := p.TypeName(p.ObjectNamed(field.GetTypeName())) + if oneof { + p.P(`var v `, typName) + p.decodeVarint("v", typName) + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{v}`) + } else if repeated { + p.P(`var v `, typName) + p.decodeVarint("v", typName) + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = 0`) + p.decodeVarint("m."+fieldname, typName) + } else { + p.P(`var v `, typName) + p.decodeVarint("v", typName) + p.P(`m.`, fieldname, ` = &v`) + } + case descriptor.FieldDescriptorProto_TYPE_SFIXED32: + if oneof { + p.P(`var v `, typ) + p.decodeFixed32("v", typ) + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{v}`) + } else if repeated { + p.P(`var v `, typ) + p.decodeFixed32("v", typ) + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = 0`) + p.decodeFixed32("m."+fieldname, typ) + } else { + p.P(`var v `, typ) + p.decodeFixed32("v", typ) + p.P(`m.`, fieldname, ` = &v`) + } + case descriptor.FieldDescriptorProto_TYPE_SFIXED64: + if oneof { + p.P(`var v `, typ) + p.decodeFixed64("v", typ) + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{v}`) + } else if repeated { + p.P(`var v `, typ) + p.decodeFixed64("v", typ) + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = 0`) + p.decodeFixed64("m."+fieldname, typ) + } else { + p.P(`var v `, typ) + p.decodeFixed64("v", typ) + p.P(`m.`, fieldname, ` = &v`) + } + case descriptor.FieldDescriptorProto_TYPE_SINT32: + p.P(`var v `, typ) + p.decodeVarint("v", typ) + p.P(`v = `, typ, `((uint32(v) >> 1) ^ uint32(((v&1)<<31)>>31))`) + if oneof { + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{v}`) + } else if repeated { + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = v`) + } else { + p.P(`m.`, fieldname, ` = &v`) + } + case descriptor.FieldDescriptorProto_TYPE_SINT64: + p.P(`var v uint64`) + p.decodeVarint("v", "uint64") + p.P(`v = (v >> 1) ^ uint64((int64(v&1)<<63)>>63)`) + if oneof { + p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{`, typ, `(v)}`) + } else if repeated { + p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, typ, `(v))`) + } else if proto3 || !nullable { + p.P(`m.`, fieldname, ` = `, typ, `(v)`) + } else { + p.P(`v2 := `, typ, `(v)`) + p.P(`m.`, fieldname, ` = &v2`) + } + default: + panic("not implemented") + } +} + +func (p *unmarshal) Generate(file *generator.FileDescriptor) { + proto3 := gogoproto.IsProto3(file.FileDescriptorProto) + p.PluginImports = generator.NewPluginImports(p.Generator) + p.atleastOne = false + p.localName = generator.FileName(file) + + p.ioPkg = p.NewImport("io") + p.mathPkg = p.NewImport("math") + p.typesPkg = p.NewImport("github.com/gogo/protobuf/types") + p.binaryPkg = p.NewImport("encoding/binary") + fmtPkg := p.NewImport("fmt") + protoPkg := p.NewImport("github.com/gogo/protobuf/proto") + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + protoPkg = p.NewImport("github.com/golang/protobuf/proto") + } + + for _, message := range file.Messages() { + ccTypeName := generator.CamelCaseSlice(message.TypeName()) + if !gogoproto.IsUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) && + !gogoproto.IsUnsafeUnmarshaler(file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if message.DescriptorProto.GetOptions().GetMapEntry() { + continue + } + p.atleastOne = true + + // build a map required field_id -> bitmask offset + rfMap := make(map[int32]uint) + rfNextId := uint(0) + for _, field := range message.Field { + if field.IsRequired() { + rfMap[field.GetNumber()] = rfNextId + rfNextId++ + } + } + rfCount := len(rfMap) + + p.P(`func (m *`, ccTypeName, `) Unmarshal(dAtA []byte) error {`) + p.In() + if rfCount > 0 { + p.P(`var hasFields [`, strconv.Itoa(1+(rfCount-1)/64), `]uint64`) + } + p.P(`l := len(dAtA)`) + p.P(`iNdEx := 0`) + p.P(`for iNdEx < l {`) + p.In() + p.P(`preIndex := iNdEx`) + p.P(`var wire uint64`) + p.decodeVarint("wire", "uint64") + p.P(`fieldNum := int32(wire >> 3)`) + if len(message.Field) > 0 || !message.IsGroup() { + p.P(`wireType := int(wire & 0x7)`) + } + if !message.IsGroup() { + p.P(`if wireType == `, strconv.Itoa(proto.WireEndGroup), ` {`) + p.In() + p.P(`return `, fmtPkg.Use(), `.Errorf("proto: `+message.GetName()+`: wiretype end group for non-group")`) + p.Out() + p.P(`}`) + } + p.P(`if fieldNum <= 0 {`) + p.In() + p.P(`return `, fmtPkg.Use(), `.Errorf("proto: `+message.GetName()+`: illegal tag %d (wire type %d)", fieldNum, wire)`) + p.Out() + p.P(`}`) + p.P(`switch fieldNum {`) + p.In() + for _, field := range message.Field { + fieldname := p.GetFieldName(message, field) + errFieldname := fieldname + if field.OneofIndex != nil { + errFieldname = p.GetOneOfFieldName(message, field) + } + possiblyPacked := field.IsScalar() && field.IsRepeated() + p.P(`case `, strconv.Itoa(int(field.GetNumber())), `:`) + p.In() + wireType := field.WireType() + if possiblyPacked { + p.P(`if wireType == `, strconv.Itoa(wireType), `{`) + p.In() + p.field(file, message, field, fieldname, false) + p.Out() + p.P(`} else if wireType == `, strconv.Itoa(proto.WireBytes), `{`) + p.In() + p.P(`var packedLen int`) + p.decodeVarint("packedLen", "int") + p.P(`if packedLen < 0 {`) + p.In() + p.P(`return ErrInvalidLength` + p.localName) + p.Out() + p.P(`}`) + p.P(`postIndex := iNdEx + packedLen`) + p.P(`if postIndex > l {`) + p.In() + p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + p.P(`for iNdEx < postIndex {`) + p.In() + p.field(file, message, field, fieldname, false) + p.Out() + p.P(`}`) + p.Out() + p.P(`} else {`) + p.In() + p.P(`return ` + fmtPkg.Use() + `.Errorf("proto: wrong wireType = %d for field ` + errFieldname + `", wireType)`) + p.Out() + p.P(`}`) + } else { + p.P(`if wireType != `, strconv.Itoa(wireType), `{`) + p.In() + p.P(`return ` + fmtPkg.Use() + `.Errorf("proto: wrong wireType = %d for field ` + errFieldname + `", wireType)`) + p.Out() + p.P(`}`) + p.field(file, message, field, fieldname, proto3) + } + + if field.IsRequired() { + fieldBit, ok := rfMap[field.GetNumber()] + if !ok { + panic("field is required, but no bit registered") + } + p.P(`hasFields[`, strconv.Itoa(int(fieldBit/64)), `] |= uint64(`, fmt.Sprintf("0x%08x", 1<<(fieldBit%64)), `)`) + } + } + p.Out() + p.P(`default:`) + p.In() + if message.DescriptorProto.HasExtension() { + c := []string{} + for _, erange := range message.GetExtensionRange() { + c = append(c, `((fieldNum >= `+strconv.Itoa(int(erange.GetStart()))+") && (fieldNum<"+strconv.Itoa(int(erange.GetEnd()))+`))`) + } + p.P(`if `, strings.Join(c, "||"), `{`) + p.In() + p.P(`var sizeOfWire int`) + p.P(`for {`) + p.In() + p.P(`sizeOfWire++`) + p.P(`wire >>= 7`) + p.P(`if wire == 0 {`) + p.In() + p.P(`break`) + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + p.P(`iNdEx-=sizeOfWire`) + p.P(`skippy, err := skip`, p.localName+`(dAtA[iNdEx:])`) + p.P(`if err != nil {`) + p.In() + p.P(`return err`) + p.Out() + p.P(`}`) + p.P(`if skippy < 0 {`) + p.In() + p.P(`return ErrInvalidLength`, p.localName) + p.Out() + p.P(`}`) + p.P(`if (iNdEx + skippy) > l {`) + p.In() + p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + p.P(protoPkg.Use(), `.AppendExtension(m, int32(fieldNum), dAtA[iNdEx:iNdEx+skippy])`) + p.P(`iNdEx += skippy`) + p.Out() + p.P(`} else {`) + p.In() + } + p.P(`iNdEx=preIndex`) + p.P(`skippy, err := skip`, p.localName, `(dAtA[iNdEx:])`) + p.P(`if err != nil {`) + p.In() + p.P(`return err`) + p.Out() + p.P(`}`) + p.P(`if skippy < 0 {`) + p.In() + p.P(`return ErrInvalidLength`, p.localName) + p.Out() + p.P(`}`) + p.P(`if (iNdEx + skippy) > l {`) + p.In() + p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) { + p.P(`m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...)`) + } + p.P(`iNdEx += skippy`) + p.Out() + if message.DescriptorProto.HasExtension() { + p.Out() + p.P(`}`) + } + p.Out() + p.P(`}`) + p.Out() + p.P(`}`) + + for _, field := range message.Field { + if !field.IsRequired() { + continue + } + + fieldBit, ok := rfMap[field.GetNumber()] + if !ok { + panic("field is required, but no bit registered") + } + + p.P(`if hasFields[`, strconv.Itoa(int(fieldBit/64)), `] & uint64(`, fmt.Sprintf("0x%08x", 1<<(fieldBit%64)), `) == 0 {`) + p.In() + if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + p.P(`return new(`, protoPkg.Use(), `.RequiredNotSetError)`) + } else { + p.P(`return `, protoPkg.Use(), `.NewRequiredNotSetError("`, field.GetName(), `")`) + } + p.Out() + p.P(`}`) + } + p.P() + p.P(`if iNdEx > l {`) + p.In() + p.P(`return ` + p.ioPkg.Use() + `.ErrUnexpectedEOF`) + p.Out() + p.P(`}`) + p.P(`return nil`) + p.Out() + p.P(`}`) + } + if !p.atleastOne { + return + } + + p.P(`func skip` + p.localName + `(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflow` + p.localName + ` + } + if iNdEx >= l { + return 0, ` + p.ioPkg.Use() + `.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflow` + p.localName + ` + } + if iNdEx >= l { + return 0, ` + p.ioPkg.Use() + `.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflow` + p.localName + ` + } + if iNdEx >= l { + return 0, ` + p.ioPkg.Use() + `.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLength` + p.localName + ` + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflow` + p.localName + ` + } + if iNdEx >= l { + return 0, ` + p.ioPkg.Use() + `.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skip` + p.localName + `(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, ` + fmtPkg.Use() + `.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") + } + + var ( + ErrInvalidLength` + p.localName + ` = ` + fmtPkg.Use() + `.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflow` + p.localName + ` = ` + fmtPkg.Use() + `.Errorf("proto: integer overflow") + ) + `) +} + +func init() { + generator.RegisterPlugin(NewUnmarshal()) +} diff --git a/vendor/github.com/gogo/protobuf/protoc-gen-gogo/descriptor/descriptor.go b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/descriptor/descriptor.go new file mode 100644 index 00000000000..a85bf1984c6 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/descriptor/descriptor.go @@ -0,0 +1,118 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2016 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Package descriptor provides functions for obtaining protocol buffer +// descriptors for generated Go types. +// +// These functions cannot go in package proto because they depend on the +// generated protobuf descriptor messages, which themselves depend on proto. +package descriptor + +import ( + "bytes" + "compress/gzip" + "fmt" + "io/ioutil" + + "github.com/gogo/protobuf/proto" +) + +// extractFile extracts a FileDescriptorProto from a gzip'd buffer. +func extractFile(gz []byte) (*FileDescriptorProto, error) { + r, err := gzip.NewReader(bytes.NewReader(gz)) + if err != nil { + return nil, fmt.Errorf("failed to open gzip reader: %v", err) + } + defer r.Close() + + b, err := ioutil.ReadAll(r) + if err != nil { + return nil, fmt.Errorf("failed to uncompress descriptor: %v", err) + } + + fd := new(FileDescriptorProto) + if err := proto.Unmarshal(b, fd); err != nil { + return nil, fmt.Errorf("malformed FileDescriptorProto: %v", err) + } + + return fd, nil +} + +// Message is a proto.Message with a method to return its descriptor. +// +// Message types generated by the protocol compiler always satisfy +// the Message interface. +type Message interface { + proto.Message + Descriptor() ([]byte, []int) +} + +// ForMessage returns a FileDescriptorProto and a DescriptorProto from within it +// describing the given message. +func ForMessage(msg Message) (fd *FileDescriptorProto, md *DescriptorProto) { + gz, path := msg.Descriptor() + fd, err := extractFile(gz) + if err != nil { + panic(fmt.Sprintf("invalid FileDescriptorProto for %T: %v", msg, err)) + } + + md = fd.MessageType[path[0]] + for _, i := range path[1:] { + md = md.NestedType[i] + } + return fd, md +} + +// Is this field a scalar numeric type? +func (field *FieldDescriptorProto) IsScalar() bool { + if field.Type == nil { + return false + } + switch *field.Type { + case FieldDescriptorProto_TYPE_DOUBLE, + FieldDescriptorProto_TYPE_FLOAT, + FieldDescriptorProto_TYPE_INT64, + FieldDescriptorProto_TYPE_UINT64, + FieldDescriptorProto_TYPE_INT32, + FieldDescriptorProto_TYPE_FIXED64, + FieldDescriptorProto_TYPE_FIXED32, + FieldDescriptorProto_TYPE_BOOL, + FieldDescriptorProto_TYPE_UINT32, + FieldDescriptorProto_TYPE_ENUM, + FieldDescriptorProto_TYPE_SFIXED32, + FieldDescriptorProto_TYPE_SFIXED64, + FieldDescriptorProto_TYPE_SINT32, + FieldDescriptorProto_TYPE_SINT64: + return true + default: + return false + } +} diff --git a/vendor/github.com/gogo/protobuf/protoc-gen-gogo/descriptor/descriptor.pb.go b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/descriptor/descriptor.pb.go new file mode 100644 index 00000000000..4174cbd9f3d --- /dev/null +++ b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/descriptor/descriptor.pb.go @@ -0,0 +1,2280 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: descriptor.proto + +/* +Package descriptor is a generated protocol buffer package. + +It is generated from these files: + descriptor.proto + +It has these top-level messages: + FileDescriptorSet + FileDescriptorProto + DescriptorProto + ExtensionRangeOptions + FieldDescriptorProto + OneofDescriptorProto + EnumDescriptorProto + EnumValueDescriptorProto + ServiceDescriptorProto + MethodDescriptorProto + FileOptions + MessageOptions + FieldOptions + OneofOptions + EnumOptions + EnumValueOptions + ServiceOptions + MethodOptions + UninterpretedOption + SourceCodeInfo + GeneratedCodeInfo +*/ +package descriptor + +import proto "github.com/gogo/protobuf/proto" +import fmt "fmt" +import math "math" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package + +type FieldDescriptorProto_Type int32 + +const ( + // 0 is reserved for errors. + // Order is weird for historical reasons. + FieldDescriptorProto_TYPE_DOUBLE FieldDescriptorProto_Type = 1 + FieldDescriptorProto_TYPE_FLOAT FieldDescriptorProto_Type = 2 + // Not ZigZag encoded. Negative numbers take 10 bytes. Use TYPE_SINT64 if + // negative values are likely. + FieldDescriptorProto_TYPE_INT64 FieldDescriptorProto_Type = 3 + FieldDescriptorProto_TYPE_UINT64 FieldDescriptorProto_Type = 4 + // Not ZigZag encoded. Negative numbers take 10 bytes. Use TYPE_SINT32 if + // negative values are likely. + FieldDescriptorProto_TYPE_INT32 FieldDescriptorProto_Type = 5 + FieldDescriptorProto_TYPE_FIXED64 FieldDescriptorProto_Type = 6 + FieldDescriptorProto_TYPE_FIXED32 FieldDescriptorProto_Type = 7 + FieldDescriptorProto_TYPE_BOOL FieldDescriptorProto_Type = 8 + FieldDescriptorProto_TYPE_STRING FieldDescriptorProto_Type = 9 + // Tag-delimited aggregate. + // Group type is deprecated and not supported in proto3. However, Proto3 + // implementations should still be able to parse the group wire format and + // treat group fields as unknown fields. + FieldDescriptorProto_TYPE_GROUP FieldDescriptorProto_Type = 10 + FieldDescriptorProto_TYPE_MESSAGE FieldDescriptorProto_Type = 11 + // New in version 2. + FieldDescriptorProto_TYPE_BYTES FieldDescriptorProto_Type = 12 + FieldDescriptorProto_TYPE_UINT32 FieldDescriptorProto_Type = 13 + FieldDescriptorProto_TYPE_ENUM FieldDescriptorProto_Type = 14 + FieldDescriptorProto_TYPE_SFIXED32 FieldDescriptorProto_Type = 15 + FieldDescriptorProto_TYPE_SFIXED64 FieldDescriptorProto_Type = 16 + FieldDescriptorProto_TYPE_SINT32 FieldDescriptorProto_Type = 17 + FieldDescriptorProto_TYPE_SINT64 FieldDescriptorProto_Type = 18 +) + +var FieldDescriptorProto_Type_name = map[int32]string{ + 1: "TYPE_DOUBLE", + 2: "TYPE_FLOAT", + 3: "TYPE_INT64", + 4: "TYPE_UINT64", + 5: "TYPE_INT32", + 6: "TYPE_FIXED64", + 7: "TYPE_FIXED32", + 8: "TYPE_BOOL", + 9: "TYPE_STRING", + 10: "TYPE_GROUP", + 11: "TYPE_MESSAGE", + 12: "TYPE_BYTES", + 13: "TYPE_UINT32", + 14: "TYPE_ENUM", + 15: "TYPE_SFIXED32", + 16: "TYPE_SFIXED64", + 17: "TYPE_SINT32", + 18: "TYPE_SINT64", +} +var FieldDescriptorProto_Type_value = map[string]int32{ + "TYPE_DOUBLE": 1, + "TYPE_FLOAT": 2, + "TYPE_INT64": 3, + "TYPE_UINT64": 4, + "TYPE_INT32": 5, + "TYPE_FIXED64": 6, + "TYPE_FIXED32": 7, + "TYPE_BOOL": 8, + "TYPE_STRING": 9, + "TYPE_GROUP": 10, + "TYPE_MESSAGE": 11, + "TYPE_BYTES": 12, + "TYPE_UINT32": 13, + "TYPE_ENUM": 14, + "TYPE_SFIXED32": 15, + "TYPE_SFIXED64": 16, + "TYPE_SINT32": 17, + "TYPE_SINT64": 18, +} + +func (x FieldDescriptorProto_Type) Enum() *FieldDescriptorProto_Type { + p := new(FieldDescriptorProto_Type) + *p = x + return p +} +func (x FieldDescriptorProto_Type) String() string { + return proto.EnumName(FieldDescriptorProto_Type_name, int32(x)) +} +func (x *FieldDescriptorProto_Type) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(FieldDescriptorProto_Type_value, data, "FieldDescriptorProto_Type") + if err != nil { + return err + } + *x = FieldDescriptorProto_Type(value) + return nil +} +func (FieldDescriptorProto_Type) EnumDescriptor() ([]byte, []int) { + return fileDescriptorDescriptor, []int{4, 0} +} + +type FieldDescriptorProto_Label int32 + +const ( + // 0 is reserved for errors + FieldDescriptorProto_LABEL_OPTIONAL FieldDescriptorProto_Label = 1 + FieldDescriptorProto_LABEL_REQUIRED FieldDescriptorProto_Label = 2 + FieldDescriptorProto_LABEL_REPEATED FieldDescriptorProto_Label = 3 +) + +var FieldDescriptorProto_Label_name = map[int32]string{ + 1: "LABEL_OPTIONAL", + 2: "LABEL_REQUIRED", + 3: "LABEL_REPEATED", +} +var FieldDescriptorProto_Label_value = map[string]int32{ + "LABEL_OPTIONAL": 1, + "LABEL_REQUIRED": 2, + "LABEL_REPEATED": 3, +} + +func (x FieldDescriptorProto_Label) Enum() *FieldDescriptorProto_Label { + p := new(FieldDescriptorProto_Label) + *p = x + return p +} +func (x FieldDescriptorProto_Label) String() string { + return proto.EnumName(FieldDescriptorProto_Label_name, int32(x)) +} +func (x *FieldDescriptorProto_Label) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(FieldDescriptorProto_Label_value, data, "FieldDescriptorProto_Label") + if err != nil { + return err + } + *x = FieldDescriptorProto_Label(value) + return nil +} +func (FieldDescriptorProto_Label) EnumDescriptor() ([]byte, []int) { + return fileDescriptorDescriptor, []int{4, 1} +} + +// Generated classes can be optimized for speed or code size. +type FileOptions_OptimizeMode int32 + +const ( + FileOptions_SPEED FileOptions_OptimizeMode = 1 + // etc. + FileOptions_CODE_SIZE FileOptions_OptimizeMode = 2 + FileOptions_LITE_RUNTIME FileOptions_OptimizeMode = 3 +) + +var FileOptions_OptimizeMode_name = map[int32]string{ + 1: "SPEED", + 2: "CODE_SIZE", + 3: "LITE_RUNTIME", +} +var FileOptions_OptimizeMode_value = map[string]int32{ + "SPEED": 1, + "CODE_SIZE": 2, + "LITE_RUNTIME": 3, +} + +func (x FileOptions_OptimizeMode) Enum() *FileOptions_OptimizeMode { + p := new(FileOptions_OptimizeMode) + *p = x + return p +} +func (x FileOptions_OptimizeMode) String() string { + return proto.EnumName(FileOptions_OptimizeMode_name, int32(x)) +} +func (x *FileOptions_OptimizeMode) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(FileOptions_OptimizeMode_value, data, "FileOptions_OptimizeMode") + if err != nil { + return err + } + *x = FileOptions_OptimizeMode(value) + return nil +} +func (FileOptions_OptimizeMode) EnumDescriptor() ([]byte, []int) { + return fileDescriptorDescriptor, []int{10, 0} +} + +type FieldOptions_CType int32 + +const ( + // Default mode. + FieldOptions_STRING FieldOptions_CType = 0 + FieldOptions_CORD FieldOptions_CType = 1 + FieldOptions_STRING_PIECE FieldOptions_CType = 2 +) + +var FieldOptions_CType_name = map[int32]string{ + 0: "STRING", + 1: "CORD", + 2: "STRING_PIECE", +} +var FieldOptions_CType_value = map[string]int32{ + "STRING": 0, + "CORD": 1, + "STRING_PIECE": 2, +} + +func (x FieldOptions_CType) Enum() *FieldOptions_CType { + p := new(FieldOptions_CType) + *p = x + return p +} +func (x FieldOptions_CType) String() string { + return proto.EnumName(FieldOptions_CType_name, int32(x)) +} +func (x *FieldOptions_CType) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(FieldOptions_CType_value, data, "FieldOptions_CType") + if err != nil { + return err + } + *x = FieldOptions_CType(value) + return nil +} +func (FieldOptions_CType) EnumDescriptor() ([]byte, []int) { + return fileDescriptorDescriptor, []int{12, 0} +} + +type FieldOptions_JSType int32 + +const ( + // Use the default type. + FieldOptions_JS_NORMAL FieldOptions_JSType = 0 + // Use JavaScript strings. + FieldOptions_JS_STRING FieldOptions_JSType = 1 + // Use JavaScript numbers. + FieldOptions_JS_NUMBER FieldOptions_JSType = 2 +) + +var FieldOptions_JSType_name = map[int32]string{ + 0: "JS_NORMAL", + 1: "JS_STRING", + 2: "JS_NUMBER", +} +var FieldOptions_JSType_value = map[string]int32{ + "JS_NORMAL": 0, + "JS_STRING": 1, + "JS_NUMBER": 2, +} + +func (x FieldOptions_JSType) Enum() *FieldOptions_JSType { + p := new(FieldOptions_JSType) + *p = x + return p +} +func (x FieldOptions_JSType) String() string { + return proto.EnumName(FieldOptions_JSType_name, int32(x)) +} +func (x *FieldOptions_JSType) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(FieldOptions_JSType_value, data, "FieldOptions_JSType") + if err != nil { + return err + } + *x = FieldOptions_JSType(value) + return nil +} +func (FieldOptions_JSType) EnumDescriptor() ([]byte, []int) { + return fileDescriptorDescriptor, []int{12, 1} +} + +// Is this method side-effect-free (or safe in HTTP parlance), or idempotent, +// or neither? HTTP based RPC implementation may choose GET verb for safe +// methods, and PUT verb for idempotent methods instead of the default POST. +type MethodOptions_IdempotencyLevel int32 + +const ( + MethodOptions_IDEMPOTENCY_UNKNOWN MethodOptions_IdempotencyLevel = 0 + MethodOptions_NO_SIDE_EFFECTS MethodOptions_IdempotencyLevel = 1 + MethodOptions_IDEMPOTENT MethodOptions_IdempotencyLevel = 2 +) + +var MethodOptions_IdempotencyLevel_name = map[int32]string{ + 0: "IDEMPOTENCY_UNKNOWN", + 1: "NO_SIDE_EFFECTS", + 2: "IDEMPOTENT", +} +var MethodOptions_IdempotencyLevel_value = map[string]int32{ + "IDEMPOTENCY_UNKNOWN": 0, + "NO_SIDE_EFFECTS": 1, + "IDEMPOTENT": 2, +} + +func (x MethodOptions_IdempotencyLevel) Enum() *MethodOptions_IdempotencyLevel { + p := new(MethodOptions_IdempotencyLevel) + *p = x + return p +} +func (x MethodOptions_IdempotencyLevel) String() string { + return proto.EnumName(MethodOptions_IdempotencyLevel_name, int32(x)) +} +func (x *MethodOptions_IdempotencyLevel) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(MethodOptions_IdempotencyLevel_value, data, "MethodOptions_IdempotencyLevel") + if err != nil { + return err + } + *x = MethodOptions_IdempotencyLevel(value) + return nil +} +func (MethodOptions_IdempotencyLevel) EnumDescriptor() ([]byte, []int) { + return fileDescriptorDescriptor, []int{17, 0} +} + +// The protocol compiler can output a FileDescriptorSet containing the .proto +// files it parses. +type FileDescriptorSet struct { + File []*FileDescriptorProto `protobuf:"bytes,1,rep,name=file" json:"file,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *FileDescriptorSet) Reset() { *m = FileDescriptorSet{} } +func (m *FileDescriptorSet) String() string { return proto.CompactTextString(m) } +func (*FileDescriptorSet) ProtoMessage() {} +func (*FileDescriptorSet) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{0} } + +func (m *FileDescriptorSet) GetFile() []*FileDescriptorProto { + if m != nil { + return m.File + } + return nil +} + +// Describes a complete .proto file. +type FileDescriptorProto struct { + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Package *string `protobuf:"bytes,2,opt,name=package" json:"package,omitempty"` + // Names of files imported by this file. + Dependency []string `protobuf:"bytes,3,rep,name=dependency" json:"dependency,omitempty"` + // Indexes of the public imported files in the dependency list above. + PublicDependency []int32 `protobuf:"varint,10,rep,name=public_dependency,json=publicDependency" json:"public_dependency,omitempty"` + // Indexes of the weak imported files in the dependency list. + // For Google-internal migration only. Do not use. + WeakDependency []int32 `protobuf:"varint,11,rep,name=weak_dependency,json=weakDependency" json:"weak_dependency,omitempty"` + // All top-level definitions in this file. + MessageType []*DescriptorProto `protobuf:"bytes,4,rep,name=message_type,json=messageType" json:"message_type,omitempty"` + EnumType []*EnumDescriptorProto `protobuf:"bytes,5,rep,name=enum_type,json=enumType" json:"enum_type,omitempty"` + Service []*ServiceDescriptorProto `protobuf:"bytes,6,rep,name=service" json:"service,omitempty"` + Extension []*FieldDescriptorProto `protobuf:"bytes,7,rep,name=extension" json:"extension,omitempty"` + Options *FileOptions `protobuf:"bytes,8,opt,name=options" json:"options,omitempty"` + // This field contains optional information about the original source code. + // You may safely remove this entire field without harming runtime + // functionality of the descriptors -- the information is needed only by + // development tools. + SourceCodeInfo *SourceCodeInfo `protobuf:"bytes,9,opt,name=source_code_info,json=sourceCodeInfo" json:"source_code_info,omitempty"` + // The syntax of the proto file. + // The supported values are "proto2" and "proto3". + Syntax *string `protobuf:"bytes,12,opt,name=syntax" json:"syntax,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *FileDescriptorProto) Reset() { *m = FileDescriptorProto{} } +func (m *FileDescriptorProto) String() string { return proto.CompactTextString(m) } +func (*FileDescriptorProto) ProtoMessage() {} +func (*FileDescriptorProto) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{1} } + +func (m *FileDescriptorProto) GetName() string { + if m != nil && m.Name != nil { + return *m.Name + } + return "" +} + +func (m *FileDescriptorProto) GetPackage() string { + if m != nil && m.Package != nil { + return *m.Package + } + return "" +} + +func (m *FileDescriptorProto) GetDependency() []string { + if m != nil { + return m.Dependency + } + return nil +} + +func (m *FileDescriptorProto) GetPublicDependency() []int32 { + if m != nil { + return m.PublicDependency + } + return nil +} + +func (m *FileDescriptorProto) GetWeakDependency() []int32 { + if m != nil { + return m.WeakDependency + } + return nil +} + +func (m *FileDescriptorProto) GetMessageType() []*DescriptorProto { + if m != nil { + return m.MessageType + } + return nil +} + +func (m *FileDescriptorProto) GetEnumType() []*EnumDescriptorProto { + if m != nil { + return m.EnumType + } + return nil +} + +func (m *FileDescriptorProto) GetService() []*ServiceDescriptorProto { + if m != nil { + return m.Service + } + return nil +} + +func (m *FileDescriptorProto) GetExtension() []*FieldDescriptorProto { + if m != nil { + return m.Extension + } + return nil +} + +func (m *FileDescriptorProto) GetOptions() *FileOptions { + if m != nil { + return m.Options + } + return nil +} + +func (m *FileDescriptorProto) GetSourceCodeInfo() *SourceCodeInfo { + if m != nil { + return m.SourceCodeInfo + } + return nil +} + +func (m *FileDescriptorProto) GetSyntax() string { + if m != nil && m.Syntax != nil { + return *m.Syntax + } + return "" +} + +// Describes a message type. +type DescriptorProto struct { + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Field []*FieldDescriptorProto `protobuf:"bytes,2,rep,name=field" json:"field,omitempty"` + Extension []*FieldDescriptorProto `protobuf:"bytes,6,rep,name=extension" json:"extension,omitempty"` + NestedType []*DescriptorProto `protobuf:"bytes,3,rep,name=nested_type,json=nestedType" json:"nested_type,omitempty"` + EnumType []*EnumDescriptorProto `protobuf:"bytes,4,rep,name=enum_type,json=enumType" json:"enum_type,omitempty"` + ExtensionRange []*DescriptorProto_ExtensionRange `protobuf:"bytes,5,rep,name=extension_range,json=extensionRange" json:"extension_range,omitempty"` + OneofDecl []*OneofDescriptorProto `protobuf:"bytes,8,rep,name=oneof_decl,json=oneofDecl" json:"oneof_decl,omitempty"` + Options *MessageOptions `protobuf:"bytes,7,opt,name=options" json:"options,omitempty"` + ReservedRange []*DescriptorProto_ReservedRange `protobuf:"bytes,9,rep,name=reserved_range,json=reservedRange" json:"reserved_range,omitempty"` + // Reserved field names, which may not be used by fields in the same message. + // A given name may only be reserved once. + ReservedName []string `protobuf:"bytes,10,rep,name=reserved_name,json=reservedName" json:"reserved_name,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *DescriptorProto) Reset() { *m = DescriptorProto{} } +func (m *DescriptorProto) String() string { return proto.CompactTextString(m) } +func (*DescriptorProto) ProtoMessage() {} +func (*DescriptorProto) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{2} } + +func (m *DescriptorProto) GetName() string { + if m != nil && m.Name != nil { + return *m.Name + } + return "" +} + +func (m *DescriptorProto) GetField() []*FieldDescriptorProto { + if m != nil { + return m.Field + } + return nil +} + +func (m *DescriptorProto) GetExtension() []*FieldDescriptorProto { + if m != nil { + return m.Extension + } + return nil +} + +func (m *DescriptorProto) GetNestedType() []*DescriptorProto { + if m != nil { + return m.NestedType + } + return nil +} + +func (m *DescriptorProto) GetEnumType() []*EnumDescriptorProto { + if m != nil { + return m.EnumType + } + return nil +} + +func (m *DescriptorProto) GetExtensionRange() []*DescriptorProto_ExtensionRange { + if m != nil { + return m.ExtensionRange + } + return nil +} + +func (m *DescriptorProto) GetOneofDecl() []*OneofDescriptorProto { + if m != nil { + return m.OneofDecl + } + return nil +} + +func (m *DescriptorProto) GetOptions() *MessageOptions { + if m != nil { + return m.Options + } + return nil +} + +func (m *DescriptorProto) GetReservedRange() []*DescriptorProto_ReservedRange { + if m != nil { + return m.ReservedRange + } + return nil +} + +func (m *DescriptorProto) GetReservedName() []string { + if m != nil { + return m.ReservedName + } + return nil +} + +type DescriptorProto_ExtensionRange struct { + Start *int32 `protobuf:"varint,1,opt,name=start" json:"start,omitempty"` + End *int32 `protobuf:"varint,2,opt,name=end" json:"end,omitempty"` + Options *ExtensionRangeOptions `protobuf:"bytes,3,opt,name=options" json:"options,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *DescriptorProto_ExtensionRange) Reset() { *m = DescriptorProto_ExtensionRange{} } +func (m *DescriptorProto_ExtensionRange) String() string { return proto.CompactTextString(m) } +func (*DescriptorProto_ExtensionRange) ProtoMessage() {} +func (*DescriptorProto_ExtensionRange) Descriptor() ([]byte, []int) { + return fileDescriptorDescriptor, []int{2, 0} +} + +func (m *DescriptorProto_ExtensionRange) GetStart() int32 { + if m != nil && m.Start != nil { + return *m.Start + } + return 0 +} + +func (m *DescriptorProto_ExtensionRange) GetEnd() int32 { + if m != nil && m.End != nil { + return *m.End + } + return 0 +} + +func (m *DescriptorProto_ExtensionRange) GetOptions() *ExtensionRangeOptions { + if m != nil { + return m.Options + } + return nil +} + +// Range of reserved tag numbers. Reserved tag numbers may not be used by +// fields or extension ranges in the same message. Reserved ranges may +// not overlap. +type DescriptorProto_ReservedRange struct { + Start *int32 `protobuf:"varint,1,opt,name=start" json:"start,omitempty"` + End *int32 `protobuf:"varint,2,opt,name=end" json:"end,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *DescriptorProto_ReservedRange) Reset() { *m = DescriptorProto_ReservedRange{} } +func (m *DescriptorProto_ReservedRange) String() string { return proto.CompactTextString(m) } +func (*DescriptorProto_ReservedRange) ProtoMessage() {} +func (*DescriptorProto_ReservedRange) Descriptor() ([]byte, []int) { + return fileDescriptorDescriptor, []int{2, 1} +} + +func (m *DescriptorProto_ReservedRange) GetStart() int32 { + if m != nil && m.Start != nil { + return *m.Start + } + return 0 +} + +func (m *DescriptorProto_ReservedRange) GetEnd() int32 { + if m != nil && m.End != nil { + return *m.End + } + return 0 +} + +type ExtensionRangeOptions struct { + // The parser stores options it doesn't recognize here. See above. + UninterpretedOption []*UninterpretedOption `protobuf:"bytes,999,rep,name=uninterpreted_option,json=uninterpretedOption" json:"uninterpreted_option,omitempty"` + proto.XXX_InternalExtensions `json:"-"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *ExtensionRangeOptions) Reset() { *m = ExtensionRangeOptions{} } +func (m *ExtensionRangeOptions) String() string { return proto.CompactTextString(m) } +func (*ExtensionRangeOptions) ProtoMessage() {} +func (*ExtensionRangeOptions) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{3} } + +var extRange_ExtensionRangeOptions = []proto.ExtensionRange{ + {Start: 1000, End: 536870911}, +} + +func (*ExtensionRangeOptions) ExtensionRangeArray() []proto.ExtensionRange { + return extRange_ExtensionRangeOptions +} + +func (m *ExtensionRangeOptions) GetUninterpretedOption() []*UninterpretedOption { + if m != nil { + return m.UninterpretedOption + } + return nil +} + +// Describes a field within a message. +type FieldDescriptorProto struct { + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Number *int32 `protobuf:"varint,3,opt,name=number" json:"number,omitempty"` + Label *FieldDescriptorProto_Label `protobuf:"varint,4,opt,name=label,enum=google.protobuf.FieldDescriptorProto_Label" json:"label,omitempty"` + // If type_name is set, this need not be set. If both this and type_name + // are set, this must be one of TYPE_ENUM, TYPE_MESSAGE or TYPE_GROUP. + Type *FieldDescriptorProto_Type `protobuf:"varint,5,opt,name=type,enum=google.protobuf.FieldDescriptorProto_Type" json:"type,omitempty"` + // For message and enum types, this is the name of the type. If the name + // starts with a '.', it is fully-qualified. Otherwise, C++-like scoping + // rules are used to find the type (i.e. first the nested types within this + // message are searched, then within the parent, on up to the root + // namespace). + TypeName *string `protobuf:"bytes,6,opt,name=type_name,json=typeName" json:"type_name,omitempty"` + // For extensions, this is the name of the type being extended. It is + // resolved in the same manner as type_name. + Extendee *string `protobuf:"bytes,2,opt,name=extendee" json:"extendee,omitempty"` + // For numeric types, contains the original text representation of the value. + // For booleans, "true" or "false". + // For strings, contains the default text contents (not escaped in any way). + // For bytes, contains the C escaped value. All bytes >= 128 are escaped. + // TODO(kenton): Base-64 encode? + DefaultValue *string `protobuf:"bytes,7,opt,name=default_value,json=defaultValue" json:"default_value,omitempty"` + // If set, gives the index of a oneof in the containing type's oneof_decl + // list. This field is a member of that oneof. + OneofIndex *int32 `protobuf:"varint,9,opt,name=oneof_index,json=oneofIndex" json:"oneof_index,omitempty"` + // JSON name of this field. The value is set by protocol compiler. If the + // user has set a "json_name" option on this field, that option's value + // will be used. Otherwise, it's deduced from the field's name by converting + // it to camelCase. + JsonName *string `protobuf:"bytes,10,opt,name=json_name,json=jsonName" json:"json_name,omitempty"` + Options *FieldOptions `protobuf:"bytes,8,opt,name=options" json:"options,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *FieldDescriptorProto) Reset() { *m = FieldDescriptorProto{} } +func (m *FieldDescriptorProto) String() string { return proto.CompactTextString(m) } +func (*FieldDescriptorProto) ProtoMessage() {} +func (*FieldDescriptorProto) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{4} } + +func (m *FieldDescriptorProto) GetName() string { + if m != nil && m.Name != nil { + return *m.Name + } + return "" +} + +func (m *FieldDescriptorProto) GetNumber() int32 { + if m != nil && m.Number != nil { + return *m.Number + } + return 0 +} + +func (m *FieldDescriptorProto) GetLabel() FieldDescriptorProto_Label { + if m != nil && m.Label != nil { + return *m.Label + } + return FieldDescriptorProto_LABEL_OPTIONAL +} + +func (m *FieldDescriptorProto) GetType() FieldDescriptorProto_Type { + if m != nil && m.Type != nil { + return *m.Type + } + return FieldDescriptorProto_TYPE_DOUBLE +} + +func (m *FieldDescriptorProto) GetTypeName() string { + if m != nil && m.TypeName != nil { + return *m.TypeName + } + return "" +} + +func (m *FieldDescriptorProto) GetExtendee() string { + if m != nil && m.Extendee != nil { + return *m.Extendee + } + return "" +} + +func (m *FieldDescriptorProto) GetDefaultValue() string { + if m != nil && m.DefaultValue != nil { + return *m.DefaultValue + } + return "" +} + +func (m *FieldDescriptorProto) GetOneofIndex() int32 { + if m != nil && m.OneofIndex != nil { + return *m.OneofIndex + } + return 0 +} + +func (m *FieldDescriptorProto) GetJsonName() string { + if m != nil && m.JsonName != nil { + return *m.JsonName + } + return "" +} + +func (m *FieldDescriptorProto) GetOptions() *FieldOptions { + if m != nil { + return m.Options + } + return nil +} + +// Describes a oneof. +type OneofDescriptorProto struct { + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Options *OneofOptions `protobuf:"bytes,2,opt,name=options" json:"options,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *OneofDescriptorProto) Reset() { *m = OneofDescriptorProto{} } +func (m *OneofDescriptorProto) String() string { return proto.CompactTextString(m) } +func (*OneofDescriptorProto) ProtoMessage() {} +func (*OneofDescriptorProto) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{5} } + +func (m *OneofDescriptorProto) GetName() string { + if m != nil && m.Name != nil { + return *m.Name + } + return "" +} + +func (m *OneofDescriptorProto) GetOptions() *OneofOptions { + if m != nil { + return m.Options + } + return nil +} + +// Describes an enum type. +type EnumDescriptorProto struct { + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Value []*EnumValueDescriptorProto `protobuf:"bytes,2,rep,name=value" json:"value,omitempty"` + Options *EnumOptions `protobuf:"bytes,3,opt,name=options" json:"options,omitempty"` + // Range of reserved numeric values. Reserved numeric values may not be used + // by enum values in the same enum declaration. Reserved ranges may not + // overlap. + ReservedRange []*EnumDescriptorProto_EnumReservedRange `protobuf:"bytes,4,rep,name=reserved_range,json=reservedRange" json:"reserved_range,omitempty"` + // Reserved enum value names, which may not be reused. A given name may only + // be reserved once. + ReservedName []string `protobuf:"bytes,5,rep,name=reserved_name,json=reservedName" json:"reserved_name,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *EnumDescriptorProto) Reset() { *m = EnumDescriptorProto{} } +func (m *EnumDescriptorProto) String() string { return proto.CompactTextString(m) } +func (*EnumDescriptorProto) ProtoMessage() {} +func (*EnumDescriptorProto) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{6} } + +func (m *EnumDescriptorProto) GetName() string { + if m != nil && m.Name != nil { + return *m.Name + } + return "" +} + +func (m *EnumDescriptorProto) GetValue() []*EnumValueDescriptorProto { + if m != nil { + return m.Value + } + return nil +} + +func (m *EnumDescriptorProto) GetOptions() *EnumOptions { + if m != nil { + return m.Options + } + return nil +} + +func (m *EnumDescriptorProto) GetReservedRange() []*EnumDescriptorProto_EnumReservedRange { + if m != nil { + return m.ReservedRange + } + return nil +} + +func (m *EnumDescriptorProto) GetReservedName() []string { + if m != nil { + return m.ReservedName + } + return nil +} + +// Range of reserved numeric values. Reserved values may not be used by +// entries in the same enum. Reserved ranges may not overlap. +// +// Note that this is distinct from DescriptorProto.ReservedRange in that it +// is inclusive such that it can appropriately represent the entire int32 +// domain. +type EnumDescriptorProto_EnumReservedRange struct { + Start *int32 `protobuf:"varint,1,opt,name=start" json:"start,omitempty"` + End *int32 `protobuf:"varint,2,opt,name=end" json:"end,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *EnumDescriptorProto_EnumReservedRange) Reset() { *m = EnumDescriptorProto_EnumReservedRange{} } +func (m *EnumDescriptorProto_EnumReservedRange) String() string { return proto.CompactTextString(m) } +func (*EnumDescriptorProto_EnumReservedRange) ProtoMessage() {} +func (*EnumDescriptorProto_EnumReservedRange) Descriptor() ([]byte, []int) { + return fileDescriptorDescriptor, []int{6, 0} +} + +func (m *EnumDescriptorProto_EnumReservedRange) GetStart() int32 { + if m != nil && m.Start != nil { + return *m.Start + } + return 0 +} + +func (m *EnumDescriptorProto_EnumReservedRange) GetEnd() int32 { + if m != nil && m.End != nil { + return *m.End + } + return 0 +} + +// Describes a value within an enum. +type EnumValueDescriptorProto struct { + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Number *int32 `protobuf:"varint,2,opt,name=number" json:"number,omitempty"` + Options *EnumValueOptions `protobuf:"bytes,3,opt,name=options" json:"options,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *EnumValueDescriptorProto) Reset() { *m = EnumValueDescriptorProto{} } +func (m *EnumValueDescriptorProto) String() string { return proto.CompactTextString(m) } +func (*EnumValueDescriptorProto) ProtoMessage() {} +func (*EnumValueDescriptorProto) Descriptor() ([]byte, []int) { + return fileDescriptorDescriptor, []int{7} +} + +func (m *EnumValueDescriptorProto) GetName() string { + if m != nil && m.Name != nil { + return *m.Name + } + return "" +} + +func (m *EnumValueDescriptorProto) GetNumber() int32 { + if m != nil && m.Number != nil { + return *m.Number + } + return 0 +} + +func (m *EnumValueDescriptorProto) GetOptions() *EnumValueOptions { + if m != nil { + return m.Options + } + return nil +} + +// Describes a service. +type ServiceDescriptorProto struct { + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Method []*MethodDescriptorProto `protobuf:"bytes,2,rep,name=method" json:"method,omitempty"` + Options *ServiceOptions `protobuf:"bytes,3,opt,name=options" json:"options,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *ServiceDescriptorProto) Reset() { *m = ServiceDescriptorProto{} } +func (m *ServiceDescriptorProto) String() string { return proto.CompactTextString(m) } +func (*ServiceDescriptorProto) ProtoMessage() {} +func (*ServiceDescriptorProto) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{8} } + +func (m *ServiceDescriptorProto) GetName() string { + if m != nil && m.Name != nil { + return *m.Name + } + return "" +} + +func (m *ServiceDescriptorProto) GetMethod() []*MethodDescriptorProto { + if m != nil { + return m.Method + } + return nil +} + +func (m *ServiceDescriptorProto) GetOptions() *ServiceOptions { + if m != nil { + return m.Options + } + return nil +} + +// Describes a method of a service. +type MethodDescriptorProto struct { + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + // Input and output type names. These are resolved in the same way as + // FieldDescriptorProto.type_name, but must refer to a message type. + InputType *string `protobuf:"bytes,2,opt,name=input_type,json=inputType" json:"input_type,omitempty"` + OutputType *string `protobuf:"bytes,3,opt,name=output_type,json=outputType" json:"output_type,omitempty"` + Options *MethodOptions `protobuf:"bytes,4,opt,name=options" json:"options,omitempty"` + // Identifies if client streams multiple client messages + ClientStreaming *bool `protobuf:"varint,5,opt,name=client_streaming,json=clientStreaming,def=0" json:"client_streaming,omitempty"` + // Identifies if server streams multiple server messages + ServerStreaming *bool `protobuf:"varint,6,opt,name=server_streaming,json=serverStreaming,def=0" json:"server_streaming,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *MethodDescriptorProto) Reset() { *m = MethodDescriptorProto{} } +func (m *MethodDescriptorProto) String() string { return proto.CompactTextString(m) } +func (*MethodDescriptorProto) ProtoMessage() {} +func (*MethodDescriptorProto) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{9} } + +const Default_MethodDescriptorProto_ClientStreaming bool = false +const Default_MethodDescriptorProto_ServerStreaming bool = false + +func (m *MethodDescriptorProto) GetName() string { + if m != nil && m.Name != nil { + return *m.Name + } + return "" +} + +func (m *MethodDescriptorProto) GetInputType() string { + if m != nil && m.InputType != nil { + return *m.InputType + } + return "" +} + +func (m *MethodDescriptorProto) GetOutputType() string { + if m != nil && m.OutputType != nil { + return *m.OutputType + } + return "" +} + +func (m *MethodDescriptorProto) GetOptions() *MethodOptions { + if m != nil { + return m.Options + } + return nil +} + +func (m *MethodDescriptorProto) GetClientStreaming() bool { + if m != nil && m.ClientStreaming != nil { + return *m.ClientStreaming + } + return Default_MethodDescriptorProto_ClientStreaming +} + +func (m *MethodDescriptorProto) GetServerStreaming() bool { + if m != nil && m.ServerStreaming != nil { + return *m.ServerStreaming + } + return Default_MethodDescriptorProto_ServerStreaming +} + +type FileOptions struct { + // Sets the Java package where classes generated from this .proto will be + // placed. By default, the proto package is used, but this is often + // inappropriate because proto packages do not normally start with backwards + // domain names. + JavaPackage *string `protobuf:"bytes,1,opt,name=java_package,json=javaPackage" json:"java_package,omitempty"` + // If set, all the classes from the .proto file are wrapped in a single + // outer class with the given name. This applies to both Proto1 + // (equivalent to the old "--one_java_file" option) and Proto2 (where + // a .proto always translates to a single class, but you may want to + // explicitly choose the class name). + JavaOuterClassname *string `protobuf:"bytes,8,opt,name=java_outer_classname,json=javaOuterClassname" json:"java_outer_classname,omitempty"` + // If set true, then the Java code generator will generate a separate .java + // file for each top-level message, enum, and service defined in the .proto + // file. Thus, these types will *not* be nested inside the outer class + // named by java_outer_classname. However, the outer class will still be + // generated to contain the file's getDescriptor() method as well as any + // top-level extensions defined in the file. + JavaMultipleFiles *bool `protobuf:"varint,10,opt,name=java_multiple_files,json=javaMultipleFiles,def=0" json:"java_multiple_files,omitempty"` + // This option does nothing. + JavaGenerateEqualsAndHash *bool `protobuf:"varint,20,opt,name=java_generate_equals_and_hash,json=javaGenerateEqualsAndHash" json:"java_generate_equals_and_hash,omitempty"` + // If set true, then the Java2 code generator will generate code that + // throws an exception whenever an attempt is made to assign a non-UTF-8 + // byte sequence to a string field. + // Message reflection will do the same. + // However, an extension field still accepts non-UTF-8 byte sequences. + // This option has no effect on when used with the lite runtime. + JavaStringCheckUtf8 *bool `protobuf:"varint,27,opt,name=java_string_check_utf8,json=javaStringCheckUtf8,def=0" json:"java_string_check_utf8,omitempty"` + OptimizeFor *FileOptions_OptimizeMode `protobuf:"varint,9,opt,name=optimize_for,json=optimizeFor,enum=google.protobuf.FileOptions_OptimizeMode,def=1" json:"optimize_for,omitempty"` + // Sets the Go package where structs generated from this .proto will be + // placed. If omitted, the Go package will be derived from the following: + // - The basename of the package import path, if provided. + // - Otherwise, the package statement in the .proto file, if present. + // - Otherwise, the basename of the .proto file, without extension. + GoPackage *string `protobuf:"bytes,11,opt,name=go_package,json=goPackage" json:"go_package,omitempty"` + // Should generic services be generated in each language? "Generic" services + // are not specific to any particular RPC system. They are generated by the + // main code generators in each language (without additional plugins). + // Generic services were the only kind of service generation supported by + // early versions of google.protobuf. + // + // Generic services are now considered deprecated in favor of using plugins + // that generate code specific to your particular RPC system. Therefore, + // these default to false. Old code which depends on generic services should + // explicitly set them to true. + CcGenericServices *bool `protobuf:"varint,16,opt,name=cc_generic_services,json=ccGenericServices,def=0" json:"cc_generic_services,omitempty"` + JavaGenericServices *bool `protobuf:"varint,17,opt,name=java_generic_services,json=javaGenericServices,def=0" json:"java_generic_services,omitempty"` + PyGenericServices *bool `protobuf:"varint,18,opt,name=py_generic_services,json=pyGenericServices,def=0" json:"py_generic_services,omitempty"` + PhpGenericServices *bool `protobuf:"varint,42,opt,name=php_generic_services,json=phpGenericServices,def=0" json:"php_generic_services,omitempty"` + // Is this file deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for everything in the file, or it will be completely ignored; in the very + // least, this is a formalization for deprecating files. + Deprecated *bool `protobuf:"varint,23,opt,name=deprecated,def=0" json:"deprecated,omitempty"` + // Enables the use of arenas for the proto messages in this file. This applies + // only to generated classes for C++. + CcEnableArenas *bool `protobuf:"varint,31,opt,name=cc_enable_arenas,json=ccEnableArenas,def=0" json:"cc_enable_arenas,omitempty"` + // Sets the objective c class prefix which is prepended to all objective c + // generated classes from this .proto. There is no default. + ObjcClassPrefix *string `protobuf:"bytes,36,opt,name=objc_class_prefix,json=objcClassPrefix" json:"objc_class_prefix,omitempty"` + // Namespace for generated classes; defaults to the package. + CsharpNamespace *string `protobuf:"bytes,37,opt,name=csharp_namespace,json=csharpNamespace" json:"csharp_namespace,omitempty"` + // By default Swift generators will take the proto package and CamelCase it + // replacing '.' with underscore and use that to prefix the types/symbols + // defined. When this options is provided, they will use this value instead + // to prefix the types/symbols defined. + SwiftPrefix *string `protobuf:"bytes,39,opt,name=swift_prefix,json=swiftPrefix" json:"swift_prefix,omitempty"` + // Sets the php class prefix which is prepended to all php generated classes + // from this .proto. Default is empty. + PhpClassPrefix *string `protobuf:"bytes,40,opt,name=php_class_prefix,json=phpClassPrefix" json:"php_class_prefix,omitempty"` + // Use this option to change the namespace of php generated classes. Default + // is empty. When this option is empty, the package name will be used for + // determining the namespace. + PhpNamespace *string `protobuf:"bytes,41,opt,name=php_namespace,json=phpNamespace" json:"php_namespace,omitempty"` + // The parser stores options it doesn't recognize here. See above. + UninterpretedOption []*UninterpretedOption `protobuf:"bytes,999,rep,name=uninterpreted_option,json=uninterpretedOption" json:"uninterpreted_option,omitempty"` + proto.XXX_InternalExtensions `json:"-"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *FileOptions) Reset() { *m = FileOptions{} } +func (m *FileOptions) String() string { return proto.CompactTextString(m) } +func (*FileOptions) ProtoMessage() {} +func (*FileOptions) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{10} } + +var extRange_FileOptions = []proto.ExtensionRange{ + {Start: 1000, End: 536870911}, +} + +func (*FileOptions) ExtensionRangeArray() []proto.ExtensionRange { + return extRange_FileOptions +} + +const Default_FileOptions_JavaMultipleFiles bool = false +const Default_FileOptions_JavaStringCheckUtf8 bool = false +const Default_FileOptions_OptimizeFor FileOptions_OptimizeMode = FileOptions_SPEED +const Default_FileOptions_CcGenericServices bool = false +const Default_FileOptions_JavaGenericServices bool = false +const Default_FileOptions_PyGenericServices bool = false +const Default_FileOptions_PhpGenericServices bool = false +const Default_FileOptions_Deprecated bool = false +const Default_FileOptions_CcEnableArenas bool = false + +func (m *FileOptions) GetJavaPackage() string { + if m != nil && m.JavaPackage != nil { + return *m.JavaPackage + } + return "" +} + +func (m *FileOptions) GetJavaOuterClassname() string { + if m != nil && m.JavaOuterClassname != nil { + return *m.JavaOuterClassname + } + return "" +} + +func (m *FileOptions) GetJavaMultipleFiles() bool { + if m != nil && m.JavaMultipleFiles != nil { + return *m.JavaMultipleFiles + } + return Default_FileOptions_JavaMultipleFiles +} + +func (m *FileOptions) GetJavaGenerateEqualsAndHash() bool { + if m != nil && m.JavaGenerateEqualsAndHash != nil { + return *m.JavaGenerateEqualsAndHash + } + return false +} + +func (m *FileOptions) GetJavaStringCheckUtf8() bool { + if m != nil && m.JavaStringCheckUtf8 != nil { + return *m.JavaStringCheckUtf8 + } + return Default_FileOptions_JavaStringCheckUtf8 +} + +func (m *FileOptions) GetOptimizeFor() FileOptions_OptimizeMode { + if m != nil && m.OptimizeFor != nil { + return *m.OptimizeFor + } + return Default_FileOptions_OptimizeFor +} + +func (m *FileOptions) GetGoPackage() string { + if m != nil && m.GoPackage != nil { + return *m.GoPackage + } + return "" +} + +func (m *FileOptions) GetCcGenericServices() bool { + if m != nil && m.CcGenericServices != nil { + return *m.CcGenericServices + } + return Default_FileOptions_CcGenericServices +} + +func (m *FileOptions) GetJavaGenericServices() bool { + if m != nil && m.JavaGenericServices != nil { + return *m.JavaGenericServices + } + return Default_FileOptions_JavaGenericServices +} + +func (m *FileOptions) GetPyGenericServices() bool { + if m != nil && m.PyGenericServices != nil { + return *m.PyGenericServices + } + return Default_FileOptions_PyGenericServices +} + +func (m *FileOptions) GetPhpGenericServices() bool { + if m != nil && m.PhpGenericServices != nil { + return *m.PhpGenericServices + } + return Default_FileOptions_PhpGenericServices +} + +func (m *FileOptions) GetDeprecated() bool { + if m != nil && m.Deprecated != nil { + return *m.Deprecated + } + return Default_FileOptions_Deprecated +} + +func (m *FileOptions) GetCcEnableArenas() bool { + if m != nil && m.CcEnableArenas != nil { + return *m.CcEnableArenas + } + return Default_FileOptions_CcEnableArenas +} + +func (m *FileOptions) GetObjcClassPrefix() string { + if m != nil && m.ObjcClassPrefix != nil { + return *m.ObjcClassPrefix + } + return "" +} + +func (m *FileOptions) GetCsharpNamespace() string { + if m != nil && m.CsharpNamespace != nil { + return *m.CsharpNamespace + } + return "" +} + +func (m *FileOptions) GetSwiftPrefix() string { + if m != nil && m.SwiftPrefix != nil { + return *m.SwiftPrefix + } + return "" +} + +func (m *FileOptions) GetPhpClassPrefix() string { + if m != nil && m.PhpClassPrefix != nil { + return *m.PhpClassPrefix + } + return "" +} + +func (m *FileOptions) GetPhpNamespace() string { + if m != nil && m.PhpNamespace != nil { + return *m.PhpNamespace + } + return "" +} + +func (m *FileOptions) GetUninterpretedOption() []*UninterpretedOption { + if m != nil { + return m.UninterpretedOption + } + return nil +} + +type MessageOptions struct { + // Set true to use the old proto1 MessageSet wire format for extensions. + // This is provided for backwards-compatibility with the MessageSet wire + // format. You should not use this for any other reason: It's less + // efficient, has fewer features, and is more complicated. + // + // The message must be defined exactly as follows: + // message Foo { + // option message_set_wire_format = true; + // extensions 4 to max; + // } + // Note that the message cannot have any defined fields; MessageSets only + // have extensions. + // + // All extensions of your type must be singular messages; e.g. they cannot + // be int32s, enums, or repeated messages. + // + // Because this is an option, the above two restrictions are not enforced by + // the protocol compiler. + MessageSetWireFormat *bool `protobuf:"varint,1,opt,name=message_set_wire_format,json=messageSetWireFormat,def=0" json:"message_set_wire_format,omitempty"` + // Disables the generation of the standard "descriptor()" accessor, which can + // conflict with a field of the same name. This is meant to make migration + // from proto1 easier; new code should avoid fields named "descriptor". + NoStandardDescriptorAccessor *bool `protobuf:"varint,2,opt,name=no_standard_descriptor_accessor,json=noStandardDescriptorAccessor,def=0" json:"no_standard_descriptor_accessor,omitempty"` + // Is this message deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for the message, or it will be completely ignored; in the very least, + // this is a formalization for deprecating messages. + Deprecated *bool `protobuf:"varint,3,opt,name=deprecated,def=0" json:"deprecated,omitempty"` + // Whether the message is an automatically generated map entry type for the + // maps field. + // + // For maps fields: + // map map_field = 1; + // The parsed descriptor looks like: + // message MapFieldEntry { + // option map_entry = true; + // optional KeyType key = 1; + // optional ValueType value = 2; + // } + // repeated MapFieldEntry map_field = 1; + // + // Implementations may choose not to generate the map_entry=true message, but + // use a native map in the target language to hold the keys and values. + // The reflection APIs in such implementions still need to work as + // if the field is a repeated message field. + // + // NOTE: Do not set the option in .proto files. Always use the maps syntax + // instead. The option should only be implicitly set by the proto compiler + // parser. + MapEntry *bool `protobuf:"varint,7,opt,name=map_entry,json=mapEntry" json:"map_entry,omitempty"` + // The parser stores options it doesn't recognize here. See above. + UninterpretedOption []*UninterpretedOption `protobuf:"bytes,999,rep,name=uninterpreted_option,json=uninterpretedOption" json:"uninterpreted_option,omitempty"` + proto.XXX_InternalExtensions `json:"-"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *MessageOptions) Reset() { *m = MessageOptions{} } +func (m *MessageOptions) String() string { return proto.CompactTextString(m) } +func (*MessageOptions) ProtoMessage() {} +func (*MessageOptions) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{11} } + +var extRange_MessageOptions = []proto.ExtensionRange{ + {Start: 1000, End: 536870911}, +} + +func (*MessageOptions) ExtensionRangeArray() []proto.ExtensionRange { + return extRange_MessageOptions +} + +const Default_MessageOptions_MessageSetWireFormat bool = false +const Default_MessageOptions_NoStandardDescriptorAccessor bool = false +const Default_MessageOptions_Deprecated bool = false + +func (m *MessageOptions) GetMessageSetWireFormat() bool { + if m != nil && m.MessageSetWireFormat != nil { + return *m.MessageSetWireFormat + } + return Default_MessageOptions_MessageSetWireFormat +} + +func (m *MessageOptions) GetNoStandardDescriptorAccessor() bool { + if m != nil && m.NoStandardDescriptorAccessor != nil { + return *m.NoStandardDescriptorAccessor + } + return Default_MessageOptions_NoStandardDescriptorAccessor +} + +func (m *MessageOptions) GetDeprecated() bool { + if m != nil && m.Deprecated != nil { + return *m.Deprecated + } + return Default_MessageOptions_Deprecated +} + +func (m *MessageOptions) GetMapEntry() bool { + if m != nil && m.MapEntry != nil { + return *m.MapEntry + } + return false +} + +func (m *MessageOptions) GetUninterpretedOption() []*UninterpretedOption { + if m != nil { + return m.UninterpretedOption + } + return nil +} + +type FieldOptions struct { + // The ctype option instructs the C++ code generator to use a different + // representation of the field than it normally would. See the specific + // options below. This option is not yet implemented in the open source + // release -- sorry, we'll try to include it in a future version! + Ctype *FieldOptions_CType `protobuf:"varint,1,opt,name=ctype,enum=google.protobuf.FieldOptions_CType,def=0" json:"ctype,omitempty"` + // The packed option can be enabled for repeated primitive fields to enable + // a more efficient representation on the wire. Rather than repeatedly + // writing the tag and type for each element, the entire array is encoded as + // a single length-delimited blob. In proto3, only explicit setting it to + // false will avoid using packed encoding. + Packed *bool `protobuf:"varint,2,opt,name=packed" json:"packed,omitempty"` + // The jstype option determines the JavaScript type used for values of the + // field. The option is permitted only for 64 bit integral and fixed types + // (int64, uint64, sint64, fixed64, sfixed64). A field with jstype JS_STRING + // is represented as JavaScript string, which avoids loss of precision that + // can happen when a large value is converted to a floating point JavaScript. + // Specifying JS_NUMBER for the jstype causes the generated JavaScript code to + // use the JavaScript "number" type. The behavior of the default option + // JS_NORMAL is implementation dependent. + // + // This option is an enum to permit additional types to be added, e.g. + // goog.math.Integer. + Jstype *FieldOptions_JSType `protobuf:"varint,6,opt,name=jstype,enum=google.protobuf.FieldOptions_JSType,def=0" json:"jstype,omitempty"` + // Should this field be parsed lazily? Lazy applies only to message-type + // fields. It means that when the outer message is initially parsed, the + // inner message's contents will not be parsed but instead stored in encoded + // form. The inner message will actually be parsed when it is first accessed. + // + // This is only a hint. Implementations are free to choose whether to use + // eager or lazy parsing regardless of the value of this option. However, + // setting this option true suggests that the protocol author believes that + // using lazy parsing on this field is worth the additional bookkeeping + // overhead typically needed to implement it. + // + // This option does not affect the public interface of any generated code; + // all method signatures remain the same. Furthermore, thread-safety of the + // interface is not affected by this option; const methods remain safe to + // call from multiple threads concurrently, while non-const methods continue + // to require exclusive access. + // + // + // Note that implementations may choose not to check required fields within + // a lazy sub-message. That is, calling IsInitialized() on the outer message + // may return true even if the inner message has missing required fields. + // This is necessary because otherwise the inner message would have to be + // parsed in order to perform the check, defeating the purpose of lazy + // parsing. An implementation which chooses not to check required fields + // must be consistent about it. That is, for any particular sub-message, the + // implementation must either *always* check its required fields, or *never* + // check its required fields, regardless of whether or not the message has + // been parsed. + Lazy *bool `protobuf:"varint,5,opt,name=lazy,def=0" json:"lazy,omitempty"` + // Is this field deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for accessors, or it will be completely ignored; in the very least, this + // is a formalization for deprecating fields. + Deprecated *bool `protobuf:"varint,3,opt,name=deprecated,def=0" json:"deprecated,omitempty"` + // For Google-internal migration only. Do not use. + Weak *bool `protobuf:"varint,10,opt,name=weak,def=0" json:"weak,omitempty"` + // The parser stores options it doesn't recognize here. See above. + UninterpretedOption []*UninterpretedOption `protobuf:"bytes,999,rep,name=uninterpreted_option,json=uninterpretedOption" json:"uninterpreted_option,omitempty"` + proto.XXX_InternalExtensions `json:"-"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *FieldOptions) Reset() { *m = FieldOptions{} } +func (m *FieldOptions) String() string { return proto.CompactTextString(m) } +func (*FieldOptions) ProtoMessage() {} +func (*FieldOptions) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{12} } + +var extRange_FieldOptions = []proto.ExtensionRange{ + {Start: 1000, End: 536870911}, +} + +func (*FieldOptions) ExtensionRangeArray() []proto.ExtensionRange { + return extRange_FieldOptions +} + +const Default_FieldOptions_Ctype FieldOptions_CType = FieldOptions_STRING +const Default_FieldOptions_Jstype FieldOptions_JSType = FieldOptions_JS_NORMAL +const Default_FieldOptions_Lazy bool = false +const Default_FieldOptions_Deprecated bool = false +const Default_FieldOptions_Weak bool = false + +func (m *FieldOptions) GetCtype() FieldOptions_CType { + if m != nil && m.Ctype != nil { + return *m.Ctype + } + return Default_FieldOptions_Ctype +} + +func (m *FieldOptions) GetPacked() bool { + if m != nil && m.Packed != nil { + return *m.Packed + } + return false +} + +func (m *FieldOptions) GetJstype() FieldOptions_JSType { + if m != nil && m.Jstype != nil { + return *m.Jstype + } + return Default_FieldOptions_Jstype +} + +func (m *FieldOptions) GetLazy() bool { + if m != nil && m.Lazy != nil { + return *m.Lazy + } + return Default_FieldOptions_Lazy +} + +func (m *FieldOptions) GetDeprecated() bool { + if m != nil && m.Deprecated != nil { + return *m.Deprecated + } + return Default_FieldOptions_Deprecated +} + +func (m *FieldOptions) GetWeak() bool { + if m != nil && m.Weak != nil { + return *m.Weak + } + return Default_FieldOptions_Weak +} + +func (m *FieldOptions) GetUninterpretedOption() []*UninterpretedOption { + if m != nil { + return m.UninterpretedOption + } + return nil +} + +type OneofOptions struct { + // The parser stores options it doesn't recognize here. See above. + UninterpretedOption []*UninterpretedOption `protobuf:"bytes,999,rep,name=uninterpreted_option,json=uninterpretedOption" json:"uninterpreted_option,omitempty"` + proto.XXX_InternalExtensions `json:"-"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *OneofOptions) Reset() { *m = OneofOptions{} } +func (m *OneofOptions) String() string { return proto.CompactTextString(m) } +func (*OneofOptions) ProtoMessage() {} +func (*OneofOptions) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{13} } + +var extRange_OneofOptions = []proto.ExtensionRange{ + {Start: 1000, End: 536870911}, +} + +func (*OneofOptions) ExtensionRangeArray() []proto.ExtensionRange { + return extRange_OneofOptions +} + +func (m *OneofOptions) GetUninterpretedOption() []*UninterpretedOption { + if m != nil { + return m.UninterpretedOption + } + return nil +} + +type EnumOptions struct { + // Set this option to true to allow mapping different tag names to the same + // value. + AllowAlias *bool `protobuf:"varint,2,opt,name=allow_alias,json=allowAlias" json:"allow_alias,omitempty"` + // Is this enum deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for the enum, or it will be completely ignored; in the very least, this + // is a formalization for deprecating enums. + Deprecated *bool `protobuf:"varint,3,opt,name=deprecated,def=0" json:"deprecated,omitempty"` + // The parser stores options it doesn't recognize here. See above. + UninterpretedOption []*UninterpretedOption `protobuf:"bytes,999,rep,name=uninterpreted_option,json=uninterpretedOption" json:"uninterpreted_option,omitempty"` + proto.XXX_InternalExtensions `json:"-"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *EnumOptions) Reset() { *m = EnumOptions{} } +func (m *EnumOptions) String() string { return proto.CompactTextString(m) } +func (*EnumOptions) ProtoMessage() {} +func (*EnumOptions) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{14} } + +var extRange_EnumOptions = []proto.ExtensionRange{ + {Start: 1000, End: 536870911}, +} + +func (*EnumOptions) ExtensionRangeArray() []proto.ExtensionRange { + return extRange_EnumOptions +} + +const Default_EnumOptions_Deprecated bool = false + +func (m *EnumOptions) GetAllowAlias() bool { + if m != nil && m.AllowAlias != nil { + return *m.AllowAlias + } + return false +} + +func (m *EnumOptions) GetDeprecated() bool { + if m != nil && m.Deprecated != nil { + return *m.Deprecated + } + return Default_EnumOptions_Deprecated +} + +func (m *EnumOptions) GetUninterpretedOption() []*UninterpretedOption { + if m != nil { + return m.UninterpretedOption + } + return nil +} + +type EnumValueOptions struct { + // Is this enum value deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for the enum value, or it will be completely ignored; in the very least, + // this is a formalization for deprecating enum values. + Deprecated *bool `protobuf:"varint,1,opt,name=deprecated,def=0" json:"deprecated,omitempty"` + // The parser stores options it doesn't recognize here. See above. + UninterpretedOption []*UninterpretedOption `protobuf:"bytes,999,rep,name=uninterpreted_option,json=uninterpretedOption" json:"uninterpreted_option,omitempty"` + proto.XXX_InternalExtensions `json:"-"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *EnumValueOptions) Reset() { *m = EnumValueOptions{} } +func (m *EnumValueOptions) String() string { return proto.CompactTextString(m) } +func (*EnumValueOptions) ProtoMessage() {} +func (*EnumValueOptions) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{15} } + +var extRange_EnumValueOptions = []proto.ExtensionRange{ + {Start: 1000, End: 536870911}, +} + +func (*EnumValueOptions) ExtensionRangeArray() []proto.ExtensionRange { + return extRange_EnumValueOptions +} + +const Default_EnumValueOptions_Deprecated bool = false + +func (m *EnumValueOptions) GetDeprecated() bool { + if m != nil && m.Deprecated != nil { + return *m.Deprecated + } + return Default_EnumValueOptions_Deprecated +} + +func (m *EnumValueOptions) GetUninterpretedOption() []*UninterpretedOption { + if m != nil { + return m.UninterpretedOption + } + return nil +} + +type ServiceOptions struct { + // Is this service deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for the service, or it will be completely ignored; in the very least, + // this is a formalization for deprecating services. + Deprecated *bool `protobuf:"varint,33,opt,name=deprecated,def=0" json:"deprecated,omitempty"` + // The parser stores options it doesn't recognize here. See above. + UninterpretedOption []*UninterpretedOption `protobuf:"bytes,999,rep,name=uninterpreted_option,json=uninterpretedOption" json:"uninterpreted_option,omitempty"` + proto.XXX_InternalExtensions `json:"-"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *ServiceOptions) Reset() { *m = ServiceOptions{} } +func (m *ServiceOptions) String() string { return proto.CompactTextString(m) } +func (*ServiceOptions) ProtoMessage() {} +func (*ServiceOptions) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{16} } + +var extRange_ServiceOptions = []proto.ExtensionRange{ + {Start: 1000, End: 536870911}, +} + +func (*ServiceOptions) ExtensionRangeArray() []proto.ExtensionRange { + return extRange_ServiceOptions +} + +const Default_ServiceOptions_Deprecated bool = false + +func (m *ServiceOptions) GetDeprecated() bool { + if m != nil && m.Deprecated != nil { + return *m.Deprecated + } + return Default_ServiceOptions_Deprecated +} + +func (m *ServiceOptions) GetUninterpretedOption() []*UninterpretedOption { + if m != nil { + return m.UninterpretedOption + } + return nil +} + +type MethodOptions struct { + // Is this method deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for the method, or it will be completely ignored; in the very least, + // this is a formalization for deprecating methods. + Deprecated *bool `protobuf:"varint,33,opt,name=deprecated,def=0" json:"deprecated,omitempty"` + IdempotencyLevel *MethodOptions_IdempotencyLevel `protobuf:"varint,34,opt,name=idempotency_level,json=idempotencyLevel,enum=google.protobuf.MethodOptions_IdempotencyLevel,def=0" json:"idempotency_level,omitempty"` + // The parser stores options it doesn't recognize here. See above. + UninterpretedOption []*UninterpretedOption `protobuf:"bytes,999,rep,name=uninterpreted_option,json=uninterpretedOption" json:"uninterpreted_option,omitempty"` + proto.XXX_InternalExtensions `json:"-"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *MethodOptions) Reset() { *m = MethodOptions{} } +func (m *MethodOptions) String() string { return proto.CompactTextString(m) } +func (*MethodOptions) ProtoMessage() {} +func (*MethodOptions) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{17} } + +var extRange_MethodOptions = []proto.ExtensionRange{ + {Start: 1000, End: 536870911}, +} + +func (*MethodOptions) ExtensionRangeArray() []proto.ExtensionRange { + return extRange_MethodOptions +} + +const Default_MethodOptions_Deprecated bool = false +const Default_MethodOptions_IdempotencyLevel MethodOptions_IdempotencyLevel = MethodOptions_IDEMPOTENCY_UNKNOWN + +func (m *MethodOptions) GetDeprecated() bool { + if m != nil && m.Deprecated != nil { + return *m.Deprecated + } + return Default_MethodOptions_Deprecated +} + +func (m *MethodOptions) GetIdempotencyLevel() MethodOptions_IdempotencyLevel { + if m != nil && m.IdempotencyLevel != nil { + return *m.IdempotencyLevel + } + return Default_MethodOptions_IdempotencyLevel +} + +func (m *MethodOptions) GetUninterpretedOption() []*UninterpretedOption { + if m != nil { + return m.UninterpretedOption + } + return nil +} + +// A message representing a option the parser does not recognize. This only +// appears in options protos created by the compiler::Parser class. +// DescriptorPool resolves these when building Descriptor objects. Therefore, +// options protos in descriptor objects (e.g. returned by Descriptor::options(), +// or produced by Descriptor::CopyTo()) will never have UninterpretedOptions +// in them. +type UninterpretedOption struct { + Name []*UninterpretedOption_NamePart `protobuf:"bytes,2,rep,name=name" json:"name,omitempty"` + // The value of the uninterpreted option, in whatever type the tokenizer + // identified it as during parsing. Exactly one of these should be set. + IdentifierValue *string `protobuf:"bytes,3,opt,name=identifier_value,json=identifierValue" json:"identifier_value,omitempty"` + PositiveIntValue *uint64 `protobuf:"varint,4,opt,name=positive_int_value,json=positiveIntValue" json:"positive_int_value,omitempty"` + NegativeIntValue *int64 `protobuf:"varint,5,opt,name=negative_int_value,json=negativeIntValue" json:"negative_int_value,omitempty"` + DoubleValue *float64 `protobuf:"fixed64,6,opt,name=double_value,json=doubleValue" json:"double_value,omitempty"` + StringValue []byte `protobuf:"bytes,7,opt,name=string_value,json=stringValue" json:"string_value,omitempty"` + AggregateValue *string `protobuf:"bytes,8,opt,name=aggregate_value,json=aggregateValue" json:"aggregate_value,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *UninterpretedOption) Reset() { *m = UninterpretedOption{} } +func (m *UninterpretedOption) String() string { return proto.CompactTextString(m) } +func (*UninterpretedOption) ProtoMessage() {} +func (*UninterpretedOption) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{18} } + +func (m *UninterpretedOption) GetName() []*UninterpretedOption_NamePart { + if m != nil { + return m.Name + } + return nil +} + +func (m *UninterpretedOption) GetIdentifierValue() string { + if m != nil && m.IdentifierValue != nil { + return *m.IdentifierValue + } + return "" +} + +func (m *UninterpretedOption) GetPositiveIntValue() uint64 { + if m != nil && m.PositiveIntValue != nil { + return *m.PositiveIntValue + } + return 0 +} + +func (m *UninterpretedOption) GetNegativeIntValue() int64 { + if m != nil && m.NegativeIntValue != nil { + return *m.NegativeIntValue + } + return 0 +} + +func (m *UninterpretedOption) GetDoubleValue() float64 { + if m != nil && m.DoubleValue != nil { + return *m.DoubleValue + } + return 0 +} + +func (m *UninterpretedOption) GetStringValue() []byte { + if m != nil { + return m.StringValue + } + return nil +} + +func (m *UninterpretedOption) GetAggregateValue() string { + if m != nil && m.AggregateValue != nil { + return *m.AggregateValue + } + return "" +} + +// The name of the uninterpreted option. Each string represents a segment in +// a dot-separated name. is_extension is true iff a segment represents an +// extension (denoted with parentheses in options specs in .proto files). +// E.g.,{ ["foo", false], ["bar.baz", true], ["qux", false] } represents +// "foo.(bar.baz).qux". +type UninterpretedOption_NamePart struct { + NamePart *string `protobuf:"bytes,1,req,name=name_part,json=namePart" json:"name_part,omitempty"` + IsExtension *bool `protobuf:"varint,2,req,name=is_extension,json=isExtension" json:"is_extension,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *UninterpretedOption_NamePart) Reset() { *m = UninterpretedOption_NamePart{} } +func (m *UninterpretedOption_NamePart) String() string { return proto.CompactTextString(m) } +func (*UninterpretedOption_NamePart) ProtoMessage() {} +func (*UninterpretedOption_NamePart) Descriptor() ([]byte, []int) { + return fileDescriptorDescriptor, []int{18, 0} +} + +func (m *UninterpretedOption_NamePart) GetNamePart() string { + if m != nil && m.NamePart != nil { + return *m.NamePart + } + return "" +} + +func (m *UninterpretedOption_NamePart) GetIsExtension() bool { + if m != nil && m.IsExtension != nil { + return *m.IsExtension + } + return false +} + +// Encapsulates information about the original source file from which a +// FileDescriptorProto was generated. +type SourceCodeInfo struct { + // A Location identifies a piece of source code in a .proto file which + // corresponds to a particular definition. This information is intended + // to be useful to IDEs, code indexers, documentation generators, and similar + // tools. + // + // For example, say we have a file like: + // message Foo { + // optional string foo = 1; + // } + // Let's look at just the field definition: + // optional string foo = 1; + // ^ ^^ ^^ ^ ^^^ + // a bc de f ghi + // We have the following locations: + // span path represents + // [a,i) [ 4, 0, 2, 0 ] The whole field definition. + // [a,b) [ 4, 0, 2, 0, 4 ] The label (optional). + // [c,d) [ 4, 0, 2, 0, 5 ] The type (string). + // [e,f) [ 4, 0, 2, 0, 1 ] The name (foo). + // [g,h) [ 4, 0, 2, 0, 3 ] The number (1). + // + // Notes: + // - A location may refer to a repeated field itself (i.e. not to any + // particular index within it). This is used whenever a set of elements are + // logically enclosed in a single code segment. For example, an entire + // extend block (possibly containing multiple extension definitions) will + // have an outer location whose path refers to the "extensions" repeated + // field without an index. + // - Multiple locations may have the same path. This happens when a single + // logical declaration is spread out across multiple places. The most + // obvious example is the "extend" block again -- there may be multiple + // extend blocks in the same scope, each of which will have the same path. + // - A location's span is not always a subset of its parent's span. For + // example, the "extendee" of an extension declaration appears at the + // beginning of the "extend" block and is shared by all extensions within + // the block. + // - Just because a location's span is a subset of some other location's span + // does not mean that it is a descendent. For example, a "group" defines + // both a type and a field in a single declaration. Thus, the locations + // corresponding to the type and field and their components will overlap. + // - Code which tries to interpret locations should probably be designed to + // ignore those that it doesn't understand, as more types of locations could + // be recorded in the future. + Location []*SourceCodeInfo_Location `protobuf:"bytes,1,rep,name=location" json:"location,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *SourceCodeInfo) Reset() { *m = SourceCodeInfo{} } +func (m *SourceCodeInfo) String() string { return proto.CompactTextString(m) } +func (*SourceCodeInfo) ProtoMessage() {} +func (*SourceCodeInfo) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{19} } + +func (m *SourceCodeInfo) GetLocation() []*SourceCodeInfo_Location { + if m != nil { + return m.Location + } + return nil +} + +type SourceCodeInfo_Location struct { + // Identifies which part of the FileDescriptorProto was defined at this + // location. + // + // Each element is a field number or an index. They form a path from + // the root FileDescriptorProto to the place where the definition. For + // example, this path: + // [ 4, 3, 2, 7, 1 ] + // refers to: + // file.message_type(3) // 4, 3 + // .field(7) // 2, 7 + // .name() // 1 + // This is because FileDescriptorProto.message_type has field number 4: + // repeated DescriptorProto message_type = 4; + // and DescriptorProto.field has field number 2: + // repeated FieldDescriptorProto field = 2; + // and FieldDescriptorProto.name has field number 1: + // optional string name = 1; + // + // Thus, the above path gives the location of a field name. If we removed + // the last element: + // [ 4, 3, 2, 7 ] + // this path refers to the whole field declaration (from the beginning + // of the label to the terminating semicolon). + Path []int32 `protobuf:"varint,1,rep,packed,name=path" json:"path,omitempty"` + // Always has exactly three or four elements: start line, start column, + // end line (optional, otherwise assumed same as start line), end column. + // These are packed into a single field for efficiency. Note that line + // and column numbers are zero-based -- typically you will want to add + // 1 to each before displaying to a user. + Span []int32 `protobuf:"varint,2,rep,packed,name=span" json:"span,omitempty"` + // If this SourceCodeInfo represents a complete declaration, these are any + // comments appearing before and after the declaration which appear to be + // attached to the declaration. + // + // A series of line comments appearing on consecutive lines, with no other + // tokens appearing on those lines, will be treated as a single comment. + // + // leading_detached_comments will keep paragraphs of comments that appear + // before (but not connected to) the current element. Each paragraph, + // separated by empty lines, will be one comment element in the repeated + // field. + // + // Only the comment content is provided; comment markers (e.g. //) are + // stripped out. For block comments, leading whitespace and an asterisk + // will be stripped from the beginning of each line other than the first. + // Newlines are included in the output. + // + // Examples: + // + // optional int32 foo = 1; // Comment attached to foo. + // // Comment attached to bar. + // optional int32 bar = 2; + // + // optional string baz = 3; + // // Comment attached to baz. + // // Another line attached to baz. + // + // // Comment attached to qux. + // // + // // Another line attached to qux. + // optional double qux = 4; + // + // // Detached comment for corge. This is not leading or trailing comments + // // to qux or corge because there are blank lines separating it from + // // both. + // + // // Detached comment for corge paragraph 2. + // + // optional string corge = 5; + // /* Block comment attached + // * to corge. Leading asterisks + // * will be removed. */ + // /* Block comment attached to + // * grault. */ + // optional int32 grault = 6; + // + // // ignored detached comments. + LeadingComments *string `protobuf:"bytes,3,opt,name=leading_comments,json=leadingComments" json:"leading_comments,omitempty"` + TrailingComments *string `protobuf:"bytes,4,opt,name=trailing_comments,json=trailingComments" json:"trailing_comments,omitempty"` + LeadingDetachedComments []string `protobuf:"bytes,6,rep,name=leading_detached_comments,json=leadingDetachedComments" json:"leading_detached_comments,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *SourceCodeInfo_Location) Reset() { *m = SourceCodeInfo_Location{} } +func (m *SourceCodeInfo_Location) String() string { return proto.CompactTextString(m) } +func (*SourceCodeInfo_Location) ProtoMessage() {} +func (*SourceCodeInfo_Location) Descriptor() ([]byte, []int) { + return fileDescriptorDescriptor, []int{19, 0} +} + +func (m *SourceCodeInfo_Location) GetPath() []int32 { + if m != nil { + return m.Path + } + return nil +} + +func (m *SourceCodeInfo_Location) GetSpan() []int32 { + if m != nil { + return m.Span + } + return nil +} + +func (m *SourceCodeInfo_Location) GetLeadingComments() string { + if m != nil && m.LeadingComments != nil { + return *m.LeadingComments + } + return "" +} + +func (m *SourceCodeInfo_Location) GetTrailingComments() string { + if m != nil && m.TrailingComments != nil { + return *m.TrailingComments + } + return "" +} + +func (m *SourceCodeInfo_Location) GetLeadingDetachedComments() []string { + if m != nil { + return m.LeadingDetachedComments + } + return nil +} + +// Describes the relationship between generated code and its original source +// file. A GeneratedCodeInfo message is associated with only one generated +// source file, but may contain references to different source .proto files. +type GeneratedCodeInfo struct { + // An Annotation connects some span of text in generated code to an element + // of its generating .proto file. + Annotation []*GeneratedCodeInfo_Annotation `protobuf:"bytes,1,rep,name=annotation" json:"annotation,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GeneratedCodeInfo) Reset() { *m = GeneratedCodeInfo{} } +func (m *GeneratedCodeInfo) String() string { return proto.CompactTextString(m) } +func (*GeneratedCodeInfo) ProtoMessage() {} +func (*GeneratedCodeInfo) Descriptor() ([]byte, []int) { return fileDescriptorDescriptor, []int{20} } + +func (m *GeneratedCodeInfo) GetAnnotation() []*GeneratedCodeInfo_Annotation { + if m != nil { + return m.Annotation + } + return nil +} + +type GeneratedCodeInfo_Annotation struct { + // Identifies the element in the original source .proto file. This field + // is formatted the same as SourceCodeInfo.Location.path. + Path []int32 `protobuf:"varint,1,rep,packed,name=path" json:"path,omitempty"` + // Identifies the filesystem path to the original source .proto. + SourceFile *string `protobuf:"bytes,2,opt,name=source_file,json=sourceFile" json:"source_file,omitempty"` + // Identifies the starting offset in bytes in the generated code + // that relates to the identified object. + Begin *int32 `protobuf:"varint,3,opt,name=begin" json:"begin,omitempty"` + // Identifies the ending offset in bytes in the generated code that + // relates to the identified offset. The end offset should be one past + // the last relevant byte (so the length of the text = end - begin). + End *int32 `protobuf:"varint,4,opt,name=end" json:"end,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *GeneratedCodeInfo_Annotation) Reset() { *m = GeneratedCodeInfo_Annotation{} } +func (m *GeneratedCodeInfo_Annotation) String() string { return proto.CompactTextString(m) } +func (*GeneratedCodeInfo_Annotation) ProtoMessage() {} +func (*GeneratedCodeInfo_Annotation) Descriptor() ([]byte, []int) { + return fileDescriptorDescriptor, []int{20, 0} +} + +func (m *GeneratedCodeInfo_Annotation) GetPath() []int32 { + if m != nil { + return m.Path + } + return nil +} + +func (m *GeneratedCodeInfo_Annotation) GetSourceFile() string { + if m != nil && m.SourceFile != nil { + return *m.SourceFile + } + return "" +} + +func (m *GeneratedCodeInfo_Annotation) GetBegin() int32 { + if m != nil && m.Begin != nil { + return *m.Begin + } + return 0 +} + +func (m *GeneratedCodeInfo_Annotation) GetEnd() int32 { + if m != nil && m.End != nil { + return *m.End + } + return 0 +} + +func init() { + proto.RegisterType((*FileDescriptorSet)(nil), "google.protobuf.FileDescriptorSet") + proto.RegisterType((*FileDescriptorProto)(nil), "google.protobuf.FileDescriptorProto") + proto.RegisterType((*DescriptorProto)(nil), "google.protobuf.DescriptorProto") + proto.RegisterType((*DescriptorProto_ExtensionRange)(nil), "google.protobuf.DescriptorProto.ExtensionRange") + proto.RegisterType((*DescriptorProto_ReservedRange)(nil), "google.protobuf.DescriptorProto.ReservedRange") + proto.RegisterType((*ExtensionRangeOptions)(nil), "google.protobuf.ExtensionRangeOptions") + proto.RegisterType((*FieldDescriptorProto)(nil), "google.protobuf.FieldDescriptorProto") + proto.RegisterType((*OneofDescriptorProto)(nil), "google.protobuf.OneofDescriptorProto") + proto.RegisterType((*EnumDescriptorProto)(nil), "google.protobuf.EnumDescriptorProto") + proto.RegisterType((*EnumDescriptorProto_EnumReservedRange)(nil), "google.protobuf.EnumDescriptorProto.EnumReservedRange") + proto.RegisterType((*EnumValueDescriptorProto)(nil), "google.protobuf.EnumValueDescriptorProto") + proto.RegisterType((*ServiceDescriptorProto)(nil), "google.protobuf.ServiceDescriptorProto") + proto.RegisterType((*MethodDescriptorProto)(nil), "google.protobuf.MethodDescriptorProto") + proto.RegisterType((*FileOptions)(nil), "google.protobuf.FileOptions") + proto.RegisterType((*MessageOptions)(nil), "google.protobuf.MessageOptions") + proto.RegisterType((*FieldOptions)(nil), "google.protobuf.FieldOptions") + proto.RegisterType((*OneofOptions)(nil), "google.protobuf.OneofOptions") + proto.RegisterType((*EnumOptions)(nil), "google.protobuf.EnumOptions") + proto.RegisterType((*EnumValueOptions)(nil), "google.protobuf.EnumValueOptions") + proto.RegisterType((*ServiceOptions)(nil), "google.protobuf.ServiceOptions") + proto.RegisterType((*MethodOptions)(nil), "google.protobuf.MethodOptions") + proto.RegisterType((*UninterpretedOption)(nil), "google.protobuf.UninterpretedOption") + proto.RegisterType((*UninterpretedOption_NamePart)(nil), "google.protobuf.UninterpretedOption.NamePart") + proto.RegisterType((*SourceCodeInfo)(nil), "google.protobuf.SourceCodeInfo") + proto.RegisterType((*SourceCodeInfo_Location)(nil), "google.protobuf.SourceCodeInfo.Location") + proto.RegisterType((*GeneratedCodeInfo)(nil), "google.protobuf.GeneratedCodeInfo") + proto.RegisterType((*GeneratedCodeInfo_Annotation)(nil), "google.protobuf.GeneratedCodeInfo.Annotation") + proto.RegisterEnum("google.protobuf.FieldDescriptorProto_Type", FieldDescriptorProto_Type_name, FieldDescriptorProto_Type_value) + proto.RegisterEnum("google.protobuf.FieldDescriptorProto_Label", FieldDescriptorProto_Label_name, FieldDescriptorProto_Label_value) + proto.RegisterEnum("google.protobuf.FileOptions_OptimizeMode", FileOptions_OptimizeMode_name, FileOptions_OptimizeMode_value) + proto.RegisterEnum("google.protobuf.FieldOptions_CType", FieldOptions_CType_name, FieldOptions_CType_value) + proto.RegisterEnum("google.protobuf.FieldOptions_JSType", FieldOptions_JSType_name, FieldOptions_JSType_value) + proto.RegisterEnum("google.protobuf.MethodOptions_IdempotencyLevel", MethodOptions_IdempotencyLevel_name, MethodOptions_IdempotencyLevel_value) +} + +func init() { proto.RegisterFile("descriptor.proto", fileDescriptorDescriptor) } + +var fileDescriptorDescriptor = []byte{ + // 2487 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xc4, 0x59, 0xcd, 0x6f, 0xdb, 0xc8, + 0x15, 0x5f, 0x7d, 0x5a, 0x7a, 0x92, 0xe5, 0xf1, 0xd8, 0x9b, 0x30, 0xde, 0x8f, 0x38, 0xda, 0x8f, + 0x38, 0x49, 0xab, 0x2c, 0x9c, 0xc4, 0xc9, 0x3a, 0xc5, 0xb6, 0xb2, 0xc4, 0x78, 0x95, 0xca, 0x92, + 0x4a, 0xc9, 0xdd, 0x64, 0x8b, 0x82, 0x18, 0x93, 0x23, 0x89, 0x09, 0x45, 0x72, 0x49, 0x2a, 0x89, + 0x83, 0x1e, 0x02, 0xf4, 0xd4, 0xff, 0xa0, 0x28, 0x8a, 0x1e, 0x7a, 0x59, 0xa0, 0xd7, 0x02, 0x05, + 0xda, 0x7b, 0xaf, 0x05, 0x7a, 0xef, 0xa1, 0x40, 0x0b, 0xb4, 0x7f, 0x42, 0x8f, 0xc5, 0xcc, 0x90, + 0x14, 0xf5, 0x95, 0x78, 0x17, 0x48, 0xf6, 0x64, 0xcf, 0xef, 0xfd, 0xde, 0xe3, 0x9b, 0x37, 0x6f, + 0xde, 0xbc, 0x19, 0x01, 0xd2, 0xa9, 0xa7, 0xb9, 0x86, 0xe3, 0xdb, 0x6e, 0xc5, 0x71, 0x6d, 0xdf, + 0xc6, 0x6b, 0x03, 0xdb, 0x1e, 0x98, 0x54, 0x8c, 0x4e, 0xc6, 0xfd, 0xf2, 0x11, 0xac, 0xdf, 0x33, + 0x4c, 0x5a, 0x8f, 0x88, 0x5d, 0xea, 0xe3, 0x3b, 0x90, 0xee, 0x1b, 0x26, 0x95, 0x12, 0xdb, 0xa9, + 0x9d, 0xc2, 0xee, 0x87, 0x95, 0x19, 0xa5, 0xca, 0xb4, 0x46, 0x87, 0xc1, 0x0a, 0xd7, 0x28, 0xff, + 0x3b, 0x0d, 0x1b, 0x0b, 0xa4, 0x18, 0x43, 0xda, 0x22, 0x23, 0x66, 0x31, 0xb1, 0x93, 0x57, 0xf8, + 0xff, 0x58, 0x82, 0x15, 0x87, 0x68, 0x8f, 0xc9, 0x80, 0x4a, 0x49, 0x0e, 0x87, 0x43, 0xfc, 0x3e, + 0x80, 0x4e, 0x1d, 0x6a, 0xe9, 0xd4, 0xd2, 0x4e, 0xa5, 0xd4, 0x76, 0x6a, 0x27, 0xaf, 0xc4, 0x10, + 0x7c, 0x0d, 0xd6, 0x9d, 0xf1, 0x89, 0x69, 0x68, 0x6a, 0x8c, 0x06, 0xdb, 0xa9, 0x9d, 0x8c, 0x82, + 0x84, 0xa0, 0x3e, 0x21, 0x5f, 0x86, 0xb5, 0xa7, 0x94, 0x3c, 0x8e, 0x53, 0x0b, 0x9c, 0x5a, 0x62, + 0x70, 0x8c, 0x58, 0x83, 0xe2, 0x88, 0x7a, 0x1e, 0x19, 0x50, 0xd5, 0x3f, 0x75, 0xa8, 0x94, 0xe6, + 0xb3, 0xdf, 0x9e, 0x9b, 0xfd, 0xec, 0xcc, 0x0b, 0x81, 0x56, 0xef, 0xd4, 0xa1, 0xb8, 0x0a, 0x79, + 0x6a, 0x8d, 0x47, 0xc2, 0x42, 0x66, 0x49, 0xfc, 0x64, 0x6b, 0x3c, 0x9a, 0xb5, 0x92, 0x63, 0x6a, + 0x81, 0x89, 0x15, 0x8f, 0xba, 0x4f, 0x0c, 0x8d, 0x4a, 0x59, 0x6e, 0xe0, 0xf2, 0x9c, 0x81, 0xae, + 0x90, 0xcf, 0xda, 0x08, 0xf5, 0x70, 0x0d, 0xf2, 0xf4, 0x99, 0x4f, 0x2d, 0xcf, 0xb0, 0x2d, 0x69, + 0x85, 0x1b, 0xf9, 0x68, 0xc1, 0x2a, 0x52, 0x53, 0x9f, 0x35, 0x31, 0xd1, 0xc3, 0x7b, 0xb0, 0x62, + 0x3b, 0xbe, 0x61, 0x5b, 0x9e, 0x94, 0xdb, 0x4e, 0xec, 0x14, 0x76, 0xdf, 0x5d, 0x98, 0x08, 0x6d, + 0xc1, 0x51, 0x42, 0x32, 0x6e, 0x00, 0xf2, 0xec, 0xb1, 0xab, 0x51, 0x55, 0xb3, 0x75, 0xaa, 0x1a, + 0x56, 0xdf, 0x96, 0xf2, 0xdc, 0xc0, 0xc5, 0xf9, 0x89, 0x70, 0x62, 0xcd, 0xd6, 0x69, 0xc3, 0xea, + 0xdb, 0x4a, 0xc9, 0x9b, 0x1a, 0xe3, 0x73, 0x90, 0xf5, 0x4e, 0x2d, 0x9f, 0x3c, 0x93, 0x8a, 0x3c, + 0x43, 0x82, 0x51, 0xf9, 0xcf, 0x59, 0x58, 0x3b, 0x4b, 0x8a, 0xdd, 0x85, 0x4c, 0x9f, 0xcd, 0x52, + 0x4a, 0x7e, 0x93, 0x18, 0x08, 0x9d, 0xe9, 0x20, 0x66, 0xbf, 0x65, 0x10, 0xab, 0x50, 0xb0, 0xa8, + 0xe7, 0x53, 0x5d, 0x64, 0x44, 0xea, 0x8c, 0x39, 0x05, 0x42, 0x69, 0x3e, 0xa5, 0xd2, 0xdf, 0x2a, + 0xa5, 0x1e, 0xc0, 0x5a, 0xe4, 0x92, 0xea, 0x12, 0x6b, 0x10, 0xe6, 0xe6, 0xf5, 0x57, 0x79, 0x52, + 0x91, 0x43, 0x3d, 0x85, 0xa9, 0x29, 0x25, 0x3a, 0x35, 0xc6, 0x75, 0x00, 0xdb, 0xa2, 0x76, 0x5f, + 0xd5, 0xa9, 0x66, 0x4a, 0xb9, 0x25, 0x51, 0x6a, 0x33, 0xca, 0x5c, 0x94, 0x6c, 0x81, 0x6a, 0x26, + 0xfe, 0x74, 0x92, 0x6a, 0x2b, 0x4b, 0x32, 0xe5, 0x48, 0x6c, 0xb2, 0xb9, 0x6c, 0x3b, 0x86, 0x92, + 0x4b, 0x59, 0xde, 0x53, 0x3d, 0x98, 0x59, 0x9e, 0x3b, 0x51, 0x79, 0xe5, 0xcc, 0x94, 0x40, 0x4d, + 0x4c, 0x6c, 0xd5, 0x8d, 0x0f, 0xf1, 0x07, 0x10, 0x01, 0x2a, 0x4f, 0x2b, 0xe0, 0x55, 0xa8, 0x18, + 0x82, 0x2d, 0x32, 0xa2, 0x5b, 0xcf, 0xa1, 0x34, 0x1d, 0x1e, 0xbc, 0x09, 0x19, 0xcf, 0x27, 0xae, + 0xcf, 0xb3, 0x30, 0xa3, 0x88, 0x01, 0x46, 0x90, 0xa2, 0x96, 0xce, 0xab, 0x5c, 0x46, 0x61, 0xff, + 0xe2, 0x1f, 0x4d, 0x26, 0x9c, 0xe2, 0x13, 0xfe, 0x78, 0x7e, 0x45, 0xa7, 0x2c, 0xcf, 0xce, 0x7b, + 0xeb, 0x36, 0xac, 0x4e, 0x4d, 0xe0, 0xac, 0x9f, 0x2e, 0xff, 0x02, 0xde, 0x5e, 0x68, 0x1a, 0x3f, + 0x80, 0xcd, 0xb1, 0x65, 0x58, 0x3e, 0x75, 0x1d, 0x97, 0xb2, 0x8c, 0x15, 0x9f, 0x92, 0xfe, 0xb3, + 0xb2, 0x24, 0xe7, 0x8e, 0xe3, 0x6c, 0x61, 0x45, 0xd9, 0x18, 0xcf, 0x83, 0x57, 0xf3, 0xb9, 0xff, + 0xae, 0xa0, 0x17, 0x2f, 0x5e, 0xbc, 0x48, 0x96, 0x7f, 0x9d, 0x85, 0xcd, 0x45, 0x7b, 0x66, 0xe1, + 0xf6, 0x3d, 0x07, 0x59, 0x6b, 0x3c, 0x3a, 0xa1, 0x2e, 0x0f, 0x52, 0x46, 0x09, 0x46, 0xb8, 0x0a, + 0x19, 0x93, 0x9c, 0x50, 0x53, 0x4a, 0x6f, 0x27, 0x76, 0x4a, 0xbb, 0xd7, 0xce, 0xb4, 0x2b, 0x2b, + 0x4d, 0xa6, 0xa2, 0x08, 0x4d, 0xfc, 0x19, 0xa4, 0x83, 0x12, 0xcd, 0x2c, 0x5c, 0x3d, 0x9b, 0x05, + 0xb6, 0x97, 0x14, 0xae, 0x87, 0xdf, 0x81, 0x3c, 0xfb, 0x2b, 0x72, 0x23, 0xcb, 0x7d, 0xce, 0x31, + 0x80, 0xe5, 0x05, 0xde, 0x82, 0x1c, 0xdf, 0x26, 0x3a, 0x0d, 0x8f, 0xb6, 0x68, 0xcc, 0x12, 0x4b, + 0xa7, 0x7d, 0x32, 0x36, 0x7d, 0xf5, 0x09, 0x31, 0xc7, 0x94, 0x27, 0x7c, 0x5e, 0x29, 0x06, 0xe0, + 0x4f, 0x19, 0x86, 0x2f, 0x42, 0x41, 0xec, 0x2a, 0xc3, 0xd2, 0xe9, 0x33, 0x5e, 0x3d, 0x33, 0x8a, + 0xd8, 0x68, 0x0d, 0x86, 0xb0, 0xcf, 0x3f, 0xf2, 0x6c, 0x2b, 0x4c, 0x4d, 0xfe, 0x09, 0x06, 0xf0, + 0xcf, 0xdf, 0x9e, 0x2d, 0xdc, 0xef, 0x2d, 0x9e, 0xde, 0x6c, 0x4e, 0x95, 0xff, 0x94, 0x84, 0x34, + 0xaf, 0x17, 0x6b, 0x50, 0xe8, 0x3d, 0xec, 0xc8, 0x6a, 0xbd, 0x7d, 0x7c, 0xd0, 0x94, 0x51, 0x02, + 0x97, 0x00, 0x38, 0x70, 0xaf, 0xd9, 0xae, 0xf6, 0x50, 0x32, 0x1a, 0x37, 0x5a, 0xbd, 0xbd, 0x9b, + 0x28, 0x15, 0x29, 0x1c, 0x0b, 0x20, 0x1d, 0x27, 0xdc, 0xd8, 0x45, 0x19, 0x8c, 0xa0, 0x28, 0x0c, + 0x34, 0x1e, 0xc8, 0xf5, 0xbd, 0x9b, 0x28, 0x3b, 0x8d, 0xdc, 0xd8, 0x45, 0x2b, 0x78, 0x15, 0xf2, + 0x1c, 0x39, 0x68, 0xb7, 0x9b, 0x28, 0x17, 0xd9, 0xec, 0xf6, 0x94, 0x46, 0xeb, 0x10, 0xe5, 0x23, + 0x9b, 0x87, 0x4a, 0xfb, 0xb8, 0x83, 0x20, 0xb2, 0x70, 0x24, 0x77, 0xbb, 0xd5, 0x43, 0x19, 0x15, + 0x22, 0xc6, 0xc1, 0xc3, 0x9e, 0xdc, 0x45, 0xc5, 0x29, 0xb7, 0x6e, 0xec, 0xa2, 0xd5, 0xe8, 0x13, + 0x72, 0xeb, 0xf8, 0x08, 0x95, 0xf0, 0x3a, 0xac, 0x8a, 0x4f, 0x84, 0x4e, 0xac, 0xcd, 0x40, 0x7b, + 0x37, 0x11, 0x9a, 0x38, 0x22, 0xac, 0xac, 0x4f, 0x01, 0x7b, 0x37, 0x11, 0x2e, 0xd7, 0x20, 0xc3, + 0xb3, 0x0b, 0x63, 0x28, 0x35, 0xab, 0x07, 0x72, 0x53, 0x6d, 0x77, 0x7a, 0x8d, 0x76, 0xab, 0xda, + 0x44, 0x89, 0x09, 0xa6, 0xc8, 0x3f, 0x39, 0x6e, 0x28, 0x72, 0x1d, 0x25, 0xe3, 0x58, 0x47, 0xae, + 0xf6, 0xe4, 0x3a, 0x4a, 0x95, 0x35, 0xd8, 0x5c, 0x54, 0x27, 0x17, 0xee, 0x8c, 0xd8, 0x12, 0x27, + 0x97, 0x2c, 0x31, 0xb7, 0x35, 0xb7, 0xc4, 0xff, 0x4a, 0xc2, 0xc6, 0x82, 0xb3, 0x62, 0xe1, 0x47, + 0x7e, 0x08, 0x19, 0x91, 0xa2, 0xe2, 0xf4, 0xbc, 0xb2, 0xf0, 0xd0, 0xe1, 0x09, 0x3b, 0x77, 0x82, + 0x72, 0xbd, 0x78, 0x07, 0x91, 0x5a, 0xd2, 0x41, 0x30, 0x13, 0x73, 0x35, 0xfd, 0xe7, 0x73, 0x35, + 0x5d, 0x1c, 0x7b, 0x7b, 0x67, 0x39, 0xf6, 0x38, 0xf6, 0xcd, 0x6a, 0x7b, 0x66, 0x41, 0x6d, 0xbf, + 0x0b, 0xeb, 0x73, 0x86, 0xce, 0x5c, 0x63, 0x7f, 0x99, 0x00, 0x69, 0x59, 0x70, 0x5e, 0x51, 0xe9, + 0x92, 0x53, 0x95, 0xee, 0xee, 0x6c, 0x04, 0x2f, 0x2d, 0x5f, 0x84, 0xb9, 0xb5, 0xfe, 0x3a, 0x01, + 0xe7, 0x16, 0x77, 0x8a, 0x0b, 0x7d, 0xf8, 0x0c, 0xb2, 0x23, 0xea, 0x0f, 0xed, 0xb0, 0x5b, 0xfa, + 0x78, 0xc1, 0x19, 0xcc, 0xc4, 0xb3, 0x8b, 0x1d, 0x68, 0xc5, 0x0f, 0xf1, 0xd4, 0xb2, 0x76, 0x4f, + 0x78, 0x33, 0xe7, 0xe9, 0xaf, 0x92, 0xf0, 0xf6, 0x42, 0xe3, 0x0b, 0x1d, 0x7d, 0x0f, 0xc0, 0xb0, + 0x9c, 0xb1, 0x2f, 0x3a, 0x22, 0x51, 0x60, 0xf3, 0x1c, 0xe1, 0xc5, 0x8b, 0x15, 0xcf, 0xb1, 0x1f, + 0xc9, 0x53, 0x5c, 0x0e, 0x02, 0xe2, 0x84, 0x3b, 0x13, 0x47, 0xd3, 0xdc, 0xd1, 0xf7, 0x97, 0xcc, + 0x74, 0x2e, 0x31, 0x3f, 0x01, 0xa4, 0x99, 0x06, 0xb5, 0x7c, 0xd5, 0xf3, 0x5d, 0x4a, 0x46, 0x86, + 0x35, 0xe0, 0x27, 0x48, 0x6e, 0x3f, 0xd3, 0x27, 0xa6, 0x47, 0x95, 0x35, 0x21, 0xee, 0x86, 0x52, + 0xa6, 0xc1, 0x13, 0xc8, 0x8d, 0x69, 0x64, 0xa7, 0x34, 0x84, 0x38, 0xd2, 0x28, 0xff, 0x31, 0x07, + 0x85, 0x58, 0x5f, 0x8d, 0x2f, 0x41, 0xf1, 0x11, 0x79, 0x42, 0xd4, 0xf0, 0xae, 0x24, 0x22, 0x51, + 0x60, 0x58, 0x27, 0xb8, 0x2f, 0x7d, 0x02, 0x9b, 0x9c, 0x62, 0x8f, 0x7d, 0xea, 0xaa, 0x9a, 0x49, + 0x3c, 0x8f, 0x07, 0x2d, 0xc7, 0xa9, 0x98, 0xc9, 0xda, 0x4c, 0x54, 0x0b, 0x25, 0xf8, 0x16, 0x6c, + 0x70, 0x8d, 0xd1, 0xd8, 0xf4, 0x0d, 0xc7, 0xa4, 0x2a, 0xbb, 0xbd, 0x79, 0xfc, 0x24, 0x89, 0x3c, + 0x5b, 0x67, 0x8c, 0xa3, 0x80, 0xc0, 0x3c, 0xf2, 0x70, 0x1d, 0xde, 0xe3, 0x6a, 0x03, 0x6a, 0x51, + 0x97, 0xf8, 0x54, 0xa5, 0x5f, 0x8d, 0x89, 0xe9, 0xa9, 0xc4, 0xd2, 0xd5, 0x21, 0xf1, 0x86, 0xd2, + 0x26, 0x33, 0x70, 0x90, 0x94, 0x12, 0xca, 0x05, 0x46, 0x3c, 0x0c, 0x78, 0x32, 0xa7, 0x55, 0x2d, + 0xfd, 0x73, 0xe2, 0x0d, 0xf1, 0x3e, 0x9c, 0xe3, 0x56, 0x3c, 0xdf, 0x35, 0xac, 0x81, 0xaa, 0x0d, + 0xa9, 0xf6, 0x58, 0x1d, 0xfb, 0xfd, 0x3b, 0xd2, 0x3b, 0xf1, 0xef, 0x73, 0x0f, 0xbb, 0x9c, 0x53, + 0x63, 0x94, 0x63, 0xbf, 0x7f, 0x07, 0x77, 0xa1, 0xc8, 0x16, 0x63, 0x64, 0x3c, 0xa7, 0x6a, 0xdf, + 0x76, 0xf9, 0xd1, 0x58, 0x5a, 0x50, 0x9a, 0x62, 0x11, 0xac, 0xb4, 0x03, 0x85, 0x23, 0x5b, 0xa7, + 0xfb, 0x99, 0x6e, 0x47, 0x96, 0xeb, 0x4a, 0x21, 0xb4, 0x72, 0xcf, 0x76, 0x59, 0x42, 0x0d, 0xec, + 0x28, 0xc0, 0x05, 0x91, 0x50, 0x03, 0x3b, 0x0c, 0xef, 0x2d, 0xd8, 0xd0, 0x34, 0x31, 0x67, 0x43, + 0x53, 0x83, 0x3b, 0x96, 0x27, 0xa1, 0xa9, 0x60, 0x69, 0xda, 0xa1, 0x20, 0x04, 0x39, 0xee, 0xe1, + 0x4f, 0xe1, 0xed, 0x49, 0xb0, 0xe2, 0x8a, 0xeb, 0x73, 0xb3, 0x9c, 0x55, 0xbd, 0x05, 0x1b, 0xce, + 0xe9, 0xbc, 0x22, 0x9e, 0xfa, 0xa2, 0x73, 0x3a, 0xab, 0x76, 0x1b, 0x36, 0x9d, 0xa1, 0x33, 0xaf, + 0x77, 0x35, 0xae, 0x87, 0x9d, 0xa1, 0x33, 0xab, 0xf8, 0x11, 0xbf, 0x70, 0xbb, 0x54, 0x23, 0x3e, + 0xd5, 0xa5, 0xf3, 0x71, 0x7a, 0x4c, 0x80, 0xaf, 0x03, 0xd2, 0x34, 0x95, 0x5a, 0xe4, 0xc4, 0xa4, + 0x2a, 0x71, 0xa9, 0x45, 0x3c, 0xe9, 0x62, 0x9c, 0x5c, 0xd2, 0x34, 0x99, 0x4b, 0xab, 0x5c, 0x88, + 0xaf, 0xc2, 0xba, 0x7d, 0xf2, 0x48, 0x13, 0x29, 0xa9, 0x3a, 0x2e, 0xed, 0x1b, 0xcf, 0xa4, 0x0f, + 0x79, 0x7c, 0xd7, 0x98, 0x80, 0x27, 0x64, 0x87, 0xc3, 0xf8, 0x0a, 0x20, 0xcd, 0x1b, 0x12, 0xd7, + 0xe1, 0x35, 0xd9, 0x73, 0x88, 0x46, 0xa5, 0x8f, 0x04, 0x55, 0xe0, 0xad, 0x10, 0x66, 0x5b, 0xc2, + 0x7b, 0x6a, 0xf4, 0xfd, 0xd0, 0xe2, 0x65, 0xb1, 0x25, 0x38, 0x16, 0x58, 0xdb, 0x01, 0xc4, 0x42, + 0x31, 0xf5, 0xe1, 0x1d, 0x4e, 0x2b, 0x39, 0x43, 0x27, 0xfe, 0xdd, 0x0f, 0x60, 0x95, 0x31, 0x27, + 0x1f, 0xbd, 0x22, 0x1a, 0x32, 0x67, 0x18, 0xfb, 0xe2, 0x6b, 0xeb, 0x8d, 0xcb, 0xfb, 0x50, 0x8c, + 0xe7, 0x27, 0xce, 0x83, 0xc8, 0x50, 0x94, 0x60, 0xcd, 0x4a, 0xad, 0x5d, 0x67, 0x6d, 0xc6, 0x97, + 0x32, 0x4a, 0xb2, 0x76, 0xa7, 0xd9, 0xe8, 0xc9, 0xaa, 0x72, 0xdc, 0xea, 0x35, 0x8e, 0x64, 0x94, + 0x8a, 0xf7, 0xd5, 0x7f, 0x4d, 0x42, 0x69, 0xfa, 0x8a, 0x84, 0x7f, 0x00, 0xe7, 0xc3, 0xf7, 0x0c, + 0x8f, 0xfa, 0xea, 0x53, 0xc3, 0xe5, 0x5b, 0x66, 0x44, 0xc4, 0xf1, 0x15, 0x2d, 0xda, 0x66, 0xc0, + 0xea, 0x52, 0xff, 0x0b, 0xc3, 0x65, 0x1b, 0x62, 0x44, 0x7c, 0xdc, 0x84, 0x8b, 0x96, 0xad, 0x7a, + 0x3e, 0xb1, 0x74, 0xe2, 0xea, 0xea, 0xe4, 0x25, 0x49, 0x25, 0x9a, 0x46, 0x3d, 0xcf, 0x16, 0x47, + 0x55, 0x64, 0xe5, 0x5d, 0xcb, 0xee, 0x06, 0xe4, 0x49, 0x0d, 0xaf, 0x06, 0xd4, 0x99, 0x04, 0x4b, + 0x2d, 0x4b, 0xb0, 0x77, 0x20, 0x3f, 0x22, 0x8e, 0x4a, 0x2d, 0xdf, 0x3d, 0xe5, 0x8d, 0x71, 0x4e, + 0xc9, 0x8d, 0x88, 0x23, 0xb3, 0xf1, 0x9b, 0xb9, 0x9f, 0xfc, 0x23, 0x05, 0xc5, 0x78, 0x73, 0xcc, + 0xee, 0x1a, 0x1a, 0x3f, 0x47, 0x12, 0xbc, 0xd2, 0x7c, 0xf0, 0xd2, 0x56, 0xba, 0x52, 0x63, 0x07, + 0xcc, 0x7e, 0x56, 0xb4, 0xac, 0x8a, 0xd0, 0x64, 0x87, 0x3b, 0xab, 0x2d, 0x54, 0xb4, 0x08, 0x39, + 0x25, 0x18, 0xe1, 0x43, 0xc8, 0x3e, 0xf2, 0xb8, 0xed, 0x2c, 0xb7, 0xfd, 0xe1, 0xcb, 0x6d, 0xdf, + 0xef, 0x72, 0xe3, 0xf9, 0xfb, 0x5d, 0xb5, 0xd5, 0x56, 0x8e, 0xaa, 0x4d, 0x25, 0x50, 0xc7, 0x17, + 0x20, 0x6d, 0x92, 0xe7, 0xa7, 0xd3, 0x47, 0x11, 0x87, 0xce, 0x1a, 0xf8, 0x0b, 0x90, 0x7e, 0x4a, + 0xc9, 0xe3, 0xe9, 0x03, 0x80, 0x43, 0xaf, 0x31, 0xf5, 0xaf, 0x43, 0x86, 0xc7, 0x0b, 0x03, 0x04, + 0x11, 0x43, 0x6f, 0xe1, 0x1c, 0xa4, 0x6b, 0x6d, 0x85, 0xa5, 0x3f, 0x82, 0xa2, 0x40, 0xd5, 0x4e, + 0x43, 0xae, 0xc9, 0x28, 0x59, 0xbe, 0x05, 0x59, 0x11, 0x04, 0xb6, 0x35, 0xa2, 0x30, 0xa0, 0xb7, + 0x82, 0x61, 0x60, 0x23, 0x11, 0x4a, 0x8f, 0x8f, 0x0e, 0x64, 0x05, 0x25, 0xe3, 0xcb, 0xeb, 0x41, + 0x31, 0xde, 0x17, 0xbf, 0x99, 0x9c, 0xfa, 0x4b, 0x02, 0x0a, 0xb1, 0x3e, 0x97, 0x35, 0x28, 0xc4, + 0x34, 0xed, 0xa7, 0x2a, 0x31, 0x0d, 0xe2, 0x05, 0x49, 0x01, 0x1c, 0xaa, 0x32, 0xe4, 0xac, 0x8b, + 0xf6, 0x46, 0x9c, 0xff, 0x5d, 0x02, 0xd0, 0x6c, 0x8b, 0x39, 0xe3, 0x60, 0xe2, 0x3b, 0x75, 0xf0, + 0xb7, 0x09, 0x28, 0x4d, 0xf7, 0x95, 0x33, 0xee, 0x5d, 0xfa, 0x4e, 0xdd, 0xfb, 0x67, 0x12, 0x56, + 0xa7, 0xba, 0xc9, 0xb3, 0x7a, 0xf7, 0x15, 0xac, 0x1b, 0x3a, 0x1d, 0x39, 0xb6, 0x4f, 0x2d, 0xed, + 0x54, 0x35, 0xe9, 0x13, 0x6a, 0x4a, 0x65, 0x5e, 0x28, 0xae, 0xbf, 0xbc, 0x5f, 0xad, 0x34, 0x26, + 0x7a, 0x4d, 0xa6, 0xb6, 0xbf, 0xd1, 0xa8, 0xcb, 0x47, 0x9d, 0x76, 0x4f, 0x6e, 0xd5, 0x1e, 0xaa, + 0xc7, 0xad, 0x1f, 0xb7, 0xda, 0x5f, 0xb4, 0x14, 0x64, 0xcc, 0xd0, 0x5e, 0xe3, 0x56, 0xef, 0x00, + 0x9a, 0x75, 0x0a, 0x9f, 0x87, 0x45, 0x6e, 0xa1, 0xb7, 0xf0, 0x06, 0xac, 0xb5, 0xda, 0x6a, 0xb7, + 0x51, 0x97, 0x55, 0xf9, 0xde, 0x3d, 0xb9, 0xd6, 0xeb, 0x8a, 0x17, 0x88, 0x88, 0xdd, 0x9b, 0xde, + 0xd4, 0xbf, 0x49, 0xc1, 0xc6, 0x02, 0x4f, 0x70, 0x35, 0xb8, 0x3b, 0x88, 0xeb, 0xcc, 0xf7, 0xcf, + 0xe2, 0x7d, 0x85, 0x1d, 0xf9, 0x1d, 0xe2, 0xfa, 0xc1, 0x55, 0xe3, 0x0a, 0xb0, 0x28, 0x59, 0xbe, + 0xd1, 0x37, 0xa8, 0x1b, 0x3c, 0xd8, 0x88, 0x0b, 0xc5, 0xda, 0x04, 0x17, 0x6f, 0x36, 0xdf, 0x03, + 0xec, 0xd8, 0x9e, 0xe1, 0x1b, 0x4f, 0xa8, 0x6a, 0x58, 0xe1, 0xeb, 0x0e, 0xbb, 0x60, 0xa4, 0x15, + 0x14, 0x4a, 0x1a, 0x96, 0x1f, 0xb1, 0x2d, 0x3a, 0x20, 0x33, 0x6c, 0x56, 0xc0, 0x53, 0x0a, 0x0a, + 0x25, 0x11, 0xfb, 0x12, 0x14, 0x75, 0x7b, 0xcc, 0xba, 0x2e, 0xc1, 0x63, 0xe7, 0x45, 0x42, 0x29, + 0x08, 0x2c, 0xa2, 0x04, 0xfd, 0xf4, 0xe4, 0x59, 0xa9, 0xa8, 0x14, 0x04, 0x26, 0x28, 0x97, 0x61, + 0x8d, 0x0c, 0x06, 0x2e, 0x33, 0x1e, 0x1a, 0x12, 0x37, 0x84, 0x52, 0x04, 0x73, 0xe2, 0xd6, 0x7d, + 0xc8, 0x85, 0x71, 0x60, 0x47, 0x32, 0x8b, 0x84, 0xea, 0x88, 0x6b, 0x6f, 0x72, 0x27, 0xaf, 0xe4, + 0xac, 0x50, 0x78, 0x09, 0x8a, 0x86, 0xa7, 0x4e, 0x5e, 0xc9, 0x93, 0xdb, 0xc9, 0x9d, 0x9c, 0x52, + 0x30, 0xbc, 0xe8, 0x85, 0xb1, 0xfc, 0x75, 0x12, 0x4a, 0xd3, 0xaf, 0xfc, 0xb8, 0x0e, 0x39, 0xd3, + 0xd6, 0x08, 0x4f, 0x2d, 0xf1, 0x13, 0xd3, 0xce, 0x2b, 0x7e, 0x18, 0xa8, 0x34, 0x03, 0xbe, 0x12, + 0x69, 0x6e, 0xfd, 0x2d, 0x01, 0xb9, 0x10, 0xc6, 0xe7, 0x20, 0xed, 0x10, 0x7f, 0xc8, 0xcd, 0x65, + 0x0e, 0x92, 0x28, 0xa1, 0xf0, 0x31, 0xc3, 0x3d, 0x87, 0x58, 0x3c, 0x05, 0x02, 0x9c, 0x8d, 0xd9, + 0xba, 0x9a, 0x94, 0xe8, 0xfc, 0xfa, 0x61, 0x8f, 0x46, 0xd4, 0xf2, 0xbd, 0x70, 0x5d, 0x03, 0xbc, + 0x16, 0xc0, 0xf8, 0x1a, 0xac, 0xfb, 0x2e, 0x31, 0xcc, 0x29, 0x6e, 0x9a, 0x73, 0x51, 0x28, 0x88, + 0xc8, 0xfb, 0x70, 0x21, 0xb4, 0xab, 0x53, 0x9f, 0x68, 0x43, 0xaa, 0x4f, 0x94, 0xb2, 0xfc, 0x99, + 0xe1, 0x7c, 0x40, 0xa8, 0x07, 0xf2, 0x50, 0xb7, 0xfc, 0xf7, 0x04, 0xac, 0x87, 0x17, 0x26, 0x3d, + 0x0a, 0xd6, 0x11, 0x00, 0xb1, 0x2c, 0xdb, 0x8f, 0x87, 0x6b, 0x3e, 0x95, 0xe7, 0xf4, 0x2a, 0xd5, + 0x48, 0x49, 0x89, 0x19, 0xd8, 0x1a, 0x01, 0x4c, 0x24, 0x4b, 0xc3, 0x76, 0x11, 0x0a, 0xc1, 0x4f, + 0x38, 0xfc, 0x77, 0x40, 0x71, 0xc5, 0x06, 0x01, 0xb1, 0x9b, 0x15, 0xde, 0x84, 0xcc, 0x09, 0x1d, + 0x18, 0x56, 0xf0, 0x30, 0x2b, 0x06, 0xe1, 0x43, 0x48, 0x3a, 0x7a, 0x08, 0x39, 0xf8, 0x19, 0x6c, + 0x68, 0xf6, 0x68, 0xd6, 0xdd, 0x03, 0x34, 0x73, 0xcd, 0xf7, 0x3e, 0x4f, 0x7c, 0x09, 0x93, 0x16, + 0xf3, 0x7f, 0x89, 0xc4, 0xef, 0x93, 0xa9, 0xc3, 0xce, 0xc1, 0x1f, 0x92, 0x5b, 0x87, 0x42, 0xb5, + 0x13, 0xce, 0x54, 0xa1, 0x7d, 0x93, 0x6a, 0xcc, 0xfb, 0xff, 0x07, 0x00, 0x00, 0xff, 0xff, 0xa3, + 0x58, 0x22, 0x30, 0xdf, 0x1c, 0x00, 0x00, +} diff --git a/vendor/github.com/gogo/protobuf/protoc-gen-gogo/descriptor/descriptor_gostring.gen.go b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/descriptor/descriptor_gostring.gen.go new file mode 100644 index 00000000000..3b95a775754 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/descriptor/descriptor_gostring.gen.go @@ -0,0 +1,772 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: descriptor.proto + +/* +Package descriptor is a generated protocol buffer package. + +It is generated from these files: + descriptor.proto + +It has these top-level messages: + FileDescriptorSet + FileDescriptorProto + DescriptorProto + ExtensionRangeOptions + FieldDescriptorProto + OneofDescriptorProto + EnumDescriptorProto + EnumValueDescriptorProto + ServiceDescriptorProto + MethodDescriptorProto + FileOptions + MessageOptions + FieldOptions + OneofOptions + EnumOptions + EnumValueOptions + ServiceOptions + MethodOptions + UninterpretedOption + SourceCodeInfo + GeneratedCodeInfo +*/ +package descriptor + +import fmt "fmt" +import strings "strings" +import proto "github.com/gogo/protobuf/proto" +import sort "sort" +import strconv "strconv" +import reflect "reflect" +import math "math" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +func (this *FileDescriptorSet) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 5) + s = append(s, "&descriptor.FileDescriptorSet{") + if this.File != nil { + s = append(s, "File: "+fmt.Sprintf("%#v", this.File)+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *FileDescriptorProto) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 16) + s = append(s, "&descriptor.FileDescriptorProto{") + if this.Name != nil { + s = append(s, "Name: "+valueToGoStringDescriptor(this.Name, "string")+",\n") + } + if this.Package != nil { + s = append(s, "Package: "+valueToGoStringDescriptor(this.Package, "string")+",\n") + } + if this.Dependency != nil { + s = append(s, "Dependency: "+fmt.Sprintf("%#v", this.Dependency)+",\n") + } + if this.PublicDependency != nil { + s = append(s, "PublicDependency: "+fmt.Sprintf("%#v", this.PublicDependency)+",\n") + } + if this.WeakDependency != nil { + s = append(s, "WeakDependency: "+fmt.Sprintf("%#v", this.WeakDependency)+",\n") + } + if this.MessageType != nil { + s = append(s, "MessageType: "+fmt.Sprintf("%#v", this.MessageType)+",\n") + } + if this.EnumType != nil { + s = append(s, "EnumType: "+fmt.Sprintf("%#v", this.EnumType)+",\n") + } + if this.Service != nil { + s = append(s, "Service: "+fmt.Sprintf("%#v", this.Service)+",\n") + } + if this.Extension != nil { + s = append(s, "Extension: "+fmt.Sprintf("%#v", this.Extension)+",\n") + } + if this.Options != nil { + s = append(s, "Options: "+fmt.Sprintf("%#v", this.Options)+",\n") + } + if this.SourceCodeInfo != nil { + s = append(s, "SourceCodeInfo: "+fmt.Sprintf("%#v", this.SourceCodeInfo)+",\n") + } + if this.Syntax != nil { + s = append(s, "Syntax: "+valueToGoStringDescriptor(this.Syntax, "string")+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *DescriptorProto) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 14) + s = append(s, "&descriptor.DescriptorProto{") + if this.Name != nil { + s = append(s, "Name: "+valueToGoStringDescriptor(this.Name, "string")+",\n") + } + if this.Field != nil { + s = append(s, "Field: "+fmt.Sprintf("%#v", this.Field)+",\n") + } + if this.Extension != nil { + s = append(s, "Extension: "+fmt.Sprintf("%#v", this.Extension)+",\n") + } + if this.NestedType != nil { + s = append(s, "NestedType: "+fmt.Sprintf("%#v", this.NestedType)+",\n") + } + if this.EnumType != nil { + s = append(s, "EnumType: "+fmt.Sprintf("%#v", this.EnumType)+",\n") + } + if this.ExtensionRange != nil { + s = append(s, "ExtensionRange: "+fmt.Sprintf("%#v", this.ExtensionRange)+",\n") + } + if this.OneofDecl != nil { + s = append(s, "OneofDecl: "+fmt.Sprintf("%#v", this.OneofDecl)+",\n") + } + if this.Options != nil { + s = append(s, "Options: "+fmt.Sprintf("%#v", this.Options)+",\n") + } + if this.ReservedRange != nil { + s = append(s, "ReservedRange: "+fmt.Sprintf("%#v", this.ReservedRange)+",\n") + } + if this.ReservedName != nil { + s = append(s, "ReservedName: "+fmt.Sprintf("%#v", this.ReservedName)+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *DescriptorProto_ExtensionRange) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 7) + s = append(s, "&descriptor.DescriptorProto_ExtensionRange{") + if this.Start != nil { + s = append(s, "Start: "+valueToGoStringDescriptor(this.Start, "int32")+",\n") + } + if this.End != nil { + s = append(s, "End: "+valueToGoStringDescriptor(this.End, "int32")+",\n") + } + if this.Options != nil { + s = append(s, "Options: "+fmt.Sprintf("%#v", this.Options)+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *DescriptorProto_ReservedRange) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 6) + s = append(s, "&descriptor.DescriptorProto_ReservedRange{") + if this.Start != nil { + s = append(s, "Start: "+valueToGoStringDescriptor(this.Start, "int32")+",\n") + } + if this.End != nil { + s = append(s, "End: "+valueToGoStringDescriptor(this.End, "int32")+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *ExtensionRangeOptions) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 5) + s = append(s, "&descriptor.ExtensionRangeOptions{") + if this.UninterpretedOption != nil { + s = append(s, "UninterpretedOption: "+fmt.Sprintf("%#v", this.UninterpretedOption)+",\n") + } + s = append(s, "XXX_InternalExtensions: "+extensionToGoStringDescriptor(this)+",\n") + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *FieldDescriptorProto) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 14) + s = append(s, "&descriptor.FieldDescriptorProto{") + if this.Name != nil { + s = append(s, "Name: "+valueToGoStringDescriptor(this.Name, "string")+",\n") + } + if this.Number != nil { + s = append(s, "Number: "+valueToGoStringDescriptor(this.Number, "int32")+",\n") + } + if this.Label != nil { + s = append(s, "Label: "+valueToGoStringDescriptor(this.Label, "FieldDescriptorProto_Label")+",\n") + } + if this.Type != nil { + s = append(s, "Type: "+valueToGoStringDescriptor(this.Type, "FieldDescriptorProto_Type")+",\n") + } + if this.TypeName != nil { + s = append(s, "TypeName: "+valueToGoStringDescriptor(this.TypeName, "string")+",\n") + } + if this.Extendee != nil { + s = append(s, "Extendee: "+valueToGoStringDescriptor(this.Extendee, "string")+",\n") + } + if this.DefaultValue != nil { + s = append(s, "DefaultValue: "+valueToGoStringDescriptor(this.DefaultValue, "string")+",\n") + } + if this.OneofIndex != nil { + s = append(s, "OneofIndex: "+valueToGoStringDescriptor(this.OneofIndex, "int32")+",\n") + } + if this.JsonName != nil { + s = append(s, "JsonName: "+valueToGoStringDescriptor(this.JsonName, "string")+",\n") + } + if this.Options != nil { + s = append(s, "Options: "+fmt.Sprintf("%#v", this.Options)+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *OneofDescriptorProto) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 6) + s = append(s, "&descriptor.OneofDescriptorProto{") + if this.Name != nil { + s = append(s, "Name: "+valueToGoStringDescriptor(this.Name, "string")+",\n") + } + if this.Options != nil { + s = append(s, "Options: "+fmt.Sprintf("%#v", this.Options)+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *EnumDescriptorProto) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 9) + s = append(s, "&descriptor.EnumDescriptorProto{") + if this.Name != nil { + s = append(s, "Name: "+valueToGoStringDescriptor(this.Name, "string")+",\n") + } + if this.Value != nil { + s = append(s, "Value: "+fmt.Sprintf("%#v", this.Value)+",\n") + } + if this.Options != nil { + s = append(s, "Options: "+fmt.Sprintf("%#v", this.Options)+",\n") + } + if this.ReservedRange != nil { + s = append(s, "ReservedRange: "+fmt.Sprintf("%#v", this.ReservedRange)+",\n") + } + if this.ReservedName != nil { + s = append(s, "ReservedName: "+fmt.Sprintf("%#v", this.ReservedName)+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *EnumDescriptorProto_EnumReservedRange) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 6) + s = append(s, "&descriptor.EnumDescriptorProto_EnumReservedRange{") + if this.Start != nil { + s = append(s, "Start: "+valueToGoStringDescriptor(this.Start, "int32")+",\n") + } + if this.End != nil { + s = append(s, "End: "+valueToGoStringDescriptor(this.End, "int32")+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *EnumValueDescriptorProto) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 7) + s = append(s, "&descriptor.EnumValueDescriptorProto{") + if this.Name != nil { + s = append(s, "Name: "+valueToGoStringDescriptor(this.Name, "string")+",\n") + } + if this.Number != nil { + s = append(s, "Number: "+valueToGoStringDescriptor(this.Number, "int32")+",\n") + } + if this.Options != nil { + s = append(s, "Options: "+fmt.Sprintf("%#v", this.Options)+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *ServiceDescriptorProto) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 7) + s = append(s, "&descriptor.ServiceDescriptorProto{") + if this.Name != nil { + s = append(s, "Name: "+valueToGoStringDescriptor(this.Name, "string")+",\n") + } + if this.Method != nil { + s = append(s, "Method: "+fmt.Sprintf("%#v", this.Method)+",\n") + } + if this.Options != nil { + s = append(s, "Options: "+fmt.Sprintf("%#v", this.Options)+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *MethodDescriptorProto) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 10) + s = append(s, "&descriptor.MethodDescriptorProto{") + if this.Name != nil { + s = append(s, "Name: "+valueToGoStringDescriptor(this.Name, "string")+",\n") + } + if this.InputType != nil { + s = append(s, "InputType: "+valueToGoStringDescriptor(this.InputType, "string")+",\n") + } + if this.OutputType != nil { + s = append(s, "OutputType: "+valueToGoStringDescriptor(this.OutputType, "string")+",\n") + } + if this.Options != nil { + s = append(s, "Options: "+fmt.Sprintf("%#v", this.Options)+",\n") + } + if this.ClientStreaming != nil { + s = append(s, "ClientStreaming: "+valueToGoStringDescriptor(this.ClientStreaming, "bool")+",\n") + } + if this.ServerStreaming != nil { + s = append(s, "ServerStreaming: "+valueToGoStringDescriptor(this.ServerStreaming, "bool")+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *FileOptions) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 23) + s = append(s, "&descriptor.FileOptions{") + if this.JavaPackage != nil { + s = append(s, "JavaPackage: "+valueToGoStringDescriptor(this.JavaPackage, "string")+",\n") + } + if this.JavaOuterClassname != nil { + s = append(s, "JavaOuterClassname: "+valueToGoStringDescriptor(this.JavaOuterClassname, "string")+",\n") + } + if this.JavaMultipleFiles != nil { + s = append(s, "JavaMultipleFiles: "+valueToGoStringDescriptor(this.JavaMultipleFiles, "bool")+",\n") + } + if this.JavaGenerateEqualsAndHash != nil { + s = append(s, "JavaGenerateEqualsAndHash: "+valueToGoStringDescriptor(this.JavaGenerateEqualsAndHash, "bool")+",\n") + } + if this.JavaStringCheckUtf8 != nil { + s = append(s, "JavaStringCheckUtf8: "+valueToGoStringDescriptor(this.JavaStringCheckUtf8, "bool")+",\n") + } + if this.OptimizeFor != nil { + s = append(s, "OptimizeFor: "+valueToGoStringDescriptor(this.OptimizeFor, "FileOptions_OptimizeMode")+",\n") + } + if this.GoPackage != nil { + s = append(s, "GoPackage: "+valueToGoStringDescriptor(this.GoPackage, "string")+",\n") + } + if this.CcGenericServices != nil { + s = append(s, "CcGenericServices: "+valueToGoStringDescriptor(this.CcGenericServices, "bool")+",\n") + } + if this.JavaGenericServices != nil { + s = append(s, "JavaGenericServices: "+valueToGoStringDescriptor(this.JavaGenericServices, "bool")+",\n") + } + if this.PyGenericServices != nil { + s = append(s, "PyGenericServices: "+valueToGoStringDescriptor(this.PyGenericServices, "bool")+",\n") + } + if this.PhpGenericServices != nil { + s = append(s, "PhpGenericServices: "+valueToGoStringDescriptor(this.PhpGenericServices, "bool")+",\n") + } + if this.Deprecated != nil { + s = append(s, "Deprecated: "+valueToGoStringDescriptor(this.Deprecated, "bool")+",\n") + } + if this.CcEnableArenas != nil { + s = append(s, "CcEnableArenas: "+valueToGoStringDescriptor(this.CcEnableArenas, "bool")+",\n") + } + if this.ObjcClassPrefix != nil { + s = append(s, "ObjcClassPrefix: "+valueToGoStringDescriptor(this.ObjcClassPrefix, "string")+",\n") + } + if this.CsharpNamespace != nil { + s = append(s, "CsharpNamespace: "+valueToGoStringDescriptor(this.CsharpNamespace, "string")+",\n") + } + if this.SwiftPrefix != nil { + s = append(s, "SwiftPrefix: "+valueToGoStringDescriptor(this.SwiftPrefix, "string")+",\n") + } + if this.PhpClassPrefix != nil { + s = append(s, "PhpClassPrefix: "+valueToGoStringDescriptor(this.PhpClassPrefix, "string")+",\n") + } + if this.PhpNamespace != nil { + s = append(s, "PhpNamespace: "+valueToGoStringDescriptor(this.PhpNamespace, "string")+",\n") + } + if this.UninterpretedOption != nil { + s = append(s, "UninterpretedOption: "+fmt.Sprintf("%#v", this.UninterpretedOption)+",\n") + } + s = append(s, "XXX_InternalExtensions: "+extensionToGoStringDescriptor(this)+",\n") + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *MessageOptions) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 9) + s = append(s, "&descriptor.MessageOptions{") + if this.MessageSetWireFormat != nil { + s = append(s, "MessageSetWireFormat: "+valueToGoStringDescriptor(this.MessageSetWireFormat, "bool")+",\n") + } + if this.NoStandardDescriptorAccessor != nil { + s = append(s, "NoStandardDescriptorAccessor: "+valueToGoStringDescriptor(this.NoStandardDescriptorAccessor, "bool")+",\n") + } + if this.Deprecated != nil { + s = append(s, "Deprecated: "+valueToGoStringDescriptor(this.Deprecated, "bool")+",\n") + } + if this.MapEntry != nil { + s = append(s, "MapEntry: "+valueToGoStringDescriptor(this.MapEntry, "bool")+",\n") + } + if this.UninterpretedOption != nil { + s = append(s, "UninterpretedOption: "+fmt.Sprintf("%#v", this.UninterpretedOption)+",\n") + } + s = append(s, "XXX_InternalExtensions: "+extensionToGoStringDescriptor(this)+",\n") + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *FieldOptions) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 11) + s = append(s, "&descriptor.FieldOptions{") + if this.Ctype != nil { + s = append(s, "Ctype: "+valueToGoStringDescriptor(this.Ctype, "FieldOptions_CType")+",\n") + } + if this.Packed != nil { + s = append(s, "Packed: "+valueToGoStringDescriptor(this.Packed, "bool")+",\n") + } + if this.Jstype != nil { + s = append(s, "Jstype: "+valueToGoStringDescriptor(this.Jstype, "FieldOptions_JSType")+",\n") + } + if this.Lazy != nil { + s = append(s, "Lazy: "+valueToGoStringDescriptor(this.Lazy, "bool")+",\n") + } + if this.Deprecated != nil { + s = append(s, "Deprecated: "+valueToGoStringDescriptor(this.Deprecated, "bool")+",\n") + } + if this.Weak != nil { + s = append(s, "Weak: "+valueToGoStringDescriptor(this.Weak, "bool")+",\n") + } + if this.UninterpretedOption != nil { + s = append(s, "UninterpretedOption: "+fmt.Sprintf("%#v", this.UninterpretedOption)+",\n") + } + s = append(s, "XXX_InternalExtensions: "+extensionToGoStringDescriptor(this)+",\n") + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *OneofOptions) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 5) + s = append(s, "&descriptor.OneofOptions{") + if this.UninterpretedOption != nil { + s = append(s, "UninterpretedOption: "+fmt.Sprintf("%#v", this.UninterpretedOption)+",\n") + } + s = append(s, "XXX_InternalExtensions: "+extensionToGoStringDescriptor(this)+",\n") + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *EnumOptions) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 7) + s = append(s, "&descriptor.EnumOptions{") + if this.AllowAlias != nil { + s = append(s, "AllowAlias: "+valueToGoStringDescriptor(this.AllowAlias, "bool")+",\n") + } + if this.Deprecated != nil { + s = append(s, "Deprecated: "+valueToGoStringDescriptor(this.Deprecated, "bool")+",\n") + } + if this.UninterpretedOption != nil { + s = append(s, "UninterpretedOption: "+fmt.Sprintf("%#v", this.UninterpretedOption)+",\n") + } + s = append(s, "XXX_InternalExtensions: "+extensionToGoStringDescriptor(this)+",\n") + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *EnumValueOptions) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 6) + s = append(s, "&descriptor.EnumValueOptions{") + if this.Deprecated != nil { + s = append(s, "Deprecated: "+valueToGoStringDescriptor(this.Deprecated, "bool")+",\n") + } + if this.UninterpretedOption != nil { + s = append(s, "UninterpretedOption: "+fmt.Sprintf("%#v", this.UninterpretedOption)+",\n") + } + s = append(s, "XXX_InternalExtensions: "+extensionToGoStringDescriptor(this)+",\n") + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *ServiceOptions) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 6) + s = append(s, "&descriptor.ServiceOptions{") + if this.Deprecated != nil { + s = append(s, "Deprecated: "+valueToGoStringDescriptor(this.Deprecated, "bool")+",\n") + } + if this.UninterpretedOption != nil { + s = append(s, "UninterpretedOption: "+fmt.Sprintf("%#v", this.UninterpretedOption)+",\n") + } + s = append(s, "XXX_InternalExtensions: "+extensionToGoStringDescriptor(this)+",\n") + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *MethodOptions) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 7) + s = append(s, "&descriptor.MethodOptions{") + if this.Deprecated != nil { + s = append(s, "Deprecated: "+valueToGoStringDescriptor(this.Deprecated, "bool")+",\n") + } + if this.IdempotencyLevel != nil { + s = append(s, "IdempotencyLevel: "+valueToGoStringDescriptor(this.IdempotencyLevel, "MethodOptions_IdempotencyLevel")+",\n") + } + if this.UninterpretedOption != nil { + s = append(s, "UninterpretedOption: "+fmt.Sprintf("%#v", this.UninterpretedOption)+",\n") + } + s = append(s, "XXX_InternalExtensions: "+extensionToGoStringDescriptor(this)+",\n") + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *UninterpretedOption) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 11) + s = append(s, "&descriptor.UninterpretedOption{") + if this.Name != nil { + s = append(s, "Name: "+fmt.Sprintf("%#v", this.Name)+",\n") + } + if this.IdentifierValue != nil { + s = append(s, "IdentifierValue: "+valueToGoStringDescriptor(this.IdentifierValue, "string")+",\n") + } + if this.PositiveIntValue != nil { + s = append(s, "PositiveIntValue: "+valueToGoStringDescriptor(this.PositiveIntValue, "uint64")+",\n") + } + if this.NegativeIntValue != nil { + s = append(s, "NegativeIntValue: "+valueToGoStringDescriptor(this.NegativeIntValue, "int64")+",\n") + } + if this.DoubleValue != nil { + s = append(s, "DoubleValue: "+valueToGoStringDescriptor(this.DoubleValue, "float64")+",\n") + } + if this.StringValue != nil { + s = append(s, "StringValue: "+valueToGoStringDescriptor(this.StringValue, "byte")+",\n") + } + if this.AggregateValue != nil { + s = append(s, "AggregateValue: "+valueToGoStringDescriptor(this.AggregateValue, "string")+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *UninterpretedOption_NamePart) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 6) + s = append(s, "&descriptor.UninterpretedOption_NamePart{") + if this.NamePart != nil { + s = append(s, "NamePart: "+valueToGoStringDescriptor(this.NamePart, "string")+",\n") + } + if this.IsExtension != nil { + s = append(s, "IsExtension: "+valueToGoStringDescriptor(this.IsExtension, "bool")+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *SourceCodeInfo) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 5) + s = append(s, "&descriptor.SourceCodeInfo{") + if this.Location != nil { + s = append(s, "Location: "+fmt.Sprintf("%#v", this.Location)+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *SourceCodeInfo_Location) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 9) + s = append(s, "&descriptor.SourceCodeInfo_Location{") + if this.Path != nil { + s = append(s, "Path: "+fmt.Sprintf("%#v", this.Path)+",\n") + } + if this.Span != nil { + s = append(s, "Span: "+fmt.Sprintf("%#v", this.Span)+",\n") + } + if this.LeadingComments != nil { + s = append(s, "LeadingComments: "+valueToGoStringDescriptor(this.LeadingComments, "string")+",\n") + } + if this.TrailingComments != nil { + s = append(s, "TrailingComments: "+valueToGoStringDescriptor(this.TrailingComments, "string")+",\n") + } + if this.LeadingDetachedComments != nil { + s = append(s, "LeadingDetachedComments: "+fmt.Sprintf("%#v", this.LeadingDetachedComments)+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *GeneratedCodeInfo) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 5) + s = append(s, "&descriptor.GeneratedCodeInfo{") + if this.Annotation != nil { + s = append(s, "Annotation: "+fmt.Sprintf("%#v", this.Annotation)+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *GeneratedCodeInfo_Annotation) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 8) + s = append(s, "&descriptor.GeneratedCodeInfo_Annotation{") + if this.Path != nil { + s = append(s, "Path: "+fmt.Sprintf("%#v", this.Path)+",\n") + } + if this.SourceFile != nil { + s = append(s, "SourceFile: "+valueToGoStringDescriptor(this.SourceFile, "string")+",\n") + } + if this.Begin != nil { + s = append(s, "Begin: "+valueToGoStringDescriptor(this.Begin, "int32")+",\n") + } + if this.End != nil { + s = append(s, "End: "+valueToGoStringDescriptor(this.End, "int32")+",\n") + } + if this.XXX_unrecognized != nil { + s = append(s, "XXX_unrecognized:"+fmt.Sprintf("%#v", this.XXX_unrecognized)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func valueToGoStringDescriptor(v interface{}, typ string) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt.Sprintf("func(v %v) *%v { return &v } ( %#v )", typ, typ, pv) +} +func extensionToGoStringDescriptor(m proto.Message) string { + e := proto.GetUnsafeExtensionsMap(m) + if e == nil { + return "nil" + } + s := "proto.NewUnsafeXXX_InternalExtensions(map[int32]proto.Extension{" + keys := make([]int, 0, len(e)) + for k := range e { + keys = append(keys, int(k)) + } + sort.Ints(keys) + ss := []string{} + for _, k := range keys { + ss = append(ss, strconv.Itoa(k)+": "+e[int32(k)].GoString()) + } + s += strings.Join(ss, ",") + "})" + return s +} diff --git a/vendor/github.com/gogo/protobuf/protoc-gen-gogo/descriptor/helper.go b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/descriptor/helper.go new file mode 100644 index 00000000000..e0846a357d5 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/descriptor/helper.go @@ -0,0 +1,390 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package descriptor + +import ( + "strings" +) + +func (msg *DescriptorProto) GetMapFields() (*FieldDescriptorProto, *FieldDescriptorProto) { + if !msg.GetOptions().GetMapEntry() { + return nil, nil + } + return msg.GetField()[0], msg.GetField()[1] +} + +func dotToUnderscore(r rune) rune { + if r == '.' { + return '_' + } + return r +} + +func (field *FieldDescriptorProto) WireType() (wire int) { + switch *field.Type { + case FieldDescriptorProto_TYPE_DOUBLE: + return 1 + case FieldDescriptorProto_TYPE_FLOAT: + return 5 + case FieldDescriptorProto_TYPE_INT64: + return 0 + case FieldDescriptorProto_TYPE_UINT64: + return 0 + case FieldDescriptorProto_TYPE_INT32: + return 0 + case FieldDescriptorProto_TYPE_UINT32: + return 0 + case FieldDescriptorProto_TYPE_FIXED64: + return 1 + case FieldDescriptorProto_TYPE_FIXED32: + return 5 + case FieldDescriptorProto_TYPE_BOOL: + return 0 + case FieldDescriptorProto_TYPE_STRING: + return 2 + case FieldDescriptorProto_TYPE_GROUP: + return 2 + case FieldDescriptorProto_TYPE_MESSAGE: + return 2 + case FieldDescriptorProto_TYPE_BYTES: + return 2 + case FieldDescriptorProto_TYPE_ENUM: + return 0 + case FieldDescriptorProto_TYPE_SFIXED32: + return 5 + case FieldDescriptorProto_TYPE_SFIXED64: + return 1 + case FieldDescriptorProto_TYPE_SINT32: + return 0 + case FieldDescriptorProto_TYPE_SINT64: + return 0 + } + panic("unreachable") +} + +func (field *FieldDescriptorProto) GetKeyUint64() (x uint64) { + packed := field.IsPacked() + wireType := field.WireType() + fieldNumber := field.GetNumber() + if packed { + wireType = 2 + } + x = uint64(uint32(fieldNumber)<<3 | uint32(wireType)) + return x +} + +func (field *FieldDescriptorProto) GetKey3Uint64() (x uint64) { + packed := field.IsPacked3() + wireType := field.WireType() + fieldNumber := field.GetNumber() + if packed { + wireType = 2 + } + x = uint64(uint32(fieldNumber)<<3 | uint32(wireType)) + return x +} + +func (field *FieldDescriptorProto) GetKey() []byte { + x := field.GetKeyUint64() + i := 0 + keybuf := make([]byte, 0) + for i = 0; x > 127; i++ { + keybuf = append(keybuf, 0x80|uint8(x&0x7F)) + x >>= 7 + } + keybuf = append(keybuf, uint8(x)) + return keybuf +} + +func (field *FieldDescriptorProto) GetKey3() []byte { + x := field.GetKey3Uint64() + i := 0 + keybuf := make([]byte, 0) + for i = 0; x > 127; i++ { + keybuf = append(keybuf, 0x80|uint8(x&0x7F)) + x >>= 7 + } + keybuf = append(keybuf, uint8(x)) + return keybuf +} + +func (desc *FileDescriptorSet) GetField(packageName, messageName, fieldName string) *FieldDescriptorProto { + msg := desc.GetMessage(packageName, messageName) + if msg == nil { + return nil + } + for _, field := range msg.GetField() { + if field.GetName() == fieldName { + return field + } + } + return nil +} + +func (file *FileDescriptorProto) GetMessage(typeName string) *DescriptorProto { + for _, msg := range file.GetMessageType() { + if msg.GetName() == typeName { + return msg + } + nes := file.GetNestedMessage(msg, strings.TrimPrefix(typeName, msg.GetName()+".")) + if nes != nil { + return nes + } + } + return nil +} + +func (file *FileDescriptorProto) GetNestedMessage(msg *DescriptorProto, typeName string) *DescriptorProto { + for _, nes := range msg.GetNestedType() { + if nes.GetName() == typeName { + return nes + } + res := file.GetNestedMessage(nes, strings.TrimPrefix(typeName, nes.GetName()+".")) + if res != nil { + return res + } + } + return nil +} + +func (desc *FileDescriptorSet) GetMessage(packageName string, typeName string) *DescriptorProto { + for _, file := range desc.GetFile() { + if strings.Map(dotToUnderscore, file.GetPackage()) != strings.Map(dotToUnderscore, packageName) { + continue + } + for _, msg := range file.GetMessageType() { + if msg.GetName() == typeName { + return msg + } + } + for _, msg := range file.GetMessageType() { + for _, nes := range msg.GetNestedType() { + if nes.GetName() == typeName { + return nes + } + if msg.GetName()+"."+nes.GetName() == typeName { + return nes + } + } + } + } + return nil +} + +func (desc *FileDescriptorSet) IsProto3(packageName string, typeName string) bool { + for _, file := range desc.GetFile() { + if strings.Map(dotToUnderscore, file.GetPackage()) != strings.Map(dotToUnderscore, packageName) { + continue + } + for _, msg := range file.GetMessageType() { + if msg.GetName() == typeName { + return file.GetSyntax() == "proto3" + } + } + for _, msg := range file.GetMessageType() { + for _, nes := range msg.GetNestedType() { + if nes.GetName() == typeName { + return file.GetSyntax() == "proto3" + } + if msg.GetName()+"."+nes.GetName() == typeName { + return file.GetSyntax() == "proto3" + } + } + } + } + return false +} + +func (msg *DescriptorProto) IsExtendable() bool { + return len(msg.GetExtensionRange()) > 0 +} + +func (desc *FileDescriptorSet) FindExtension(packageName string, typeName string, fieldName string) (extPackageName string, field *FieldDescriptorProto) { + parent := desc.GetMessage(packageName, typeName) + if parent == nil { + return "", nil + } + if !parent.IsExtendable() { + return "", nil + } + extendee := "." + packageName + "." + typeName + for _, file := range desc.GetFile() { + for _, ext := range file.GetExtension() { + if strings.Map(dotToUnderscore, file.GetPackage()) == strings.Map(dotToUnderscore, packageName) { + if !(ext.GetExtendee() == typeName || ext.GetExtendee() == extendee) { + continue + } + } else { + if ext.GetExtendee() != extendee { + continue + } + } + if ext.GetName() == fieldName { + return file.GetPackage(), ext + } + } + } + return "", nil +} + +func (desc *FileDescriptorSet) FindExtensionByFieldNumber(packageName string, typeName string, fieldNum int32) (extPackageName string, field *FieldDescriptorProto) { + parent := desc.GetMessage(packageName, typeName) + if parent == nil { + return "", nil + } + if !parent.IsExtendable() { + return "", nil + } + extendee := "." + packageName + "." + typeName + for _, file := range desc.GetFile() { + for _, ext := range file.GetExtension() { + if strings.Map(dotToUnderscore, file.GetPackage()) == strings.Map(dotToUnderscore, packageName) { + if !(ext.GetExtendee() == typeName || ext.GetExtendee() == extendee) { + continue + } + } else { + if ext.GetExtendee() != extendee { + continue + } + } + if ext.GetNumber() == fieldNum { + return file.GetPackage(), ext + } + } + } + return "", nil +} + +func (desc *FileDescriptorSet) FindMessage(packageName string, typeName string, fieldName string) (msgPackageName string, msgName string) { + parent := desc.GetMessage(packageName, typeName) + if parent == nil { + return "", "" + } + field := parent.GetFieldDescriptor(fieldName) + if field == nil { + var extPackageName string + extPackageName, field = desc.FindExtension(packageName, typeName, fieldName) + if field == nil { + return "", "" + } + packageName = extPackageName + } + typeNames := strings.Split(field.GetTypeName(), ".") + if len(typeNames) == 1 { + msg := desc.GetMessage(packageName, typeName) + if msg == nil { + return "", "" + } + return packageName, msg.GetName() + } + if len(typeNames) > 2 { + for i := 1; i < len(typeNames)-1; i++ { + packageName = strings.Join(typeNames[1:len(typeNames)-i], ".") + typeName = strings.Join(typeNames[len(typeNames)-i:], ".") + msg := desc.GetMessage(packageName, typeName) + if msg != nil { + typeNames := strings.Split(msg.GetName(), ".") + if len(typeNames) == 1 { + return packageName, msg.GetName() + } + return strings.Join(typeNames[1:len(typeNames)-1], "."), typeNames[len(typeNames)-1] + } + } + } + return "", "" +} + +func (msg *DescriptorProto) GetFieldDescriptor(fieldName string) *FieldDescriptorProto { + for _, field := range msg.GetField() { + if field.GetName() == fieldName { + return field + } + } + return nil +} + +func (desc *FileDescriptorSet) GetEnum(packageName string, typeName string) *EnumDescriptorProto { + for _, file := range desc.GetFile() { + if strings.Map(dotToUnderscore, file.GetPackage()) != strings.Map(dotToUnderscore, packageName) { + continue + } + for _, enum := range file.GetEnumType() { + if enum.GetName() == typeName { + return enum + } + } + } + return nil +} + +func (f *FieldDescriptorProto) IsEnum() bool { + return *f.Type == FieldDescriptorProto_TYPE_ENUM +} + +func (f *FieldDescriptorProto) IsMessage() bool { + return *f.Type == FieldDescriptorProto_TYPE_MESSAGE +} + +func (f *FieldDescriptorProto) IsBytes() bool { + return *f.Type == FieldDescriptorProto_TYPE_BYTES +} + +func (f *FieldDescriptorProto) IsRepeated() bool { + return f.Label != nil && *f.Label == FieldDescriptorProto_LABEL_REPEATED +} + +func (f *FieldDescriptorProto) IsString() bool { + return *f.Type == FieldDescriptorProto_TYPE_STRING +} + +func (f *FieldDescriptorProto) IsBool() bool { + return *f.Type == FieldDescriptorProto_TYPE_BOOL +} + +func (f *FieldDescriptorProto) IsRequired() bool { + return f.Label != nil && *f.Label == FieldDescriptorProto_LABEL_REQUIRED +} + +func (f *FieldDescriptorProto) IsPacked() bool { + return f.Options != nil && f.GetOptions().GetPacked() +} + +func (f *FieldDescriptorProto) IsPacked3() bool { + if f.IsRepeated() && f.IsScalar() { + if f.Options == nil || f.GetOptions().Packed == nil { + return true + } + return f.Options != nil && f.GetOptions().GetPacked() + } + return false +} + +func (m *DescriptorProto) HasExtension() bool { + return len(m.ExtensionRange) > 0 +} diff --git a/vendor/github.com/gogo/protobuf/protoc-gen-gogo/generator/generator.go b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/generator/generator.go new file mode 100644 index 00000000000..519e22d093d --- /dev/null +++ b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/generator/generator.go @@ -0,0 +1,3431 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* + The code generator for the plugin for the Google protocol buffer compiler. + It generates Go code from the protocol buffer description files read by the + main routine. +*/ +package generator + +import ( + "bufio" + "bytes" + "compress/gzip" + "fmt" + "go/parser" + "go/printer" + "go/token" + "log" + "os" + "path" + "sort" + "strconv" + "strings" + "unicode" + "unicode/utf8" + + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/proto" + descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + plugin "github.com/gogo/protobuf/protoc-gen-gogo/plugin" +) + +// generatedCodeVersion indicates a version of the generated code. +// It is incremented whenever an incompatibility between the generated code and +// proto package is introduced; the generated code references +// a constant, proto.ProtoPackageIsVersionN (where N is generatedCodeVersion). +const generatedCodeVersion = 2 + +// A Plugin provides functionality to add to the output during Go code generation, +// such as to produce RPC stubs. +type Plugin interface { + // Name identifies the plugin. + Name() string + // Init is called once after data structures are built but before + // code generation begins. + Init(g *Generator) + // Generate produces the code generated by the plugin for this file, + // except for the imports, by calling the generator's methods P, In, and Out. + Generate(file *FileDescriptor) + // GenerateImports produces the import declarations for this file. + // It is called after Generate. + GenerateImports(file *FileDescriptor) +} + +type pluginSlice []Plugin + +func (ps pluginSlice) Len() int { + return len(ps) +} + +func (ps pluginSlice) Less(i, j int) bool { + return ps[i].Name() < ps[j].Name() +} + +func (ps pluginSlice) Swap(i, j int) { + ps[i], ps[j] = ps[j], ps[i] +} + +var plugins pluginSlice + +// RegisterPlugin installs a (second-order) plugin to be run when the Go output is generated. +// It is typically called during initialization. +func RegisterPlugin(p Plugin) { + plugins = append(plugins, p) +} + +// Each type we import as a protocol buffer (other than FileDescriptorProto) needs +// a pointer to the FileDescriptorProto that represents it. These types achieve that +// wrapping by placing each Proto inside a struct with the pointer to its File. The +// structs have the same names as their contents, with "Proto" removed. +// FileDescriptor is used to store the things that it points to. + +// The file and package name method are common to messages and enums. +type common struct { + file *descriptor.FileDescriptorProto // File this object comes from. +} + +// PackageName is name in the package clause in the generated file. +func (c *common) PackageName() string { return uniquePackageOf(c.file) } + +func (c *common) File() *descriptor.FileDescriptorProto { return c.file } + +func fileIsProto3(file *descriptor.FileDescriptorProto) bool { + return file.GetSyntax() == "proto3" +} + +func (c *common) proto3() bool { return fileIsProto3(c.file) } + +// Descriptor represents a protocol buffer message. +type Descriptor struct { + common + *descriptor.DescriptorProto + parent *Descriptor // The containing message, if any. + nested []*Descriptor // Inner messages, if any. + enums []*EnumDescriptor // Inner enums, if any. + ext []*ExtensionDescriptor // Extensions, if any. + typename []string // Cached typename vector. + index int // The index into the container, whether the file or another message. + path string // The SourceCodeInfo path as comma-separated integers. + group bool +} + +// TypeName returns the elements of the dotted type name. +// The package name is not part of this name. +func (d *Descriptor) TypeName() []string { + if d.typename != nil { + return d.typename + } + n := 0 + for parent := d; parent != nil; parent = parent.parent { + n++ + } + s := make([]string, n, n) + for parent := d; parent != nil; parent = parent.parent { + n-- + s[n] = parent.GetName() + } + d.typename = s + return s +} + +func (d *Descriptor) allowOneof() bool { + return true +} + +// EnumDescriptor describes an enum. If it's at top level, its parent will be nil. +// Otherwise it will be the descriptor of the message in which it is defined. +type EnumDescriptor struct { + common + *descriptor.EnumDescriptorProto + parent *Descriptor // The containing message, if any. + typename []string // Cached typename vector. + index int // The index into the container, whether the file or a message. + path string // The SourceCodeInfo path as comma-separated integers. +} + +// TypeName returns the elements of the dotted type name. +// The package name is not part of this name. +func (e *EnumDescriptor) TypeName() (s []string) { + if e.typename != nil { + return e.typename + } + name := e.GetName() + if e.parent == nil { + s = make([]string, 1) + } else { + pname := e.parent.TypeName() + s = make([]string, len(pname)+1) + copy(s, pname) + } + s[len(s)-1] = name + e.typename = s + return s +} + +// alias provides the TypeName corrected for the application of any naming +// extensions on the enum type. It should be used for generating references to +// the Go types and for calculating prefixes. +func (e *EnumDescriptor) alias() (s []string) { + s = e.TypeName() + if gogoproto.IsEnumCustomName(e.EnumDescriptorProto) { + s[len(s)-1] = gogoproto.GetEnumCustomName(e.EnumDescriptorProto) + } + + return +} + +// Everything but the last element of the full type name, CamelCased. +// The values of type Foo.Bar are call Foo_value1... not Foo_Bar_value1... . +func (e *EnumDescriptor) prefix() string { + typeName := e.alias() + if e.parent == nil { + // If the enum is not part of a message, the prefix is just the type name. + return CamelCase(typeName[len(typeName)-1]) + "_" + } + return CamelCaseSlice(typeName[0:len(typeName)-1]) + "_" +} + +// The integer value of the named constant in this enumerated type. +func (e *EnumDescriptor) integerValueAsString(name string) string { + for _, c := range e.Value { + if c.GetName() == name { + return fmt.Sprint(c.GetNumber()) + } + } + log.Fatal("cannot find value for enum constant") + return "" +} + +// ExtensionDescriptor describes an extension. If it's at top level, its parent will be nil. +// Otherwise it will be the descriptor of the message in which it is defined. +type ExtensionDescriptor struct { + common + *descriptor.FieldDescriptorProto + parent *Descriptor // The containing message, if any. +} + +// TypeName returns the elements of the dotted type name. +// The package name is not part of this name. +func (e *ExtensionDescriptor) TypeName() (s []string) { + name := e.GetName() + if e.parent == nil { + // top-level extension + s = make([]string, 1) + } else { + pname := e.parent.TypeName() + s = make([]string, len(pname)+1) + copy(s, pname) + } + s[len(s)-1] = name + return s +} + +// DescName returns the variable name used for the generated descriptor. +func (e *ExtensionDescriptor) DescName() string { + // The full type name. + typeName := e.TypeName() + // Each scope of the extension is individually CamelCased, and all are joined with "_" with an "E_" prefix. + for i, s := range typeName { + typeName[i] = CamelCase(s) + } + return "E_" + strings.Join(typeName, "_") +} + +// ImportedDescriptor describes a type that has been publicly imported from another file. +type ImportedDescriptor struct { + common + o Object +} + +func (id *ImportedDescriptor) TypeName() []string { return id.o.TypeName() } + +// FileDescriptor describes an protocol buffer descriptor file (.proto). +// It includes slices of all the messages and enums defined within it. +// Those slices are constructed by WrapTypes. +type FileDescriptor struct { + *descriptor.FileDescriptorProto + desc []*Descriptor // All the messages defined in this file. + enum []*EnumDescriptor // All the enums defined in this file. + ext []*ExtensionDescriptor // All the top-level extensions defined in this file. + imp []*ImportedDescriptor // All types defined in files publicly imported by this file. + + // Comments, stored as a map of path (comma-separated integers) to the comment. + comments map[string]*descriptor.SourceCodeInfo_Location + + // The full list of symbols that are exported, + // as a map from the exported object to its symbols. + // This is used for supporting public imports. + exported map[Object][]symbol + + index int // The index of this file in the list of files to generate code for + + proto3 bool // whether to generate proto3 code for this file +} + +// PackageName is the package name we'll use in the generated code to refer to this file. +func (d *FileDescriptor) PackageName() string { return uniquePackageOf(d.FileDescriptorProto) } + +// VarName is the variable name we'll use in the generated code to refer +// to the compressed bytes of this descriptor. It is not exported, so +// it is only valid inside the generated package. +func (d *FileDescriptor) VarName() string { return fmt.Sprintf("fileDescriptor%v", FileName(d)) } + +// goPackageOption interprets the file's go_package option. +// If there is no go_package, it returns ("", "", false). +// If there's a simple name, it returns ("", pkg, true). +// If the option implies an import path, it returns (impPath, pkg, true). +func (d *FileDescriptor) goPackageOption() (impPath, pkg string, ok bool) { + pkg = d.GetOptions().GetGoPackage() + if pkg == "" { + return + } + ok = true + // The presence of a slash implies there's an import path. + slash := strings.LastIndex(pkg, "/") + if slash < 0 { + return + } + impPath, pkg = pkg, pkg[slash+1:] + // A semicolon-delimited suffix overrides the package name. + sc := strings.IndexByte(impPath, ';') + if sc < 0 { + return + } + impPath, pkg = impPath[:sc], impPath[sc+1:] + return +} + +// goPackageName returns the Go package name to use in the +// generated Go file. The result explicit reports whether the name +// came from an option go_package statement. If explicit is false, +// the name was derived from the protocol buffer's package statement +// or the input file name. +func (d *FileDescriptor) goPackageName() (name string, explicit bool) { + // Does the file have a "go_package" option? + if _, pkg, ok := d.goPackageOption(); ok { + return pkg, true + } + + // Does the file have a package clause? + if pkg := d.GetPackage(); pkg != "" { + return pkg, false + } + // Use the file base name. + return baseName(d.GetName()), false +} + +// goFileName returns the output name for the generated Go file. +func (d *FileDescriptor) goFileName() string { + name := *d.Name + if ext := path.Ext(name); ext == ".proto" || ext == ".protodevel" { + name = name[:len(name)-len(ext)] + } + name += ".pb.go" + + // Does the file have a "go_package" option? + // If it does, it may override the filename. + if impPath, _, ok := d.goPackageOption(); ok && impPath != "" { + // Replace the existing dirname with the declared import path. + _, name = path.Split(name) + name = path.Join(impPath, name) + return name + } + + return name +} + +func (d *FileDescriptor) addExport(obj Object, sym symbol) { + d.exported[obj] = append(d.exported[obj], sym) +} + +// symbol is an interface representing an exported Go symbol. +type symbol interface { + // GenerateAlias should generate an appropriate alias + // for the symbol from the named package. + GenerateAlias(g *Generator, pkg string) +} + +type messageSymbol struct { + sym string + hasExtensions, isMessageSet bool + hasOneof bool + getters []getterSymbol +} + +type getterSymbol struct { + name string + typ string + typeName string // canonical name in proto world; empty for proto.Message and similar + genType bool // whether typ contains a generated type (message/group/enum) +} + +func (ms *messageSymbol) GenerateAlias(g *Generator, pkg string) { + remoteSym := pkg + "." + ms.sym + + g.P("type ", ms.sym, " ", remoteSym) + g.P("func (m *", ms.sym, ") Reset() { (*", remoteSym, ")(m).Reset() }") + g.P("func (m *", ms.sym, ") String() string { return (*", remoteSym, ")(m).String() }") + g.P("func (*", ms.sym, ") ProtoMessage() {}") + if ms.hasExtensions { + g.P("func (*", ms.sym, ") ExtensionRangeArray() []", g.Pkg["proto"], ".ExtensionRange ", + "{ return (*", remoteSym, ")(nil).ExtensionRangeArray() }") + if ms.isMessageSet { + g.P("func (m *", ms.sym, ") Marshal() ([]byte, error) ", + "{ return (*", remoteSym, ")(m).Marshal() }") + g.P("func (m *", ms.sym, ") Unmarshal(buf []byte) error ", + "{ return (*", remoteSym, ")(m).Unmarshal(buf) }") + } + } + if ms.hasOneof { + // Oneofs and public imports do not mix well. + // We can make them work okay for the binary format, + // but they're going to break weirdly for text/JSON. + enc := "_" + ms.sym + "_OneofMarshaler" + dec := "_" + ms.sym + "_OneofUnmarshaler" + size := "_" + ms.sym + "_OneofSizer" + encSig := "(msg " + g.Pkg["proto"] + ".Message, b *" + g.Pkg["proto"] + ".Buffer) error" + decSig := "(msg " + g.Pkg["proto"] + ".Message, tag, wire int, b *" + g.Pkg["proto"] + ".Buffer) (bool, error)" + sizeSig := "(msg " + g.Pkg["proto"] + ".Message) int" + g.P("func (m *", ms.sym, ") XXX_OneofFuncs() (func", encSig, ", func", decSig, ", func", sizeSig, ", []interface{}) {") + g.P("return ", enc, ", ", dec, ", ", size, ", nil") + g.P("}") + + g.P("func ", enc, encSig, " {") + g.P("m := msg.(*", ms.sym, ")") + g.P("m0 := (*", remoteSym, ")(m)") + g.P("enc, _, _, _ := m0.XXX_OneofFuncs()") + g.P("return enc(m0, b)") + g.P("}") + + g.P("func ", dec, decSig, " {") + g.P("m := msg.(*", ms.sym, ")") + g.P("m0 := (*", remoteSym, ")(m)") + g.P("_, dec, _, _ := m0.XXX_OneofFuncs()") + g.P("return dec(m0, tag, wire, b)") + g.P("}") + + g.P("func ", size, sizeSig, " {") + g.P("m := msg.(*", ms.sym, ")") + g.P("m0 := (*", remoteSym, ")(m)") + g.P("_, _, size, _ := m0.XXX_OneofFuncs()") + g.P("return size(m0)") + g.P("}") + } + for _, get := range ms.getters { + + if get.typeName != "" { + g.RecordTypeUse(get.typeName) + } + typ := get.typ + val := "(*" + remoteSym + ")(m)." + get.name + "()" + if get.genType { + // typ will be "*pkg.T" (message/group) or "pkg.T" (enum) + // or "map[t]*pkg.T" (map to message/enum). + // The first two of those might have a "[]" prefix if it is repeated. + // Drop any package qualifier since we have hoisted the type into this package. + rep := strings.HasPrefix(typ, "[]") + if rep { + typ = typ[2:] + } + isMap := strings.HasPrefix(typ, "map[") + star := typ[0] == '*' + if !isMap { // map types handled lower down + typ = typ[strings.Index(typ, ".")+1:] + } + if star { + typ = "*" + typ + } + if rep { + // Go does not permit conversion between slice types where both + // element types are named. That means we need to generate a bit + // of code in this situation. + // typ is the element type. + // val is the expression to get the slice from the imported type. + + ctyp := typ // conversion type expression; "Foo" or "(*Foo)" + if star { + ctyp = "(" + typ + ")" + } + + g.P("func (m *", ms.sym, ") ", get.name, "() []", typ, " {") + g.In() + g.P("o := ", val) + g.P("if o == nil {") + g.In() + g.P("return nil") + g.Out() + g.P("}") + g.P("s := make([]", typ, ", len(o))") + g.P("for i, x := range o {") + g.In() + g.P("s[i] = ", ctyp, "(x)") + g.Out() + g.P("}") + g.P("return s") + g.Out() + g.P("}") + continue + } + if isMap { + // Split map[keyTyp]valTyp. + bra, ket := strings.Index(typ, "["), strings.Index(typ, "]") + keyTyp, valTyp := typ[bra+1:ket], typ[ket+1:] + // Drop any package qualifier. + // Only the value type may be foreign. + star := valTyp[0] == '*' + valTyp = valTyp[strings.Index(valTyp, ".")+1:] + if star { + valTyp = "*" + valTyp + } + + maptyp := "map[" + keyTyp + "]" + valTyp + g.P("func (m *", ms.sym, ") ", get.name, "() ", typ, " {") + g.P("o := ", val) + g.P("if o == nil { return nil }") + g.P("s := make(", maptyp, ", len(o))") + g.P("for k, v := range o {") + g.P("s[k] = (", valTyp, ")(v)") + g.P("}") + g.P("return s") + g.P("}") + continue + } + // Convert imported type into the forwarding type. + val = "(" + typ + ")(" + val + ")" + } + + g.P("func (m *", ms.sym, ") ", get.name, "() ", typ, " { return ", val, " }") + } + +} + +type enumSymbol struct { + name string + proto3 bool // Whether this came from a proto3 file. +} + +func (es enumSymbol) GenerateAlias(g *Generator, pkg string) { + s := es.name + g.P("type ", s, " ", pkg, ".", s) + g.P("var ", s, "_name = ", pkg, ".", s, "_name") + g.P("var ", s, "_value = ", pkg, ".", s, "_value") + g.P("func (x ", s, ") String() string { return (", pkg, ".", s, ")(x).String() }") + if !es.proto3 { + g.P("func (x ", s, ") Enum() *", s, "{ return (*", s, ")((", pkg, ".", s, ")(x).Enum()) }") + g.P("func (x *", s, ") UnmarshalJSON(data []byte) error { return (*", pkg, ".", s, ")(x).UnmarshalJSON(data) }") + } +} + +type constOrVarSymbol struct { + sym string + typ string // either "const" or "var" + cast string // if non-empty, a type cast is required (used for enums) +} + +func (cs constOrVarSymbol) GenerateAlias(g *Generator, pkg string) { + v := pkg + "." + cs.sym + if cs.cast != "" { + v = cs.cast + "(" + v + ")" + } + g.P(cs.typ, " ", cs.sym, " = ", v) +} + +// Object is an interface abstracting the abilities shared by enums, messages, extensions and imported objects. +type Object interface { + PackageName() string // The name we use in our output (a_b_c), possibly renamed for uniqueness. + TypeName() []string + File() *descriptor.FileDescriptorProto +} + +// Each package name we generate must be unique. The package we're generating +// gets its own name but every other package must have a unique name that does +// not conflict in the code we generate. These names are chosen globally (although +// they don't have to be, it simplifies things to do them globally). +func uniquePackageOf(fd *descriptor.FileDescriptorProto) string { + s, ok := uniquePackageName[fd] + if !ok { + log.Fatal("internal error: no package name defined for " + fd.GetName()) + } + return s +} + +// Generator is the type whose methods generate the output, stored in the associated response structure. +type Generator struct { + *bytes.Buffer + + Request *plugin.CodeGeneratorRequest // The input. + Response *plugin.CodeGeneratorResponse // The output. + + Param map[string]string // Command-line parameters. + PackageImportPath string // Go import path of the package we're generating code for + ImportPrefix string // String to prefix to imported package file names. + ImportMap map[string]string // Mapping from .proto file name to import path + + Pkg map[string]string // The names under which we import support packages + + packageName string // What we're calling ourselves. + allFiles []*FileDescriptor // All files in the tree + allFilesByName map[string]*FileDescriptor // All files by filename. + genFiles []*FileDescriptor // Those files we will generate output for. + file *FileDescriptor // The file we are compiling now. + usedPackages map[string]bool // Names of packages used in current file. + typeNameToObject map[string]Object // Key is a fully-qualified name in input syntax. + init []string // Lines to emit in the init function. + indent string + writeOutput bool + + customImports []string + writtenImports map[string]bool // For de-duplicating written imports +} + +// New creates a new generator and allocates the request and response protobufs. +func New() *Generator { + g := new(Generator) + g.Buffer = new(bytes.Buffer) + g.Request = new(plugin.CodeGeneratorRequest) + g.Response = new(plugin.CodeGeneratorResponse) + g.writtenImports = make(map[string]bool) + uniquePackageName = make(map[*descriptor.FileDescriptorProto]string) + pkgNamesInUse = make(map[string][]*FileDescriptor) + return g +} + +// Error reports a problem, including an error, and exits the program. +func (g *Generator) Error(err error, msgs ...string) { + s := strings.Join(msgs, " ") + ":" + err.Error() + log.Print("protoc-gen-gogo: error:", s) + os.Exit(1) +} + +// Fail reports a problem and exits the program. +func (g *Generator) Fail(msgs ...string) { + s := strings.Join(msgs, " ") + log.Print("protoc-gen-gogo: error:", s) + os.Exit(1) +} + +// CommandLineParameters breaks the comma-separated list of key=value pairs +// in the parameter (a member of the request protobuf) into a key/value map. +// It then sets file name mappings defined by those entries. +func (g *Generator) CommandLineParameters(parameter string) { + g.Param = make(map[string]string) + for _, p := range strings.Split(parameter, ",") { + if i := strings.Index(p, "="); i < 0 { + g.Param[p] = "" + } else { + g.Param[p[0:i]] = p[i+1:] + } + } + + g.ImportMap = make(map[string]string) + pluginList := "none" // Default list of plugin names to enable (empty means all). + for k, v := range g.Param { + switch k { + case "import_prefix": + g.ImportPrefix = v + case "import_path": + g.PackageImportPath = v + case "plugins": + pluginList = v + default: + if len(k) > 0 && k[0] == 'M' { + g.ImportMap[k[1:]] = v + } + } + } + if pluginList == "" { + return + } + if pluginList == "none" { + pluginList = "" + } + gogoPluginNames := []string{"unmarshal", "unsafeunmarshaler", "union", "stringer", "size", "protosizer", "populate", "marshalto", "unsafemarshaler", "gostring", "face", "equal", "enumstringer", "embedcheck", "description", "defaultcheck", "oneofcheck", "compare"} + pluginList = strings.Join(append(gogoPluginNames, pluginList), "+") + if pluginList != "" { + // Amend the set of plugins. + enabled := make(map[string]bool) + for _, name := range strings.Split(pluginList, "+") { + enabled[name] = true + } + var nplugins pluginSlice + for _, p := range plugins { + if enabled[p.Name()] { + nplugins = append(nplugins, p) + } + } + sort.Sort(nplugins) + plugins = nplugins + } +} + +// DefaultPackageName returns the package name printed for the object. +// If its file is in a different package, it returns the package name we're using for this file, plus ".". +// Otherwise it returns the empty string. +func (g *Generator) DefaultPackageName(obj Object) string { + pkg := obj.PackageName() + if pkg == g.packageName { + return "" + } + return pkg + "." +} + +// For each input file, the unique package name to use, underscored. +var uniquePackageName = make(map[*descriptor.FileDescriptorProto]string) + +// Package names already registered. Key is the name from the .proto file; +// value is the name that appears in the generated code. +var pkgNamesInUse = make(map[string][]*FileDescriptor) + +// Create and remember a guaranteed unique package name for this file descriptor. +// Pkg is the candidate name. If f is nil, it's a builtin package like "proto" and +// has no file descriptor. +func RegisterUniquePackageName(pkg string, f *FileDescriptor) string { + if f == nil { + // For builtin and standard lib packages, try to use only + // the last component of the package path. + pkg = pkg[strings.LastIndex(pkg, "/")+1:] + } + + // Convert dots to underscores before finding a unique alias. + pkg = strings.Map(badToUnderscore, pkg) + + var i = -1 + var ptr *FileDescriptor = nil + for i, ptr = range pkgNamesInUse[pkg] { + if ptr == f { + if i == 0 { + return pkg + } + return pkg + strconv.Itoa(i) + } + } + + pkgNamesInUse[pkg] = append(pkgNamesInUse[pkg], f) + i += 1 + + if i > 0 { + pkg = pkg + strconv.Itoa(i) + } + + if f != nil { + uniquePackageName[f.FileDescriptorProto] = pkg + } + return pkg +} + +var isGoKeyword = map[string]bool{ + "break": true, + "case": true, + "chan": true, + "const": true, + "continue": true, + "default": true, + "else": true, + "defer": true, + "fallthrough": true, + "for": true, + "func": true, + "go": true, + "goto": true, + "if": true, + "import": true, + "interface": true, + "map": true, + "package": true, + "range": true, + "return": true, + "select": true, + "struct": true, + "switch": true, + "type": true, + "var": true, +} + +// defaultGoPackage returns the package name to use, +// derived from the import path of the package we're building code for. +func (g *Generator) defaultGoPackage() string { + p := g.PackageImportPath + if i := strings.LastIndex(p, "/"); i >= 0 { + p = p[i+1:] + } + if p == "" { + return "" + } + + p = strings.Map(badToUnderscore, p) + // Identifier must not be keyword: insert _. + if isGoKeyword[p] { + p = "_" + p + } + // Identifier must not begin with digit: insert _. + if r, _ := utf8.DecodeRuneInString(p); unicode.IsDigit(r) { + p = "_" + p + } + return p +} + +// SetPackageNames sets the package name for this run. +// The package name must agree across all files being generated. +// It also defines unique package names for all imported files. +func (g *Generator) SetPackageNames() { + // Register the name for this package. It will be the first name + // registered so is guaranteed to be unmodified. + pkg, explicit := g.genFiles[0].goPackageName() + + // Check all files for an explicit go_package option. + for _, f := range g.genFiles { + thisPkg, thisExplicit := f.goPackageName() + if thisExplicit { + if !explicit { + // Let this file's go_package option serve for all input files. + pkg, explicit = thisPkg, true + } else if thisPkg != pkg { + g.Fail("inconsistent package names:", thisPkg, pkg) + } + } + } + + // If we don't have an explicit go_package option but we have an + // import path, use that. + if !explicit { + p := g.defaultGoPackage() + if p != "" { + pkg, explicit = p, true + } + } + + // If there was no go_package and no import path to use, + // double-check that all the inputs have the same implicit + // Go package name. + if !explicit { + for _, f := range g.genFiles { + thisPkg, _ := f.goPackageName() + if thisPkg != pkg { + g.Fail("inconsistent package names:", thisPkg, pkg) + } + } + } + + g.packageName = RegisterUniquePackageName(pkg, g.genFiles[0]) + + // Register the support package names. They might collide with the + // name of a package we import. + g.Pkg = map[string]string{ + "fmt": RegisterUniquePackageName("fmt", nil), + "math": RegisterUniquePackageName("math", nil), + "proto": RegisterUniquePackageName("proto", nil), + "golang_proto": RegisterUniquePackageName("golang_proto", nil), + } + +AllFiles: + for _, f := range g.allFiles { + for _, genf := range g.genFiles { + if f == genf { + // In this package already. + uniquePackageName[f.FileDescriptorProto] = g.packageName + continue AllFiles + } + } + // The file is a dependency, so we want to ignore its go_package option + // because that is only relevant for its specific generated output. + pkg := f.GetPackage() + if pkg == "" { + pkg = baseName(*f.Name) + } + RegisterUniquePackageName(pkg, f) + } +} + +// WrapTypes walks the incoming data, wrapping DescriptorProtos, EnumDescriptorProtos +// and FileDescriptorProtos into file-referenced objects within the Generator. +// It also creates the list of files to generate and so should be called before GenerateAllFiles. +func (g *Generator) WrapTypes() { + g.allFiles = make([]*FileDescriptor, 0, len(g.Request.ProtoFile)) + g.allFilesByName = make(map[string]*FileDescriptor, len(g.allFiles)) + for _, f := range g.Request.ProtoFile { + // We must wrap the descriptors before we wrap the enums + descs := wrapDescriptors(f) + g.buildNestedDescriptors(descs) + enums := wrapEnumDescriptors(f, descs) + g.buildNestedEnums(descs, enums) + exts := wrapExtensions(f) + fd := &FileDescriptor{ + FileDescriptorProto: f, + desc: descs, + enum: enums, + ext: exts, + exported: make(map[Object][]symbol), + proto3: fileIsProto3(f), + } + extractComments(fd) + g.allFiles = append(g.allFiles, fd) + g.allFilesByName[f.GetName()] = fd + } + for _, fd := range g.allFiles { + fd.imp = wrapImported(fd.FileDescriptorProto, g) + } + + g.genFiles = make([]*FileDescriptor, 0, len(g.Request.FileToGenerate)) + for _, fileName := range g.Request.FileToGenerate { + fd := g.allFilesByName[fileName] + if fd == nil { + g.Fail("could not find file named", fileName) + } + fd.index = len(g.genFiles) + g.genFiles = append(g.genFiles, fd) + } +} + +// Scan the descriptors in this file. For each one, build the slice of nested descriptors +func (g *Generator) buildNestedDescriptors(descs []*Descriptor) { + for _, desc := range descs { + if len(desc.NestedType) != 0 { + for _, nest := range descs { + if nest.parent == desc { + desc.nested = append(desc.nested, nest) + } + } + if len(desc.nested) != len(desc.NestedType) { + g.Fail("internal error: nesting failure for", desc.GetName()) + } + } + } +} + +func (g *Generator) buildNestedEnums(descs []*Descriptor, enums []*EnumDescriptor) { + for _, desc := range descs { + if len(desc.EnumType) != 0 { + for _, enum := range enums { + if enum.parent == desc { + desc.enums = append(desc.enums, enum) + } + } + if len(desc.enums) != len(desc.EnumType) { + g.Fail("internal error: enum nesting failure for", desc.GetName()) + } + } + } +} + +// Construct the Descriptor +func newDescriptor(desc *descriptor.DescriptorProto, parent *Descriptor, file *descriptor.FileDescriptorProto, index int) *Descriptor { + d := &Descriptor{ + common: common{file}, + DescriptorProto: desc, + parent: parent, + index: index, + } + if parent == nil { + d.path = fmt.Sprintf("%d,%d", messagePath, index) + } else { + d.path = fmt.Sprintf("%s,%d,%d", parent.path, messageMessagePath, index) + } + + // The only way to distinguish a group from a message is whether + // the containing message has a TYPE_GROUP field that matches. + if parent != nil { + parts := d.TypeName() + if file.Package != nil { + parts = append([]string{*file.Package}, parts...) + } + exp := "." + strings.Join(parts, ".") + for _, field := range parent.Field { + if field.GetType() == descriptor.FieldDescriptorProto_TYPE_GROUP && field.GetTypeName() == exp { + d.group = true + break + } + } + } + + for _, field := range desc.Extension { + d.ext = append(d.ext, &ExtensionDescriptor{common{file}, field, d}) + } + + return d +} + +// Return a slice of all the Descriptors defined within this file +func wrapDescriptors(file *descriptor.FileDescriptorProto) []*Descriptor { + sl := make([]*Descriptor, 0, len(file.MessageType)+10) + for i, desc := range file.MessageType { + sl = wrapThisDescriptor(sl, desc, nil, file, i) + } + return sl +} + +// Wrap this Descriptor, recursively +func wrapThisDescriptor(sl []*Descriptor, desc *descriptor.DescriptorProto, parent *Descriptor, file *descriptor.FileDescriptorProto, index int) []*Descriptor { + sl = append(sl, newDescriptor(desc, parent, file, index)) + me := sl[len(sl)-1] + for i, nested := range desc.NestedType { + sl = wrapThisDescriptor(sl, nested, me, file, i) + } + return sl +} + +// Construct the EnumDescriptor +func newEnumDescriptor(desc *descriptor.EnumDescriptorProto, parent *Descriptor, file *descriptor.FileDescriptorProto, index int) *EnumDescriptor { + ed := &EnumDescriptor{ + common: common{file}, + EnumDescriptorProto: desc, + parent: parent, + index: index, + } + if parent == nil { + ed.path = fmt.Sprintf("%d,%d", enumPath, index) + } else { + ed.path = fmt.Sprintf("%s,%d,%d", parent.path, messageEnumPath, index) + } + return ed +} + +// Return a slice of all the EnumDescriptors defined within this file +func wrapEnumDescriptors(file *descriptor.FileDescriptorProto, descs []*Descriptor) []*EnumDescriptor { + sl := make([]*EnumDescriptor, 0, len(file.EnumType)+10) + // Top-level enums. + for i, enum := range file.EnumType { + sl = append(sl, newEnumDescriptor(enum, nil, file, i)) + } + // Enums within messages. Enums within embedded messages appear in the outer-most message. + for _, nested := range descs { + for i, enum := range nested.EnumType { + sl = append(sl, newEnumDescriptor(enum, nested, file, i)) + } + } + return sl +} + +// Return a slice of all the top-level ExtensionDescriptors defined within this file. +func wrapExtensions(file *descriptor.FileDescriptorProto) []*ExtensionDescriptor { + var sl []*ExtensionDescriptor + for _, field := range file.Extension { + sl = append(sl, &ExtensionDescriptor{common{file}, field, nil}) + } + return sl +} + +// Return a slice of all the types that are publicly imported into this file. +func wrapImported(file *descriptor.FileDescriptorProto, g *Generator) (sl []*ImportedDescriptor) { + for _, index := range file.PublicDependency { + df := g.fileByName(file.Dependency[index]) + for _, d := range df.desc { + if d.GetOptions().GetMapEntry() { + continue + } + sl = append(sl, &ImportedDescriptor{common{file}, d}) + } + for _, e := range df.enum { + sl = append(sl, &ImportedDescriptor{common{file}, e}) + } + for _, ext := range df.ext { + sl = append(sl, &ImportedDescriptor{common{file}, ext}) + } + } + return +} + +func extractComments(file *FileDescriptor) { + file.comments = make(map[string]*descriptor.SourceCodeInfo_Location) + for _, loc := range file.GetSourceCodeInfo().GetLocation() { + if loc.LeadingComments == nil { + continue + } + var p []string + for _, n := range loc.Path { + p = append(p, strconv.Itoa(int(n))) + } + file.comments[strings.Join(p, ",")] = loc + } +} + +// BuildTypeNameMap builds the map from fully qualified type names to objects. +// The key names for the map come from the input data, which puts a period at the beginning. +// It should be called after SetPackageNames and before GenerateAllFiles. +func (g *Generator) BuildTypeNameMap() { + g.typeNameToObject = make(map[string]Object) + for _, f := range g.allFiles { + // The names in this loop are defined by the proto world, not us, so the + // package name may be empty. If so, the dotted package name of X will + // be ".X"; otherwise it will be ".pkg.X". + dottedPkg := "." + f.GetPackage() + if dottedPkg != "." { + dottedPkg += "." + } + for _, enum := range f.enum { + name := dottedPkg + dottedSlice(enum.TypeName()) + g.typeNameToObject[name] = enum + } + for _, desc := range f.desc { + name := dottedPkg + dottedSlice(desc.TypeName()) + g.typeNameToObject[name] = desc + } + } +} + +// ObjectNamed, given a fully-qualified input type name as it appears in the input data, +// returns the descriptor for the message or enum with that name. +func (g *Generator) ObjectNamed(typeName string) Object { + o, ok := g.typeNameToObject[typeName] + if !ok { + g.Fail("can't find object with type", typeName) + } + + // If the file of this object isn't a direct dependency of the current file, + // or in the current file, then this object has been publicly imported into + // a dependency of the current file. + // We should return the ImportedDescriptor object for it instead. + direct := *o.File().Name == *g.file.Name + if !direct { + for _, dep := range g.file.Dependency { + if *g.fileByName(dep).Name == *o.File().Name { + direct = true + break + } + } + } + if !direct { + found := false + Loop: + for _, dep := range g.file.Dependency { + df := g.fileByName(*g.fileByName(dep).Name) + for _, td := range df.imp { + if td.o == o { + // Found it! + o = td + found = true + break Loop + } + } + } + if !found { + log.Printf("protoc-gen-gogo: WARNING: failed finding publicly imported dependency for %v, used in %v", typeName, *g.file.Name) + } + } + + return o +} + +// P prints the arguments to the generated output. It handles strings and int32s, plus +// handling indirections because they may be *string, etc. +func (g *Generator) P(str ...interface{}) { + if !g.writeOutput { + return + } + g.WriteString(g.indent) + for _, v := range str { + switch s := v.(type) { + case string: + g.WriteString(s) + case *string: + g.WriteString(*s) + case bool: + fmt.Fprintf(g, "%t", s) + case *bool: + fmt.Fprintf(g, "%t", *s) + case int: + fmt.Fprintf(g, "%d", s) + case *int32: + fmt.Fprintf(g, "%d", *s) + case *int64: + fmt.Fprintf(g, "%d", *s) + case float64: + fmt.Fprintf(g, "%g", s) + case *float64: + fmt.Fprintf(g, "%g", *s) + default: + g.Fail(fmt.Sprintf("unknown type in printer: %T", v)) + } + } + g.WriteByte('\n') +} + +// addInitf stores the given statement to be printed inside the file's init function. +// The statement is given as a format specifier and arguments. +func (g *Generator) addInitf(stmt string, a ...interface{}) { + g.init = append(g.init, fmt.Sprintf(stmt, a...)) +} + +func (g *Generator) PrintImport(alias, pkg string) { + statement := "import " + alias + " " + strconv.Quote(pkg) + if g.writtenImports[statement] { + return + } + g.P(statement) + g.writtenImports[statement] = true +} + +// In Indents the output one tab stop. +func (g *Generator) In() { g.indent += "\t" } + +// Out unindents the output one tab stop. +func (g *Generator) Out() { + if len(g.indent) > 0 { + g.indent = g.indent[1:] + } +} + +// GenerateAllFiles generates the output for all the files we're outputting. +func (g *Generator) GenerateAllFiles() { + // Initialize the plugins + for _, p := range plugins { + p.Init(g) + } + // Generate the output. The generator runs for every file, even the files + // that we don't generate output for, so that we can collate the full list + // of exported symbols to support public imports. + genFileMap := make(map[*FileDescriptor]bool, len(g.genFiles)) + for _, file := range g.genFiles { + genFileMap[file] = true + } + for _, file := range g.allFiles { + g.Reset() + g.writeOutput = genFileMap[file] + g.generate(file) + if !g.writeOutput { + continue + } + g.Response.File = append(g.Response.File, &plugin.CodeGeneratorResponse_File{ + Name: proto.String(file.goFileName()), + Content: proto.String(g.String()), + }) + } +} + +// Run all the plugins associated with the file. +func (g *Generator) runPlugins(file *FileDescriptor) { + for _, p := range plugins { + p.Generate(file) + } +} + +// FileOf return the FileDescriptor for this FileDescriptorProto. +func (g *Generator) FileOf(fd *descriptor.FileDescriptorProto) *FileDescriptor { + for _, file := range g.allFiles { + if file.FileDescriptorProto == fd { + return file + } + } + g.Fail("could not find file in table:", fd.GetName()) + return nil +} + +// Fill the response protocol buffer with the generated output for all the files we're +// supposed to generate. +func (g *Generator) generate(file *FileDescriptor) { + g.customImports = make([]string, 0) + g.file = g.FileOf(file.FileDescriptorProto) + g.usedPackages = make(map[string]bool) + + if g.file.index == 0 { + // For one file in the package, assert version compatibility. + g.P("// This is a compile-time assertion to ensure that this generated file") + g.P("// is compatible with the proto package it is being compiled against.") + g.P("// A compilation error at this line likely means your copy of the") + g.P("// proto package needs to be updated.") + if gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { + g.P("const _ = ", g.Pkg["proto"], ".GoGoProtoPackageIsVersion", generatedCodeVersion, " // please upgrade the proto package") + } else { + g.P("const _ = ", g.Pkg["proto"], ".ProtoPackageIsVersion", generatedCodeVersion, " // please upgrade the proto package") + } + g.P() + } + // Reset on each file + g.writtenImports = make(map[string]bool) + for _, td := range g.file.imp { + g.generateImported(td) + } + for _, enum := range g.file.enum { + g.generateEnum(enum) + } + for _, desc := range g.file.desc { + // Don't generate virtual messages for maps. + if desc.GetOptions().GetMapEntry() { + continue + } + g.generateMessage(desc) + } + for _, ext := range g.file.ext { + g.generateExtension(ext) + } + g.generateInitFunction() + + // Run the plugins before the imports so we know which imports are necessary. + g.runPlugins(file) + + g.generateFileDescriptor(file) + + // Generate header and imports last, though they appear first in the output. + rem := g.Buffer + g.Buffer = new(bytes.Buffer) + g.generateHeader() + g.generateImports() + if !g.writeOutput { + return + } + g.Write(rem.Bytes()) + + // Reformat generated code. + fset := token.NewFileSet() + raw := g.Bytes() + ast, err := parser.ParseFile(fset, "", g, parser.ParseComments) + if err != nil { + // Print out the bad code with line numbers. + // This should never happen in practice, but it can while changing generated code, + // so consider this a debugging aid. + var src bytes.Buffer + s := bufio.NewScanner(bytes.NewReader(raw)) + for line := 1; s.Scan(); line++ { + fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes()) + } + if serr := s.Err(); serr != nil { + g.Fail("bad Go source code was generated:", err.Error(), "\n"+string(raw)) + } else { + g.Fail("bad Go source code was generated:", err.Error(), "\n"+src.String()) + } + } + g.Reset() + err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(g, fset, ast) + if err != nil { + g.Fail("generated Go source code could not be reformatted:", err.Error()) + } +} + +// Generate the header, including package definition +func (g *Generator) generateHeader() { + g.P("// Code generated by protoc-gen-gogo. DO NOT EDIT.") + g.P("// source: ", *g.file.Name) + g.P() + + name := g.file.PackageName() + + if g.file.index == 0 { + // Generate package docs for the first file in the package. + g.P("/*") + g.P("Package ", name, " is a generated protocol buffer package.") + g.P() + if loc, ok := g.file.comments[strconv.Itoa(packagePath)]; ok { + // not using g.PrintComments because this is a /* */ comment block. + text := strings.TrimSuffix(loc.GetLeadingComments(), "\n") + for _, line := range strings.Split(text, "\n") { + line = strings.TrimPrefix(line, " ") + // ensure we don't escape from the block comment + line = strings.Replace(line, "*/", "* /", -1) + g.P(line) + } + g.P() + } + var topMsgs []string + g.P("It is generated from these files:") + for _, f := range g.genFiles { + g.P("\t", f.Name) + for _, msg := range f.desc { + if msg.parent != nil { + continue + } + topMsgs = append(topMsgs, CamelCaseSlice(msg.TypeName())) + } + } + g.P() + g.P("It has these top-level messages:") + for _, msg := range topMsgs { + g.P("\t", msg) + } + g.P("*/") + } + + g.P("package ", name) + g.P() +} + +// PrintComments prints any comments from the source .proto file. +// The path is a comma-separated list of integers. +// It returns an indication of whether any comments were printed. +// See descriptor.proto for its format. +func (g *Generator) PrintComments(path string) bool { + if !g.writeOutput { + return false + } + if loc, ok := g.file.comments[path]; ok { + text := strings.TrimSuffix(loc.GetLeadingComments(), "\n") + for _, line := range strings.Split(text, "\n") { + g.P("// ", strings.TrimPrefix(line, " ")) + } + return true + } + return false +} + +// Comments returns any comments from the source .proto file and empty string if comments not found. +// The path is a comma-separated list of intergers. +// See descriptor.proto for its format. +func (g *Generator) Comments(path string) string { + loc, ok := g.file.comments[path] + if !ok { + return "" + } + text := strings.TrimSuffix(loc.GetLeadingComments(), "\n") + return text +} + +func (g *Generator) fileByName(filename string) *FileDescriptor { + return g.allFilesByName[filename] +} + +// weak returns whether the ith import of the current file is a weak import. +func (g *Generator) weak(i int32) bool { + for _, j := range g.file.WeakDependency { + if j == i { + return true + } + } + return false +} + +// Generate the imports +func (g *Generator) generateImports() { + // We almost always need a proto import. Rather than computing when we + // do, which is tricky when there's a plugin, just import it and + // reference it later. The same argument applies to the fmt and math packages. + if gogoproto.ImportsGoGoProto(g.file.FileDescriptorProto) { + g.PrintImport(g.Pkg["proto"], g.ImportPrefix+"github.com/gogo/protobuf/proto") + if gogoproto.RegistersGolangProto(g.file.FileDescriptorProto) { + g.PrintImport(g.Pkg["golang_proto"], g.ImportPrefix+"github.com/golang/protobuf/proto") + } + } else { + g.PrintImport(g.Pkg["proto"], g.ImportPrefix+"github.com/golang/protobuf/proto") + } + g.PrintImport(g.Pkg["fmt"], "fmt") + g.PrintImport(g.Pkg["math"], "math") + + for i, s := range g.file.Dependency { + fd := g.fileByName(s) + // Do not import our own package. + if fd.PackageName() == g.packageName { + continue + } + filename := fd.goFileName() + // By default, import path is the dirname of the Go filename. + importPath := path.Dir(filename) + if substitution, ok := g.ImportMap[s]; ok { + importPath = substitution + } + importPath = g.ImportPrefix + importPath + // Skip weak imports. + if g.weak(int32(i)) { + g.P("// skipping weak import ", fd.PackageName(), " ", strconv.Quote(importPath)) + continue + } + // We need to import all the dependencies, even if we don't reference them, + // because other code and tools depend on having the full transitive closure + // of protocol buffer types in the binary. + if _, ok := g.usedPackages[fd.PackageName()]; ok { + g.PrintImport(fd.PackageName(), importPath) + } else { + g.P("import _ ", strconv.Quote(importPath)) + } + } + g.P() + for _, s := range g.customImports { + s1 := strings.Map(badToUnderscore, s) + g.PrintImport(s1, s) + } + g.P() + // TODO: may need to worry about uniqueness across plugins + for _, p := range plugins { + p.GenerateImports(g.file) + g.P() + } + g.P("// Reference imports to suppress errors if they are not otherwise used.") + g.P("var _ = ", g.Pkg["proto"], ".Marshal") + if gogoproto.ImportsGoGoProto(g.file.FileDescriptorProto) && gogoproto.RegistersGolangProto(g.file.FileDescriptorProto) { + g.P("var _ = ", g.Pkg["golang_proto"], ".Marshal") + } + g.P("var _ = ", g.Pkg["fmt"], ".Errorf") + g.P("var _ = ", g.Pkg["math"], ".Inf") + for _, cimport := range g.customImports { + if cimport == "time" { + g.P("var _ = time.Kitchen") + break + } + } + g.P() +} + +func (g *Generator) generateImported(id *ImportedDescriptor) { + // Don't generate public import symbols for files that we are generating + // code for, since those symbols will already be in this package. + // We can't simply avoid creating the ImportedDescriptor objects, + // because g.genFiles isn't populated at that stage. + tn := id.TypeName() + sn := tn[len(tn)-1] + df := g.FileOf(id.o.File()) + filename := *df.Name + for _, fd := range g.genFiles { + if *fd.Name == filename { + g.P("// Ignoring public import of ", sn, " from ", filename) + g.P() + return + } + } + g.P("// ", sn, " from public import ", filename) + g.usedPackages[df.PackageName()] = true + + for _, sym := range df.exported[id.o] { + sym.GenerateAlias(g, df.PackageName()) + } + + g.P() +} + +// Generate the enum definitions for this EnumDescriptor. +func (g *Generator) generateEnum(enum *EnumDescriptor) { + // The full type name + typeName := enum.alias() + // The full type name, CamelCased. + ccTypeName := CamelCaseSlice(typeName) + ccPrefix := enum.prefix() + + g.PrintComments(enum.path) + if !gogoproto.EnabledGoEnumPrefix(enum.file, enum.EnumDescriptorProto) { + ccPrefix = "" + } + + if gogoproto.HasEnumDecl(enum.file, enum.EnumDescriptorProto) { + g.P("type ", ccTypeName, " int32") + g.file.addExport(enum, enumSymbol{ccTypeName, enum.proto3()}) + g.P("const (") + g.In() + for i, e := range enum.Value { + g.PrintComments(fmt.Sprintf("%s,%d,%d", enum.path, enumValuePath, i)) + name := *e.Name + if gogoproto.IsEnumValueCustomName(e) { + name = gogoproto.GetEnumValueCustomName(e) + } + name = ccPrefix + name + + g.P(name, " ", ccTypeName, " = ", e.Number) + g.file.addExport(enum, constOrVarSymbol{name, "const", ccTypeName}) + } + g.Out() + g.P(")") + } + + g.P("var ", ccTypeName, "_name = map[int32]string{") + g.In() + generated := make(map[int32]bool) // avoid duplicate values + for _, e := range enum.Value { + duplicate := "" + if _, present := generated[*e.Number]; present { + duplicate = "// Duplicate value: " + } + g.P(duplicate, e.Number, ": ", strconv.Quote(*e.Name), ",") + generated[*e.Number] = true + } + g.Out() + g.P("}") + g.P("var ", ccTypeName, "_value = map[string]int32{") + g.In() + for _, e := range enum.Value { + g.P(strconv.Quote(*e.Name), ": ", e.Number, ",") + } + g.Out() + g.P("}") + + if !enum.proto3() { + g.P("func (x ", ccTypeName, ") Enum() *", ccTypeName, " {") + g.In() + g.P("p := new(", ccTypeName, ")") + g.P("*p = x") + g.P("return p") + g.Out() + g.P("}") + } + + if gogoproto.IsGoEnumStringer(g.file.FileDescriptorProto, enum.EnumDescriptorProto) { + g.P("func (x ", ccTypeName, ") String() string {") + g.In() + g.P("return ", g.Pkg["proto"], ".EnumName(", ccTypeName, "_name, int32(x))") + g.Out() + g.P("}") + } + + if !enum.proto3() && !gogoproto.IsGoEnumStringer(g.file.FileDescriptorProto, enum.EnumDescriptorProto) { + g.P("func (x ", ccTypeName, ") MarshalJSON() ([]byte, error) {") + g.In() + g.P("return ", g.Pkg["proto"], ".MarshalJSONEnum(", ccTypeName, "_name, int32(x))") + g.Out() + g.P("}") + } + if !enum.proto3() { + g.P("func (x *", ccTypeName, ") UnmarshalJSON(data []byte) error {") + g.In() + g.P("value, err := ", g.Pkg["proto"], ".UnmarshalJSONEnum(", ccTypeName, `_value, data, "`, ccTypeName, `")`) + g.P("if err != nil {") + g.In() + g.P("return err") + g.Out() + g.P("}") + g.P("*x = ", ccTypeName, "(value)") + g.P("return nil") + g.Out() + g.P("}") + } + + var indexes []string + for m := enum.parent; m != nil; m = m.parent { + // XXX: skip groups? + indexes = append([]string{strconv.Itoa(m.index)}, indexes...) + } + indexes = append(indexes, strconv.Itoa(enum.index)) + g.P("func (", ccTypeName, ") EnumDescriptor() ([]byte, []int) { return ", g.file.VarName(), ", []int{", strings.Join(indexes, ", "), "} }") + if enum.file.GetPackage() == "google.protobuf" && enum.GetName() == "NullValue" { + g.P("func (", ccTypeName, `) XXX_WellKnownType() string { return "`, enum.GetName(), `" }`) + } + g.P() +} + +// The tag is a string like "varint,2,opt,name=fieldname,def=7" that +// identifies details of the field for the protocol buffer marshaling and unmarshaling +// code. The fields are: +// wire encoding +// protocol tag number +// opt,req,rep for optional, required, or repeated +// packed whether the encoding is "packed" (optional; repeated primitives only) +// name= the original declared name +// enum= the name of the enum type if it is an enum-typed field. +// proto3 if this field is in a proto3 message +// def= string representation of the default value, if any. +// The default value must be in a representation that can be used at run-time +// to generate the default value. Thus bools become 0 and 1, for instance. +func (g *Generator) goTag(message *Descriptor, field *descriptor.FieldDescriptorProto, wiretype string) string { + optrepreq := "" + switch { + case isOptional(field): + optrepreq = "opt" + case isRequired(field): + optrepreq = "req" + case isRepeated(field): + optrepreq = "rep" + } + var defaultValue string + if dv := field.DefaultValue; dv != nil { // set means an explicit default + defaultValue = *dv + // Some types need tweaking. + switch *field.Type { + case descriptor.FieldDescriptorProto_TYPE_BOOL: + if defaultValue == "true" { + defaultValue = "1" + } else { + defaultValue = "0" + } + case descriptor.FieldDescriptorProto_TYPE_STRING, + descriptor.FieldDescriptorProto_TYPE_BYTES: + // Nothing to do. Quoting is done for the whole tag. + case descriptor.FieldDescriptorProto_TYPE_ENUM: + // For enums we need to provide the integer constant. + obj := g.ObjectNamed(field.GetTypeName()) + if id, ok := obj.(*ImportedDescriptor); ok { + // It is an enum that was publicly imported. + // We need the underlying type. + obj = id.o + } + enum, ok := obj.(*EnumDescriptor) + if !ok { + log.Printf("obj is a %T", obj) + if id, ok := obj.(*ImportedDescriptor); ok { + log.Printf("id.o is a %T", id.o) + } + g.Fail("unknown enum type", CamelCaseSlice(obj.TypeName())) + } + defaultValue = enum.integerValueAsString(defaultValue) + } + defaultValue = ",def=" + defaultValue + } + enum := "" + if *field.Type == descriptor.FieldDescriptorProto_TYPE_ENUM { + // We avoid using obj.PackageName(), because we want to use the + // original (proto-world) package name. + obj := g.ObjectNamed(field.GetTypeName()) + if id, ok := obj.(*ImportedDescriptor); ok { + obj = id.o + } + enum = ",enum=" + if pkg := obj.File().GetPackage(); pkg != "" { + enum += pkg + "." + } + enum += CamelCaseSlice(obj.TypeName()) + } + packed := "" + if (field.Options != nil && field.Options.GetPacked()) || + // Per https://developers.google.com/protocol-buffers/docs/proto3#simple: + // "In proto3, repeated fields of scalar numeric types use packed encoding by default." + (message.proto3() && (field.Options == nil || field.Options.Packed == nil) && + isRepeated(field) && IsScalar(field)) { + packed = ",packed" + } + fieldName := field.GetName() + name := fieldName + if *field.Type == descriptor.FieldDescriptorProto_TYPE_GROUP { + // We must use the type name for groups instead of + // the field name to preserve capitalization. + // type_name in FieldDescriptorProto is fully-qualified, + // but we only want the local part. + name = *field.TypeName + if i := strings.LastIndex(name, "."); i >= 0 { + name = name[i+1:] + } + } + if json := field.GetJsonName(); json != "" && json != name { + // TODO: escaping might be needed, in which case + // perhaps this should be in its own "json" tag. + name += ",json=" + json + } + name = ",name=" + name + + embed := "" + if gogoproto.IsEmbed(field) { + embed = ",embedded=" + fieldName + } + + ctype := "" + if gogoproto.IsCustomType(field) { + ctype = ",customtype=" + gogoproto.GetCustomType(field) + } + + casttype := "" + if gogoproto.IsCastType(field) { + casttype = ",casttype=" + gogoproto.GetCastType(field) + } + + castkey := "" + if gogoproto.IsCastKey(field) { + castkey = ",castkey=" + gogoproto.GetCastKey(field) + } + + castvalue := "" + if gogoproto.IsCastValue(field) { + castvalue = ",castvalue=" + gogoproto.GetCastValue(field) + // record the original message type for jsonpb reconstruction + desc := g.ObjectNamed(field.GetTypeName()) + if d, ok := desc.(*Descriptor); ok && d.GetOptions().GetMapEntry() { + valueField := d.Field[1] + if valueField.IsMessage() { + castvalue += ",castvaluetype=" + strings.TrimPrefix(valueField.GetTypeName(), ".") + } + } + } + + if message.proto3() { + // We only need the extra tag for []byte fields; + // no need to add noise for the others. + if *field.Type != descriptor.FieldDescriptorProto_TYPE_MESSAGE && + *field.Type != descriptor.FieldDescriptorProto_TYPE_GROUP && + !field.IsRepeated() { + name += ",proto3" + } + } + oneof := "" + if field.OneofIndex != nil { + oneof = ",oneof" + } + stdtime := "" + if gogoproto.IsStdTime(field) { + stdtime = ",stdtime" + } + stdduration := "" + if gogoproto.IsStdDuration(field) { + stdduration = ",stdduration" + } + return strconv.Quote(fmt.Sprintf("%s,%d,%s%s%s%s%s%s%s%s%s%s%s%s%s", + wiretype, + field.GetNumber(), + optrepreq, + packed, + name, + enum, + oneof, + defaultValue, + embed, + ctype, + casttype, + castkey, + castvalue, + stdtime, + stdduration)) +} + +func needsStar(field *descriptor.FieldDescriptorProto, proto3 bool, allowOneOf bool) bool { + if isRepeated(field) && + (*field.Type != descriptor.FieldDescriptorProto_TYPE_MESSAGE || gogoproto.IsCustomType(field)) && + (*field.Type != descriptor.FieldDescriptorProto_TYPE_GROUP) { + return false + } + if *field.Type == descriptor.FieldDescriptorProto_TYPE_BYTES && !gogoproto.IsCustomType(field) { + return false + } + if !gogoproto.IsNullable(field) { + return false + } + if field.OneofIndex != nil && allowOneOf && + (*field.Type != descriptor.FieldDescriptorProto_TYPE_MESSAGE) && + (*field.Type != descriptor.FieldDescriptorProto_TYPE_GROUP) { + return false + } + if proto3 && + (*field.Type != descriptor.FieldDescriptorProto_TYPE_MESSAGE) && + (*field.Type != descriptor.FieldDescriptorProto_TYPE_GROUP) && + !gogoproto.IsCustomType(field) { + return false + } + return true +} + +// TypeName is the printed name appropriate for an item. If the object is in the current file, +// TypeName drops the package name and underscores the rest. +// Otherwise the object is from another package; and the result is the underscored +// package name followed by the item name. +// The result always has an initial capital. +func (g *Generator) TypeName(obj Object) string { + return g.DefaultPackageName(obj) + CamelCaseSlice(obj.TypeName()) +} + +// TypeNameWithPackage is like TypeName, but always includes the package +// name even if the object is in our own package. +func (g *Generator) TypeNameWithPackage(obj Object) string { + return obj.PackageName() + CamelCaseSlice(obj.TypeName()) +} + +// GoType returns a string representing the type name, and the wire type +func (g *Generator) GoType(message *Descriptor, field *descriptor.FieldDescriptorProto) (typ string, wire string) { + // TODO: Options. + switch *field.Type { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: + typ, wire = "float64", "fixed64" + case descriptor.FieldDescriptorProto_TYPE_FLOAT: + typ, wire = "float32", "fixed32" + case descriptor.FieldDescriptorProto_TYPE_INT64: + typ, wire = "int64", "varint" + case descriptor.FieldDescriptorProto_TYPE_UINT64: + typ, wire = "uint64", "varint" + case descriptor.FieldDescriptorProto_TYPE_INT32: + typ, wire = "int32", "varint" + case descriptor.FieldDescriptorProto_TYPE_UINT32: + typ, wire = "uint32", "varint" + case descriptor.FieldDescriptorProto_TYPE_FIXED64: + typ, wire = "uint64", "fixed64" + case descriptor.FieldDescriptorProto_TYPE_FIXED32: + typ, wire = "uint32", "fixed32" + case descriptor.FieldDescriptorProto_TYPE_BOOL: + typ, wire = "bool", "varint" + case descriptor.FieldDescriptorProto_TYPE_STRING: + typ, wire = "string", "bytes" + case descriptor.FieldDescriptorProto_TYPE_GROUP: + desc := g.ObjectNamed(field.GetTypeName()) + typ, wire = g.TypeName(desc), "group" + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + desc := g.ObjectNamed(field.GetTypeName()) + typ, wire = g.TypeName(desc), "bytes" + case descriptor.FieldDescriptorProto_TYPE_BYTES: + typ, wire = "[]byte", "bytes" + case descriptor.FieldDescriptorProto_TYPE_ENUM: + desc := g.ObjectNamed(field.GetTypeName()) + typ, wire = g.TypeName(desc), "varint" + case descriptor.FieldDescriptorProto_TYPE_SFIXED32: + typ, wire = "int32", "fixed32" + case descriptor.FieldDescriptorProto_TYPE_SFIXED64: + typ, wire = "int64", "fixed64" + case descriptor.FieldDescriptorProto_TYPE_SINT32: + typ, wire = "int32", "zigzag32" + case descriptor.FieldDescriptorProto_TYPE_SINT64: + typ, wire = "int64", "zigzag64" + default: + g.Fail("unknown type for", field.GetName()) + } + switch { + case gogoproto.IsCustomType(field) && gogoproto.IsCastType(field): + g.Fail(field.GetName() + " cannot be custom type and cast type") + case gogoproto.IsCustomType(field): + var packageName string + var err error + packageName, typ, err = getCustomType(field) + if err != nil { + g.Fail(err.Error()) + } + if len(packageName) > 0 { + g.customImports = append(g.customImports, packageName) + } + case gogoproto.IsCastType(field): + var packageName string + var err error + packageName, typ, err = getCastType(field) + if err != nil { + g.Fail(err.Error()) + } + if len(packageName) > 0 { + g.customImports = append(g.customImports, packageName) + } + case gogoproto.IsStdTime(field): + g.customImports = append(g.customImports, "time") + typ = "time.Time" + case gogoproto.IsStdDuration(field): + g.customImports = append(g.customImports, "time") + typ = "time.Duration" + } + if needsStar(field, g.file.proto3 && field.Extendee == nil, message != nil && message.allowOneof()) { + typ = "*" + typ + } + if isRepeated(field) { + typ = "[]" + typ + } + return +} + +// GoMapDescriptor is a full description of the map output struct. +type GoMapDescriptor struct { + GoType string + + KeyField *descriptor.FieldDescriptorProto + KeyAliasField *descriptor.FieldDescriptorProto + KeyTag string + + ValueField *descriptor.FieldDescriptorProto + ValueAliasField *descriptor.FieldDescriptorProto + ValueTag string +} + +func (g *Generator) GoMapType(d *Descriptor, field *descriptor.FieldDescriptorProto) *GoMapDescriptor { + if d == nil { + byName := g.ObjectNamed(field.GetTypeName()) + desc, ok := byName.(*Descriptor) + if byName == nil || !ok || !desc.GetOptions().GetMapEntry() { + g.Fail(fmt.Sprintf("field %s is not a map", field.GetTypeName())) + return nil + } + d = desc + } + + m := &GoMapDescriptor{ + KeyField: d.Field[0], + ValueField: d.Field[1], + } + + // Figure out the Go types and tags for the key and value types. + m.KeyAliasField, m.ValueAliasField = g.GetMapKeyField(field, m.KeyField), g.GetMapValueField(field, m.ValueField) + keyType, keyWire := g.GoType(d, m.KeyAliasField) + valType, valWire := g.GoType(d, m.ValueAliasField) + + m.KeyTag, m.ValueTag = g.goTag(d, m.KeyField, keyWire), g.goTag(d, m.ValueField, valWire) + + if gogoproto.IsCastType(field) { + var packageName string + var err error + packageName, typ, err := getCastType(field) + if err != nil { + g.Fail(err.Error()) + } + if len(packageName) > 0 { + g.customImports = append(g.customImports, packageName) + } + m.GoType = typ + return m + } + + // We don't use stars, except for message-typed values. + // Message and enum types are the only two possibly foreign types used in maps, + // so record their use. They are not permitted as map keys. + keyType = strings.TrimPrefix(keyType, "*") + switch *m.ValueAliasField.Type { + case descriptor.FieldDescriptorProto_TYPE_ENUM: + valType = strings.TrimPrefix(valType, "*") + g.RecordTypeUse(m.ValueAliasField.GetTypeName()) + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + if !gogoproto.IsNullable(m.ValueAliasField) { + valType = strings.TrimPrefix(valType, "*") + } + if !gogoproto.IsStdTime(field) && !gogoproto.IsStdDuration(field) { + g.RecordTypeUse(m.ValueAliasField.GetTypeName()) + } + default: + if gogoproto.IsCustomType(m.ValueAliasField) { + if !gogoproto.IsNullable(m.ValueAliasField) { + valType = strings.TrimPrefix(valType, "*") + } + g.RecordTypeUse(m.ValueAliasField.GetTypeName()) + } else { + valType = strings.TrimPrefix(valType, "*") + } + } + + m.GoType = fmt.Sprintf("map[%s]%s", keyType, valType) + return m +} + +func (g *Generator) RecordTypeUse(t string) { + if obj, ok := g.typeNameToObject[t]; ok { + // Call ObjectNamed to get the true object to record the use. + obj = g.ObjectNamed(t) + g.usedPackages[obj.PackageName()] = true + } +} + +// Method names that may be generated. Fields with these names get an +// underscore appended. Any change to this set is a potential incompatible +// API change because it changes generated field names. +var methodNames = [...]string{ + "Reset", + "String", + "ProtoMessage", + "Marshal", + "Unmarshal", + "ExtensionRangeArray", + "ExtensionMap", + "Descriptor", + "MarshalTo", + "Equal", + "VerboseEqual", + "GoString", + "ProtoSize", +} + +// Names of messages in the `google.protobuf` package for which +// we will generate XXX_WellKnownType methods. +var wellKnownTypes = map[string]bool{ + "Any": true, + "Duration": true, + "Empty": true, + "Struct": true, + "Timestamp": true, + + "Value": true, + "ListValue": true, + "DoubleValue": true, + "FloatValue": true, + "Int64Value": true, + "UInt64Value": true, + "Int32Value": true, + "UInt32Value": true, + "BoolValue": true, + "StringValue": true, + "BytesValue": true, +} + +// Generate the type and default constant definitions for this Descriptor. +func (g *Generator) generateMessage(message *Descriptor) { + // The full type name + typeName := message.TypeName() + // The full type name, CamelCased. + ccTypeName := CamelCaseSlice(typeName) + + usedNames := make(map[string]bool) + for _, n := range methodNames { + usedNames[n] = true + } + if !gogoproto.IsProtoSizer(message.file, message.DescriptorProto) { + usedNames["Size"] = true + } + fieldNames := make(map[*descriptor.FieldDescriptorProto]string) + fieldGetterNames := make(map[*descriptor.FieldDescriptorProto]string) + fieldTypes := make(map[*descriptor.FieldDescriptorProto]string) + mapFieldTypes := make(map[*descriptor.FieldDescriptorProto]string) + + oneofFieldName := make(map[int32]string) // indexed by oneof_index field of FieldDescriptorProto + oneofDisc := make(map[int32]string) // name of discriminator method + oneofTypeName := make(map[*descriptor.FieldDescriptorProto]string) // without star + oneofInsertPoints := make(map[int32]int) // oneof_index => offset of g.Buffer + + // allocNames finds a conflict-free variation of the given strings, + // consistently mutating their suffixes. + // It returns the same number of strings. + allocNames := func(ns ...string) []string { + Loop: + for { + for _, n := range ns { + if usedNames[n] { + for i := range ns { + ns[i] += "_" + } + continue Loop + } + } + for _, n := range ns { + usedNames[n] = true + } + return ns + } + } + + for _, field := range message.Field { + // Allocate the getter and the field at the same time so name + // collisions create field/method consistent names. + // TODO: This allocation occurs based on the order of the fields + // in the proto file, meaning that a change in the field + // ordering can change generated Method/Field names. + base := CamelCase(*field.Name) + if gogoproto.IsCustomName(field) { + base = gogoproto.GetCustomName(field) + } + ns := allocNames(base, "Get"+base) + fieldName, fieldGetterName := ns[0], ns[1] + fieldNames[field] = fieldName + fieldGetterNames[field] = fieldGetterName + } + + if gogoproto.HasTypeDecl(message.file, message.DescriptorProto) { + g.PrintComments(message.path) + g.P("type ", ccTypeName, " struct {") + g.In() + + for i, field := range message.Field { + fieldName := fieldNames[field] + typename, wiretype := g.GoType(message, field) + jsonName := *field.Name + jsonTag := jsonName + ",omitempty" + repeatedNativeType := (!field.IsMessage() && !gogoproto.IsCustomType(field) && field.IsRepeated()) + if !gogoproto.IsNullable(field) && !repeatedNativeType { + jsonTag = jsonName + } + gogoJsonTag := gogoproto.GetJsonTag(field) + if gogoJsonTag != nil { + jsonTag = *gogoJsonTag + } + gogoMoreTags := gogoproto.GetMoreTags(field) + moreTags := "" + if gogoMoreTags != nil { + moreTags = " " + *gogoMoreTags + } + tag := fmt.Sprintf("protobuf:%s json:%q%s", g.goTag(message, field, wiretype), jsonTag, moreTags) + if *field.Type == descriptor.FieldDescriptorProto_TYPE_MESSAGE && gogoproto.IsEmbed(field) { + fieldName = "" + } + + oneof := field.OneofIndex != nil && message.allowOneof() + if oneof && oneofFieldName[*field.OneofIndex] == "" { + odp := message.OneofDecl[int(*field.OneofIndex)] + fname := allocNames(CamelCase(odp.GetName()))[0] + + // This is the first field of a oneof we haven't seen before. + // Generate the union field. + com := g.PrintComments(fmt.Sprintf("%s,%d,%d", message.path, messageOneofPath, *field.OneofIndex)) + if com { + g.P("//") + } + g.P("// Types that are valid to be assigned to ", fname, ":") + // Generate the rest of this comment later, + // when we've computed any disambiguation. + oneofInsertPoints[*field.OneofIndex] = g.Buffer.Len() + + dname := "is" + ccTypeName + "_" + fname + oneofFieldName[*field.OneofIndex] = fname + oneofDisc[*field.OneofIndex] = dname + otag := `protobuf_oneof:"` + odp.GetName() + `"` + g.P(fname, " ", dname, " `", otag, "`") + } + + if *field.Type == descriptor.FieldDescriptorProto_TYPE_MESSAGE { + desc := g.ObjectNamed(field.GetTypeName()) + if d, ok := desc.(*Descriptor); ok && d.GetOptions().GetMapEntry() { + m := g.GoMapType(d, field) + typename = m.GoType + mapFieldTypes[field] = typename // record for the getter generation + + tag += fmt.Sprintf(" protobuf_key:%s protobuf_val:%s", m.KeyTag, m.ValueTag) + } + } + + fieldTypes[field] = typename + + if oneof { + tname := ccTypeName + "_" + fieldName + // It is possible for this to collide with a message or enum + // nested in this message. Check for collisions. + for { + ok := true + for _, desc := range message.nested { + if CamelCaseSlice(desc.TypeName()) == tname { + ok = false + break + } + } + for _, enum := range message.enums { + if CamelCaseSlice(enum.TypeName()) == tname { + ok = false + break + } + } + if !ok { + tname += "_" + continue + } + break + } + + oneofTypeName[field] = tname + continue + } + + g.PrintComments(fmt.Sprintf("%s,%d,%d", message.path, messageFieldPath, i)) + g.P(fieldName, "\t", typename, "\t`", tag, "`") + if !gogoproto.IsStdTime(field) && !gogoproto.IsStdDuration(field) { + g.RecordTypeUse(field.GetTypeName()) + } + } + if len(message.ExtensionRange) > 0 { + if gogoproto.HasExtensionsMap(g.file.FileDescriptorProto, message.DescriptorProto) { + g.P(g.Pkg["proto"], ".XXX_InternalExtensions `json:\"-\"`") + } else { + g.P("XXX_extensions\t\t[]byte `protobuf:\"bytes,0,opt\" json:\"-\"`") + } + } + if gogoproto.HasUnrecognized(g.file.FileDescriptorProto, message.DescriptorProto) && !message.proto3() { + g.P("XXX_unrecognized\t[]byte `json:\"-\"`") + } + g.Out() + g.P("}") + } else { + // Even if the type does not need to be generated, we need to iterate + // over all its fields to be able to mark as used any imported types + // used by those fields. + for _, field := range message.Field { + if !gogoproto.IsStdTime(field) && !gogoproto.IsStdDuration(field) { + g.RecordTypeUse(field.GetTypeName()) + } + } + } + + // Update g.Buffer to list valid oneof types. + // We do this down here, after we've disambiguated the oneof type names. + // We go in reverse order of insertion point to avoid invalidating offsets. + for oi := int32(len(message.OneofDecl)); oi >= 0; oi-- { + ip := oneofInsertPoints[oi] + all := g.Buffer.Bytes() + rem := all[ip:] + g.Buffer = bytes.NewBuffer(all[:ip:ip]) // set cap so we don't scribble on rem + for _, field := range message.Field { + if field.OneofIndex == nil || *field.OneofIndex != oi { + continue + } + g.P("//\t*", oneofTypeName[field]) + } + g.Buffer.Write(rem) + } + + // Reset, String and ProtoMessage methods. + g.P("func (m *", ccTypeName, ") Reset() { *m = ", ccTypeName, "{} }") + if gogoproto.EnabledGoStringer(g.file.FileDescriptorProto, message.DescriptorProto) { + g.P("func (m *", ccTypeName, ") String() string { return ", g.Pkg["proto"], ".CompactTextString(m) }") + } + g.P("func (*", ccTypeName, ") ProtoMessage() {}") + var indexes []string + for m := message; m != nil; m = m.parent { + indexes = append([]string{strconv.Itoa(m.index)}, indexes...) + } + g.P("func (*", ccTypeName, ") Descriptor() ([]byte, []int) { return ", g.file.VarName(), ", []int{", strings.Join(indexes, ", "), "} }") + // TODO: Revisit the decision to use a XXX_WellKnownType method + // if we change proto.MessageName to work with multiple equivalents. + if message.file.GetPackage() == "google.protobuf" && wellKnownTypes[message.GetName()] { + g.P("func (*", ccTypeName, `) XXX_WellKnownType() string { return "`, message.GetName(), `" }`) + } + // Extension support methods + var hasExtensions, isMessageSet bool + if len(message.ExtensionRange) > 0 { + hasExtensions = true + // message_set_wire_format only makes sense when extensions are defined. + if opts := message.Options; opts != nil && opts.GetMessageSetWireFormat() { + isMessageSet = true + g.P() + g.P("func (m *", ccTypeName, ") Marshal() ([]byte, error) {") + g.In() + g.P("return ", g.Pkg["proto"], ".MarshalMessageSet(&m.XXX_InternalExtensions)") + g.Out() + g.P("}") + g.P("func (m *", ccTypeName, ") Unmarshal(buf []byte) error {") + g.In() + g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSet(buf, &m.XXX_InternalExtensions)") + g.Out() + g.P("}") + g.P("func (m *", ccTypeName, ") MarshalJSON() ([]byte, error) {") + g.In() + g.P("return ", g.Pkg["proto"], ".MarshalMessageSetJSON(&m.XXX_InternalExtensions)") + g.Out() + g.P("}") + g.P("func (m *", ccTypeName, ") UnmarshalJSON(buf []byte) error {") + g.In() + g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSetJSON(buf, &m.XXX_InternalExtensions)") + g.Out() + g.P("}") + g.P("// ensure ", ccTypeName, " satisfies proto.Marshaler and proto.Unmarshaler") + g.P("var _ ", g.Pkg["proto"], ".Marshaler = (*", ccTypeName, ")(nil)") + g.P("var _ ", g.Pkg["proto"], ".Unmarshaler = (*", ccTypeName, ")(nil)") + } + + g.P() + g.P("var extRange_", ccTypeName, " = []", g.Pkg["proto"], ".ExtensionRange{") + g.In() + for _, r := range message.ExtensionRange { + end := fmt.Sprint(*r.End - 1) // make range inclusive on both ends + g.P("{Start: ", r.Start, ", End: ", end, "},") + } + g.Out() + g.P("}") + g.P("func (*", ccTypeName, ") ExtensionRangeArray() []", g.Pkg["proto"], ".ExtensionRange {") + g.In() + g.P("return extRange_", ccTypeName) + g.Out() + g.P("}") + if !gogoproto.HasExtensionsMap(g.file.FileDescriptorProto, message.DescriptorProto) { + g.P("func (m *", ccTypeName, ") GetExtensions() *[]byte {") + g.In() + g.P("if m.XXX_extensions == nil {") + g.In() + g.P("m.XXX_extensions = make([]byte, 0)") + g.Out() + g.P("}") + g.P("return &m.XXX_extensions") + g.Out() + g.P("}") + } + } + + // Default constants + defNames := make(map[*descriptor.FieldDescriptorProto]string) + for _, field := range message.Field { + def := field.GetDefaultValue() + if def == "" { + continue + } + if !gogoproto.IsNullable(field) { + g.Fail("illegal default value: ", field.GetName(), " in ", message.GetName(), " is not nullable and is thus not allowed to have a default value") + } + fieldname := "Default_" + ccTypeName + "_" + CamelCase(*field.Name) + defNames[field] = fieldname + typename, _ := g.GoType(message, field) + if typename[0] == '*' { + typename = typename[1:] + } + kind := "const " + switch { + case typename == "bool": + case typename == "string": + def = strconv.Quote(def) + case typename == "[]byte": + def = "[]byte(" + strconv.Quote(unescape(def)) + ")" + kind = "var " + case def == "inf", def == "-inf", def == "nan": + // These names are known to, and defined by, the protocol language. + switch def { + case "inf": + def = "math.Inf(1)" + case "-inf": + def = "math.Inf(-1)" + case "nan": + def = "math.NaN()" + } + if *field.Type == descriptor.FieldDescriptorProto_TYPE_FLOAT { + def = "float32(" + def + ")" + } + kind = "var " + case *field.Type == descriptor.FieldDescriptorProto_TYPE_ENUM: + // Must be an enum. Need to construct the prefixed name. + obj := g.ObjectNamed(field.GetTypeName()) + var enum *EnumDescriptor + if id, ok := obj.(*ImportedDescriptor); ok { + // The enum type has been publicly imported. + enum, _ = id.o.(*EnumDescriptor) + } else { + enum, _ = obj.(*EnumDescriptor) + } + if enum == nil { + log.Printf("don't know how to generate constant for %s", fieldname) + continue + } + + // hunt down the actual enum corresponding to the default + var enumValue *descriptor.EnumValueDescriptorProto + for _, ev := range enum.Value { + if def == ev.GetName() { + enumValue = ev + } + } + + if enumValue != nil { + if gogoproto.IsEnumValueCustomName(enumValue) { + def = gogoproto.GetEnumValueCustomName(enumValue) + } + } else { + g.Fail(fmt.Sprintf("could not resolve default enum value for %v.%v", + g.DefaultPackageName(obj), def)) + + } + + if gogoproto.EnabledGoEnumPrefix(enum.file, enum.EnumDescriptorProto) { + def = g.DefaultPackageName(obj) + enum.prefix() + def + } else { + def = g.DefaultPackageName(obj) + def + } + } + g.P(kind, fieldname, " ", typename, " = ", def) + g.file.addExport(message, constOrVarSymbol{fieldname, kind, ""}) + } + g.P() + + // Oneof per-field types, discriminants and getters. + if message.allowOneof() { + // Generate unexported named types for the discriminant interfaces. + // We shouldn't have to do this, but there was (~19 Aug 2015) a compiler/linker bug + // that was triggered by using anonymous interfaces here. + // TODO: Revisit this and consider reverting back to anonymous interfaces. + for oi := range message.OneofDecl { + dname := oneofDisc[int32(oi)] + g.P("type ", dname, " interface {") + g.In() + g.P(dname, "()") + if gogoproto.HasEqual(g.file.FileDescriptorProto, message.DescriptorProto) { + g.P(`Equal(interface{}) bool`) + } + if gogoproto.HasVerboseEqual(g.file.FileDescriptorProto, message.DescriptorProto) { + g.P(`VerboseEqual(interface{}) error`) + } + if gogoproto.IsMarshaler(g.file.FileDescriptorProto, message.DescriptorProto) || + gogoproto.IsUnsafeMarshaler(g.file.FileDescriptorProto, message.DescriptorProto) { + g.P(`MarshalTo([]byte) (int, error)`) + } + if gogoproto.IsSizer(g.file.FileDescriptorProto, message.DescriptorProto) { + g.P(`Size() int`) + } + if gogoproto.IsProtoSizer(g.file.FileDescriptorProto, message.DescriptorProto) { + g.P(`ProtoSize() int`) + } + g.Out() + g.P("}") + } + g.P() + for _, field := range message.Field { + if field.OneofIndex == nil { + continue + } + _, wiretype := g.GoType(message, field) + tag := "protobuf:" + g.goTag(message, field, wiretype) + g.P("type ", oneofTypeName[field], " struct{ ", fieldNames[field], " ", fieldTypes[field], " `", tag, "` }") + if !gogoproto.IsStdTime(field) && !gogoproto.IsStdDuration(field) { + g.RecordTypeUse(field.GetTypeName()) + } + } + g.P() + for _, field := range message.Field { + if field.OneofIndex == nil { + continue + } + g.P("func (*", oneofTypeName[field], ") ", oneofDisc[*field.OneofIndex], "() {}") + } + g.P() + for oi := range message.OneofDecl { + fname := oneofFieldName[int32(oi)] + g.P("func (m *", ccTypeName, ") Get", fname, "() ", oneofDisc[int32(oi)], " {") + g.P("if m != nil { return m.", fname, " }") + g.P("return nil") + g.P("}") + } + g.P() + } + + // Field getters + var getters []getterSymbol + for _, field := range message.Field { + oneof := field.OneofIndex != nil && message.allowOneof() + if !oneof && !gogoproto.HasGoGetters(g.file.FileDescriptorProto, message.DescriptorProto) { + continue + } + if gogoproto.IsEmbed(field) || gogoproto.IsCustomType(field) { + continue + } + fname := fieldNames[field] + typename, _ := g.GoType(message, field) + if t, ok := mapFieldTypes[field]; ok { + typename = t + } + mname := fieldGetterNames[field] + star := "" + if (*field.Type != descriptor.FieldDescriptorProto_TYPE_MESSAGE) && + (*field.Type != descriptor.FieldDescriptorProto_TYPE_GROUP) && + needsStar(field, g.file.proto3, message != nil && message.allowOneof()) && typename[0] == '*' { + typename = typename[1:] + star = "*" + } + + // Only export getter symbols for basic types, + // and for messages and enums in the same package. + // Groups are not exported. + // Foreign types can't be hoisted through a public import because + // the importer may not already be importing the defining .proto. + // As an example, imagine we have an import tree like this: + // A.proto -> B.proto -> C.proto + // If A publicly imports B, we need to generate the getters from B in A's output, + // but if one such getter returns something from C then we cannot do that + // because A is not importing C already. + var getter, genType bool + switch *field.Type { + case descriptor.FieldDescriptorProto_TYPE_GROUP: + getter = false + case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_ENUM: + // Only export getter if its return type is in this package. + getter = g.ObjectNamed(field.GetTypeName()).PackageName() == message.PackageName() + genType = true + default: + getter = true + } + if getter { + getters = append(getters, getterSymbol{ + name: mname, + typ: typename, + typeName: field.GetTypeName(), + genType: genType, + }) + } + + g.P("func (m *", ccTypeName, ") "+mname+"() "+typename+" {") + g.In() + def, hasDef := defNames[field] + typeDefaultIsNil := false // whether this field type's default value is a literal nil unless specified + switch *field.Type { + case descriptor.FieldDescriptorProto_TYPE_BYTES: + typeDefaultIsNil = !hasDef + case descriptor.FieldDescriptorProto_TYPE_GROUP, descriptor.FieldDescriptorProto_TYPE_MESSAGE: + typeDefaultIsNil = gogoproto.IsNullable(field) + } + if isRepeated(field) { + typeDefaultIsNil = true + } + if typeDefaultIsNil && !oneof { + // A bytes field with no explicit default needs less generated code, + // as does a message or group field, or a repeated field. + g.P("if m != nil {") + g.In() + g.P("return m." + fname) + g.Out() + g.P("}") + g.P("return nil") + g.Out() + g.P("}") + g.P() + continue + } + if !gogoproto.IsNullable(field) { + g.P("if m != nil {") + g.In() + g.P("return m." + fname) + g.Out() + g.P("}") + } else if !oneof { + if message.proto3() { + g.P("if m != nil {") + } else { + g.P("if m != nil && m." + fname + " != nil {") + } + g.In() + g.P("return " + star + "m." + fname) + g.Out() + g.P("}") + } else { + uname := oneofFieldName[*field.OneofIndex] + tname := oneofTypeName[field] + g.P("if x, ok := m.Get", uname, "().(*", tname, "); ok {") + g.P("return x.", fname) + g.P("}") + } + if hasDef { + if *field.Type != descriptor.FieldDescriptorProto_TYPE_BYTES { + g.P("return " + def) + } else { + // The default is a []byte var. + // Make a copy when returning it to be safe. + g.P("return append([]byte(nil), ", def, "...)") + } + } else { + switch *field.Type { + case descriptor.FieldDescriptorProto_TYPE_GROUP, + descriptor.FieldDescriptorProto_TYPE_MESSAGE: + if field.OneofIndex != nil { + g.P(`return nil`) + } else { + goTyp, _ := g.GoType(message, field) + goTypName := GoTypeToName(goTyp) + if !gogoproto.IsNullable(field) && gogoproto.IsStdDuration(field) { + g.P("return 0") + } else { + g.P("return ", goTypName, "{}") + } + } + case descriptor.FieldDescriptorProto_TYPE_BOOL: + g.P("return false") + case descriptor.FieldDescriptorProto_TYPE_STRING: + g.P(`return ""`) + case descriptor.FieldDescriptorProto_TYPE_BYTES: + // This is only possible for oneof fields. + g.P("return nil") + case descriptor.FieldDescriptorProto_TYPE_ENUM: + // The default default for an enum is the first value in the enum, + // not zero. + obj := g.ObjectNamed(field.GetTypeName()) + var enum *EnumDescriptor + if id, ok := obj.(*ImportedDescriptor); ok { + // The enum type has been publicly imported. + enum, _ = id.o.(*EnumDescriptor) + } else { + enum, _ = obj.(*EnumDescriptor) + } + if enum == nil { + log.Printf("don't know how to generate getter for %s", field.GetName()) + continue + } + if len(enum.Value) == 0 { + g.P("return 0 // empty enum") + } else { + first := enum.Value[0].GetName() + if gogoproto.IsEnumValueCustomName(enum.Value[0]) { + first = gogoproto.GetEnumValueCustomName(enum.Value[0]) + } + + if gogoproto.EnabledGoEnumPrefix(enum.file, enum.EnumDescriptorProto) { + g.P("return ", g.DefaultPackageName(obj)+enum.prefix()+first) + } else { + g.P("return ", g.DefaultPackageName(obj)+first) + } + } + default: + g.P("return 0") + } + } + g.Out() + g.P("}") + g.P() + } + + if !message.group { + ms := &messageSymbol{ + sym: ccTypeName, + hasExtensions: hasExtensions, + isMessageSet: isMessageSet, + hasOneof: len(message.OneofDecl) > 0, + getters: getters, + } + g.file.addExport(message, ms) + } + + // Oneof functions + if len(message.OneofDecl) > 0 && message.allowOneof() { + fieldWire := make(map[*descriptor.FieldDescriptorProto]string) + + // method + enc := "_" + ccTypeName + "_OneofMarshaler" + dec := "_" + ccTypeName + "_OneofUnmarshaler" + size := "_" + ccTypeName + "_OneofSizer" + encSig := "(msg " + g.Pkg["proto"] + ".Message, b *" + g.Pkg["proto"] + ".Buffer) error" + decSig := "(msg " + g.Pkg["proto"] + ".Message, tag, wire int, b *" + g.Pkg["proto"] + ".Buffer) (bool, error)" + sizeSig := "(msg " + g.Pkg["proto"] + ".Message) (n int)" + + g.P("// XXX_OneofFuncs is for the internal use of the proto package.") + g.P("func (*", ccTypeName, ") XXX_OneofFuncs() (func", encSig, ", func", decSig, ", func", sizeSig, ", []interface{}) {") + g.P("return ", enc, ", ", dec, ", ", size, ", []interface{}{") + for _, field := range message.Field { + if field.OneofIndex == nil { + continue + } + g.P("(*", oneofTypeName[field], ")(nil),") + } + g.P("}") + g.P("}") + g.P() + + // marshaler + g.P("func ", enc, encSig, " {") + g.P("m := msg.(*", ccTypeName, ")") + for oi, odp := range message.OneofDecl { + g.P("// ", odp.GetName()) + fname := oneofFieldName[int32(oi)] + g.P("switch x := m.", fname, ".(type) {") + for _, field := range message.Field { + if field.OneofIndex == nil || int(*field.OneofIndex) != oi { + continue + } + g.P("case *", oneofTypeName[field], ":") + var wire, pre, post string + val := "x." + fieldNames[field] // overridden for TYPE_BOOL + canFail := false // only TYPE_MESSAGE and TYPE_GROUP can fail + switch *field.Type { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: + wire = "WireFixed64" + pre = "b.EncodeFixed64(" + g.Pkg["math"] + ".Float64bits(" + post = "))" + case descriptor.FieldDescriptorProto_TYPE_FLOAT: + wire = "WireFixed32" + pre = "b.EncodeFixed32(uint64(" + g.Pkg["math"] + ".Float32bits(" + post = ")))" + case descriptor.FieldDescriptorProto_TYPE_INT64, + descriptor.FieldDescriptorProto_TYPE_UINT64: + wire = "WireVarint" + pre, post = "b.EncodeVarint(uint64(", "))" + case descriptor.FieldDescriptorProto_TYPE_INT32, + descriptor.FieldDescriptorProto_TYPE_UINT32, + descriptor.FieldDescriptorProto_TYPE_ENUM: + wire = "WireVarint" + pre, post = "b.EncodeVarint(uint64(", "))" + case descriptor.FieldDescriptorProto_TYPE_FIXED64, + descriptor.FieldDescriptorProto_TYPE_SFIXED64: + wire = "WireFixed64" + pre, post = "b.EncodeFixed64(uint64(", "))" + case descriptor.FieldDescriptorProto_TYPE_FIXED32, + descriptor.FieldDescriptorProto_TYPE_SFIXED32: + wire = "WireFixed32" + pre, post = "b.EncodeFixed32(uint64(", "))" + case descriptor.FieldDescriptorProto_TYPE_BOOL: + // bool needs special handling. + g.P("t := uint64(0)") + g.P("if ", val, " { t = 1 }") + val = "t" + wire = "WireVarint" + pre, post = "b.EncodeVarint(", ")" + case descriptor.FieldDescriptorProto_TYPE_STRING: + wire = "WireBytes" + pre, post = "b.EncodeStringBytes(", ")" + case descriptor.FieldDescriptorProto_TYPE_GROUP: + wire = "WireStartGroup" + pre, post = "b.Marshal(", ")" + canFail = true + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + wire = "WireBytes" + pre, post = "b.EncodeMessage(", ")" + canFail = true + case descriptor.FieldDescriptorProto_TYPE_BYTES: + wire = "WireBytes" + pre, post = "b.EncodeRawBytes(", ")" + case descriptor.FieldDescriptorProto_TYPE_SINT32: + wire = "WireVarint" + pre, post = "b.EncodeZigzag32(uint64(", "))" + case descriptor.FieldDescriptorProto_TYPE_SINT64: + wire = "WireVarint" + pre, post = "b.EncodeZigzag64(uint64(", "))" + default: + g.Fail("unhandled oneof field type ", field.Type.String()) + } + fieldWire[field] = wire + g.P("_ = b.EncodeVarint(", field.Number, "<<3|", g.Pkg["proto"], ".", wire, ")") + if *field.Type == descriptor.FieldDescriptorProto_TYPE_BYTES && gogoproto.IsCustomType(field) { + g.P(`dAtA, err := `, val, `.Marshal()`) + g.P(`if err != nil {`) + g.In() + g.P(`return err`) + g.Out() + g.P(`}`) + val = "dAtA" + } else if gogoproto.IsStdTime(field) { + pkg := g.useTypes() + if gogoproto.IsNullable(field) { + g.P(`dAtA, err := `, pkg, `.StdTimeMarshal(*`, val, `)`) + } else { + g.P(`dAtA, err := `, pkg, `.StdTimeMarshal(`, val, `)`) + } + g.P(`if err != nil {`) + g.In() + g.P(`return err`) + g.Out() + g.P(`}`) + val = "dAtA" + pre, post = "b.EncodeRawBytes(", ")" + } else if gogoproto.IsStdDuration(field) { + pkg := g.useTypes() + if gogoproto.IsNullable(field) { + g.P(`dAtA, err := `, pkg, `.StdDurationMarshal(*`, val, `)`) + } else { + g.P(`dAtA, err := `, pkg, `.StdDurationMarshal(`, val, `)`) + } + g.P(`if err != nil {`) + g.In() + g.P(`return err`) + g.Out() + g.P(`}`) + val = "dAtA" + pre, post = "b.EncodeRawBytes(", ")" + } + if !canFail { + g.P("_ = ", pre, val, post) + } else { + g.P("if err := ", pre, val, post, "; err != nil {") + g.In() + g.P("return err") + g.Out() + g.P("}") + } + if *field.Type == descriptor.FieldDescriptorProto_TYPE_GROUP { + g.P("_ = b.EncodeVarint(", field.Number, "<<3|", g.Pkg["proto"], ".WireEndGroup)") + } + } + g.P("case nil:") + g.P("default: return ", g.Pkg["fmt"], `.Errorf("`, ccTypeName, ".", fname, ` has unexpected type %T", x)`) + g.P("}") + } + g.P("return nil") + g.P("}") + g.P() + + // unmarshaler + g.P("func ", dec, decSig, " {") + g.P("m := msg.(*", ccTypeName, ")") + g.P("switch tag {") + for _, field := range message.Field { + if field.OneofIndex == nil { + continue + } + odp := message.OneofDecl[int(*field.OneofIndex)] + g.P("case ", field.Number, ": // ", odp.GetName(), ".", *field.Name) + g.P("if wire != ", g.Pkg["proto"], ".", fieldWire[field], " {") + g.P("return true, ", g.Pkg["proto"], ".ErrInternalBadWireType") + g.P("}") + lhs := "x, err" // overridden for TYPE_MESSAGE and TYPE_GROUP + var dec, cast, cast2 string + switch *field.Type { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: + dec, cast = "b.DecodeFixed64()", g.Pkg["math"]+".Float64frombits" + case descriptor.FieldDescriptorProto_TYPE_FLOAT: + dec, cast, cast2 = "b.DecodeFixed32()", "uint32", g.Pkg["math"]+".Float32frombits" + case descriptor.FieldDescriptorProto_TYPE_INT64: + dec, cast = "b.DecodeVarint()", "int64" + case descriptor.FieldDescriptorProto_TYPE_UINT64: + dec = "b.DecodeVarint()" + case descriptor.FieldDescriptorProto_TYPE_INT32: + dec, cast = "b.DecodeVarint()", "int32" + case descriptor.FieldDescriptorProto_TYPE_FIXED64: + dec = "b.DecodeFixed64()" + case descriptor.FieldDescriptorProto_TYPE_FIXED32: + dec, cast = "b.DecodeFixed32()", "uint32" + case descriptor.FieldDescriptorProto_TYPE_BOOL: + dec = "b.DecodeVarint()" + // handled specially below + case descriptor.FieldDescriptorProto_TYPE_STRING: + dec = "b.DecodeStringBytes()" + case descriptor.FieldDescriptorProto_TYPE_GROUP: + g.P("msg := new(", fieldTypes[field][1:], ")") // drop star + lhs = "err" + dec = "b.DecodeGroup(msg)" + // handled specially below + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + if gogoproto.IsStdTime(field) || gogoproto.IsStdDuration(field) { + dec = "b.DecodeRawBytes(true)" + } else { + g.P("msg := new(", fieldTypes[field][1:], ")") // drop star + lhs = "err" + dec = "b.DecodeMessage(msg)" + } + // handled specially below + case descriptor.FieldDescriptorProto_TYPE_BYTES: + dec = "b.DecodeRawBytes(true)" + case descriptor.FieldDescriptorProto_TYPE_UINT32: + dec, cast = "b.DecodeVarint()", "uint32" + case descriptor.FieldDescriptorProto_TYPE_ENUM: + dec, cast = "b.DecodeVarint()", fieldTypes[field] + case descriptor.FieldDescriptorProto_TYPE_SFIXED32: + dec, cast = "b.DecodeFixed32()", "int32" + case descriptor.FieldDescriptorProto_TYPE_SFIXED64: + dec, cast = "b.DecodeFixed64()", "int64" + case descriptor.FieldDescriptorProto_TYPE_SINT32: + dec, cast = "b.DecodeZigzag32()", "int32" + case descriptor.FieldDescriptorProto_TYPE_SINT64: + dec, cast = "b.DecodeZigzag64()", "int64" + default: + g.Fail("unhandled oneof field type ", field.Type.String()) + } + g.P(lhs, " := ", dec) + val := "x" + if *field.Type == descriptor.FieldDescriptorProto_TYPE_BYTES && gogoproto.IsCustomType(field) { + g.P(`if err != nil {`) + g.In() + g.P(`return true, err`) + g.Out() + g.P(`}`) + _, ctyp, err := GetCustomType(field) + if err != nil { + panic(err) + } + g.P(`var cc `, ctyp) + g.P(`c := &cc`) + g.P(`err = c.Unmarshal(`, val, `)`) + val = "*c" + } else if gogoproto.IsStdTime(field) { + pkg := g.useTypes() + g.P(`if err != nil {`) + g.In() + g.P(`return true, err`) + g.Out() + g.P(`}`) + g.P(`c := new(time.Time)`) + g.P(`if err2 := `, pkg, `.StdTimeUnmarshal(c, `, val, `); err2 != nil {`) + g.In() + g.P(`return true, err`) + g.Out() + g.P(`}`) + val = "c" + } else if gogoproto.IsStdDuration(field) { + pkg := g.useTypes() + g.P(`if err != nil {`) + g.In() + g.P(`return true, err`) + g.Out() + g.P(`}`) + g.P(`c := new(time.Duration)`) + g.P(`if err2 := `, pkg, `.StdDurationUnmarshal(c, `, val, `); err2 != nil {`) + g.In() + g.P(`return true, err`) + g.Out() + g.P(`}`) + val = "c" + } + if cast != "" { + val = cast + "(" + val + ")" + } + if cast2 != "" { + val = cast2 + "(" + val + ")" + } + switch *field.Type { + case descriptor.FieldDescriptorProto_TYPE_BOOL: + val += " != 0" + case descriptor.FieldDescriptorProto_TYPE_GROUP, + descriptor.FieldDescriptorProto_TYPE_MESSAGE: + if !gogoproto.IsStdTime(field) && !gogoproto.IsStdDuration(field) { + val = "msg" + } + } + if gogoproto.IsCastType(field) { + _, typ, err := getCastType(field) + if err != nil { + g.Fail(err.Error()) + } + val = typ + "(" + val + ")" + } + g.P("m.", oneofFieldName[*field.OneofIndex], " = &", oneofTypeName[field], "{", val, "}") + g.P("return true, err") + } + g.P("default: return false, nil") + g.P("}") + g.P("}") + g.P() + + // sizer + g.P("func ", size, sizeSig, " {") + g.P("m := msg.(*", ccTypeName, ")") + for oi, odp := range message.OneofDecl { + g.P("// ", odp.GetName()) + fname := oneofFieldName[int32(oi)] + g.P("switch x := m.", fname, ".(type) {") + for _, field := range message.Field { + if field.OneofIndex == nil || int(*field.OneofIndex) != oi { + continue + } + g.P("case *", oneofTypeName[field], ":") + val := "x." + fieldNames[field] + var wire, varint, fixed string + switch *field.Type { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: + wire = "WireFixed64" + fixed = "8" + case descriptor.FieldDescriptorProto_TYPE_FLOAT: + wire = "WireFixed32" + fixed = "4" + case descriptor.FieldDescriptorProto_TYPE_INT64, + descriptor.FieldDescriptorProto_TYPE_UINT64, + descriptor.FieldDescriptorProto_TYPE_INT32, + descriptor.FieldDescriptorProto_TYPE_UINT32, + descriptor.FieldDescriptorProto_TYPE_ENUM: + wire = "WireVarint" + varint = val + case descriptor.FieldDescriptorProto_TYPE_FIXED64, + descriptor.FieldDescriptorProto_TYPE_SFIXED64: + wire = "WireFixed64" + fixed = "8" + case descriptor.FieldDescriptorProto_TYPE_FIXED32, + descriptor.FieldDescriptorProto_TYPE_SFIXED32: + wire = "WireFixed32" + fixed = "4" + case descriptor.FieldDescriptorProto_TYPE_BOOL: + wire = "WireVarint" + fixed = "1" + case descriptor.FieldDescriptorProto_TYPE_STRING: + wire = "WireBytes" + fixed = "len(" + val + ")" + varint = fixed + case descriptor.FieldDescriptorProto_TYPE_GROUP: + wire = "WireStartGroup" + fixed = g.Pkg["proto"] + ".Size(" + val + ")" + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + wire = "WireBytes" + if gogoproto.IsStdTime(field) { + if gogoproto.IsNullable(field) { + val = "*" + val + } + pkg := g.useTypes() + g.P("s := ", pkg, ".SizeOfStdTime(", val, ")") + } else if gogoproto.IsStdDuration(field) { + if gogoproto.IsNullable(field) { + val = "*" + val + } + pkg := g.useTypes() + g.P("s := ", pkg, ".SizeOfStdDuration(", val, ")") + } else { + g.P("s := ", g.Pkg["proto"], ".Size(", val, ")") + } + fixed = "s" + varint = fixed + case descriptor.FieldDescriptorProto_TYPE_BYTES: + wire = "WireBytes" + if gogoproto.IsCustomType(field) { + fixed = val + ".Size()" + } else { + fixed = "len(" + val + ")" + } + varint = fixed + case descriptor.FieldDescriptorProto_TYPE_SINT32: + wire = "WireVarint" + varint = "(uint32(" + val + ") << 1) ^ uint32((int32(" + val + ") >> 31))" + case descriptor.FieldDescriptorProto_TYPE_SINT64: + wire = "WireVarint" + varint = "uint64(" + val + " << 1) ^ uint64((int64(" + val + ") >> 63))" + default: + g.Fail("unhandled oneof field type ", field.Type.String()) + } + g.P("n += ", g.Pkg["proto"], ".SizeVarint(", field.Number, "<<3|", g.Pkg["proto"], ".", wire, ")") + if varint != "" { + g.P("n += ", g.Pkg["proto"], ".SizeVarint(uint64(", varint, "))") + } + if fixed != "" { + g.P("n += ", fixed) + } + if *field.Type == descriptor.FieldDescriptorProto_TYPE_GROUP { + g.P("n += ", g.Pkg["proto"], ".SizeVarint(", field.Number, "<<3|", g.Pkg["proto"], ".WireEndGroup)") + } + } + g.P("case nil:") + g.P("default:") + g.P("panic(", g.Pkg["fmt"], ".Sprintf(\"proto: unexpected type %T in oneof\", x))") + g.P("}") + } + g.P("return n") + g.P("}") + g.P() + } + + for _, ext := range message.ext { + g.generateExtension(ext) + } + + fullName := strings.Join(message.TypeName(), ".") + if g.file.Package != nil { + fullName = *g.file.Package + "." + fullName + } + + g.addInitf("%s.RegisterType((*%s)(nil), %q)", g.Pkg["proto"], ccTypeName, fullName) + if gogoproto.ImportsGoGoProto(g.file.FileDescriptorProto) && gogoproto.RegistersGolangProto(g.file.FileDescriptorProto) { + g.addInitf("%s.RegisterType((*%s)(nil), %q)", g.Pkg["golang_proto"], ccTypeName, fullName) + } +} + +var escapeChars = [256]byte{ + 'a': '\a', 'b': '\b', 'f': '\f', 'n': '\n', 'r': '\r', 't': '\t', 'v': '\v', '\\': '\\', '"': '"', '\'': '\'', '?': '?', +} + +// unescape reverses the "C" escaping that protoc does for default values of bytes fields. +// It is best effort in that it effectively ignores malformed input. Seemingly invalid escape +// sequences are conveyed, unmodified, into the decoded result. +func unescape(s string) string { + // NB: Sadly, we can't use strconv.Unquote because protoc will escape both + // single and double quotes, but strconv.Unquote only allows one or the + // other (based on actual surrounding quotes of its input argument). + + var out []byte + for len(s) > 0 { + // regular character, or too short to be valid escape + if s[0] != '\\' || len(s) < 2 { + out = append(out, s[0]) + s = s[1:] + } else if c := escapeChars[s[1]]; c != 0 { + // escape sequence + out = append(out, c) + s = s[2:] + } else if s[1] == 'x' || s[1] == 'X' { + // hex escape, e.g. "\x80 + if len(s) < 4 { + // too short to be valid + out = append(out, s[:2]...) + s = s[2:] + continue + } + v, err := strconv.ParseUint(s[2:4], 16, 8) + if err != nil { + out = append(out, s[:4]...) + } else { + out = append(out, byte(v)) + } + s = s[4:] + } else if '0' <= s[1] && s[1] <= '7' { + // octal escape, can vary from 1 to 3 octal digits; e.g., "\0" "\40" or "\164" + // so consume up to 2 more bytes or up to end-of-string + n := len(s[1:]) - len(strings.TrimLeft(s[1:], "01234567")) + if n > 3 { + n = 3 + } + v, err := strconv.ParseUint(s[1:1+n], 8, 8) + if err != nil { + out = append(out, s[:1+n]...) + } else { + out = append(out, byte(v)) + } + s = s[1+n:] + } else { + // bad escape, just propagate the slash as-is + out = append(out, s[0]) + s = s[1:] + } + } + + return string(out) +} + +func (g *Generator) generateExtension(ext *ExtensionDescriptor) { + ccTypeName := ext.DescName() + + extObj := g.ObjectNamed(*ext.Extendee) + var extDesc *Descriptor + if id, ok := extObj.(*ImportedDescriptor); ok { + // This is extending a publicly imported message. + // We need the underlying type for goTag. + extDesc = id.o.(*Descriptor) + } else { + extDesc = extObj.(*Descriptor) + } + extendedType := "*" + g.TypeName(extObj) // always use the original + field := ext.FieldDescriptorProto + fieldType, wireType := g.GoType(ext.parent, field) + tag := g.goTag(extDesc, field, wireType) + g.RecordTypeUse(*ext.Extendee) + if n := ext.FieldDescriptorProto.TypeName; n != nil { + // foreign extension type + g.RecordTypeUse(*n) + } + + typeName := ext.TypeName() + + // Special case for proto2 message sets: If this extension is extending + // proto2_bridge.MessageSet, and its final name component is "message_set_extension", + // then drop that last component. + mset := false + if extendedType == "*proto2_bridge.MessageSet" && typeName[len(typeName)-1] == "message_set_extension" { + typeName = typeName[:len(typeName)-1] + mset = true + } + + // For text formatting, the package must be exactly what the .proto file declares, + // ignoring overrides such as the go_package option, and with no dot/underscore mapping. + extName := strings.Join(typeName, ".") + if g.file.Package != nil { + extName = *g.file.Package + "." + extName + } + + g.P("var ", ccTypeName, " = &", g.Pkg["proto"], ".ExtensionDesc{") + g.In() + g.P("ExtendedType: (", extendedType, ")(nil),") + g.P("ExtensionType: (", fieldType, ")(nil),") + g.P("Field: ", field.Number, ",") + g.P(`Name: "`, extName, `",`) + g.P("Tag: ", tag, ",") + g.P(`Filename: "`, g.file.GetName(), `",`) + + g.Out() + g.P("}") + g.P() + + if mset { + // Generate a bit more code to register with message_set.go. + g.addInitf("%s.RegisterMessageSetType((%s)(nil), %d, %q)", g.Pkg["proto"], fieldType, *field.Number, extName) + if gogoproto.ImportsGoGoProto(g.file.FileDescriptorProto) && gogoproto.RegistersGolangProto(g.file.FileDescriptorProto) { + g.addInitf("%s.RegisterMessageSetType((%s)(nil), %d, %q)", g.Pkg["golang_proto"], fieldType, *field.Number, extName) + } + } + + g.file.addExport(ext, constOrVarSymbol{ccTypeName, "var", ""}) +} + +func (g *Generator) generateInitFunction() { + for _, enum := range g.file.enum { + g.generateEnumRegistration(enum) + } + for _, d := range g.file.desc { + for _, ext := range d.ext { + g.generateExtensionRegistration(ext) + } + } + for _, ext := range g.file.ext { + g.generateExtensionRegistration(ext) + } + if len(g.init) == 0 { + return + } + g.P("func init() {") + g.In() + for _, l := range g.init { + g.P(l) + } + g.Out() + g.P("}") + g.init = nil +} + +func (g *Generator) generateFileDescriptor(file *FileDescriptor) { + // Make a copy and trim source_code_info data. + // TODO: Trim this more when we know exactly what we need. + pb := proto.Clone(file.FileDescriptorProto).(*descriptor.FileDescriptorProto) + pb.SourceCodeInfo = nil + + b, err := proto.Marshal(pb) + if err != nil { + g.Fail(err.Error()) + } + + var buf bytes.Buffer + w, _ := gzip.NewWriterLevel(&buf, gzip.BestCompression) + w.Write(b) + w.Close() + b = buf.Bytes() + + v := file.VarName() + g.P() + g.P("func init() { ", g.Pkg["proto"], ".RegisterFile(", strconv.Quote(*file.Name), ", ", v, ") }") + if gogoproto.ImportsGoGoProto(g.file.FileDescriptorProto) && gogoproto.RegistersGolangProto(g.file.FileDescriptorProto) { + g.P("func init() { ", g.Pkg["golang_proto"], ".RegisterFile(", strconv.Quote(*file.Name), ", ", v, ") }") + } + g.P("var ", v, " = []byte{") + g.In() + g.P("// ", len(b), " bytes of a gzipped FileDescriptorProto") + for len(b) > 0 { + n := 16 + if n > len(b) { + n = len(b) + } + + s := "" + for _, c := range b[:n] { + s += fmt.Sprintf("0x%02x,", c) + } + g.P(s) + + b = b[n:] + } + g.Out() + g.P("}") +} + +func (g *Generator) generateEnumRegistration(enum *EnumDescriptor) { + // // We always print the full (proto-world) package name here. + pkg := enum.File().GetPackage() + if pkg != "" { + pkg += "." + } + // The full type name + typeName := enum.TypeName() + // The full type name, CamelCased. + ccTypeName := CamelCaseSlice(typeName) + g.addInitf("%s.RegisterEnum(%q, %[3]s_name, %[3]s_value)", g.Pkg["proto"], pkg+ccTypeName, ccTypeName) + if gogoproto.ImportsGoGoProto(g.file.FileDescriptorProto) && gogoproto.RegistersGolangProto(g.file.FileDescriptorProto) { + g.addInitf("%s.RegisterEnum(%q, %[3]s_name, %[3]s_value)", g.Pkg["golang_proto"], pkg+ccTypeName, ccTypeName) + } +} + +func (g *Generator) generateExtensionRegistration(ext *ExtensionDescriptor) { + g.addInitf("%s.RegisterExtension(%s)", g.Pkg["proto"], ext.DescName()) + if gogoproto.ImportsGoGoProto(g.file.FileDescriptorProto) && gogoproto.RegistersGolangProto(g.file.FileDescriptorProto) { + g.addInitf("%s.RegisterExtension(%s)", g.Pkg["golang_proto"], ext.DescName()) + } +} + +// And now lots of helper functions. + +// Is c an ASCII lower-case letter? +func isASCIILower(c byte) bool { + return 'a' <= c && c <= 'z' +} + +// Is c an ASCII digit? +func isASCIIDigit(c byte) bool { + return '0' <= c && c <= '9' +} + +// CamelCase returns the CamelCased name. +// If there is an interior underscore followed by a lower case letter, +// drop the underscore and convert the letter to upper case. +// There is a remote possibility of this rewrite causing a name collision, +// but it's so remote we're prepared to pretend it's nonexistent - since the +// C++ generator lowercases names, it's extremely unlikely to have two fields +// with different capitalizations. +// In short, _my_field_name_2 becomes XMyFieldName_2. +func CamelCase(s string) string { + if s == "" { + return "" + } + t := make([]byte, 0, 32) + i := 0 + if s[0] == '_' { + // Need a capital letter; drop the '_'. + t = append(t, 'X') + i++ + } + // Invariant: if the next letter is lower case, it must be converted + // to upper case. + // That is, we process a word at a time, where words are marked by _ or + // upper case letter. Digits are treated as words. + for ; i < len(s); i++ { + c := s[i] + if c == '_' && i+1 < len(s) && isASCIILower(s[i+1]) { + continue // Skip the underscore in s. + } + if isASCIIDigit(c) { + t = append(t, c) + continue + } + // Assume we have a letter now - if not, it's a bogus identifier. + // The next word is a sequence of characters that must start upper case. + if isASCIILower(c) { + c ^= ' ' // Make it a capital letter. + } + t = append(t, c) // Guaranteed not lower case. + // Accept lower case sequence that follows. + for i+1 < len(s) && isASCIILower(s[i+1]) { + i++ + t = append(t, s[i]) + } + } + return string(t) +} + +// CamelCaseSlice is like CamelCase, but the argument is a slice of strings to +// be joined with "_". +func CamelCaseSlice(elem []string) string { return CamelCase(strings.Join(elem, "_")) } + +// dottedSlice turns a sliced name into a dotted name. +func dottedSlice(elem []string) string { return strings.Join(elem, ".") } + +// Is this field optional? +func isOptional(field *descriptor.FieldDescriptorProto) bool { + return field.Label != nil && *field.Label == descriptor.FieldDescriptorProto_LABEL_OPTIONAL +} + +// Is this field required? +func isRequired(field *descriptor.FieldDescriptorProto) bool { + return field.Label != nil && *field.Label == descriptor.FieldDescriptorProto_LABEL_REQUIRED +} + +// Is this field repeated? +func isRepeated(field *descriptor.FieldDescriptorProto) bool { + return field.Label != nil && *field.Label == descriptor.FieldDescriptorProto_LABEL_REPEATED +} + +// Is this field a scalar numeric type? +func IsScalar(field *descriptor.FieldDescriptorProto) bool { + if field.Type == nil { + return false + } + switch *field.Type { + case descriptor.FieldDescriptorProto_TYPE_DOUBLE, + descriptor.FieldDescriptorProto_TYPE_FLOAT, + descriptor.FieldDescriptorProto_TYPE_INT64, + descriptor.FieldDescriptorProto_TYPE_UINT64, + descriptor.FieldDescriptorProto_TYPE_INT32, + descriptor.FieldDescriptorProto_TYPE_FIXED64, + descriptor.FieldDescriptorProto_TYPE_FIXED32, + descriptor.FieldDescriptorProto_TYPE_BOOL, + descriptor.FieldDescriptorProto_TYPE_UINT32, + descriptor.FieldDescriptorProto_TYPE_ENUM, + descriptor.FieldDescriptorProto_TYPE_SFIXED32, + descriptor.FieldDescriptorProto_TYPE_SFIXED64, + descriptor.FieldDescriptorProto_TYPE_SINT32, + descriptor.FieldDescriptorProto_TYPE_SINT64: + return true + default: + return false + } +} + +// badToUnderscore is the mapping function used to generate Go names from package names, +// which can be dotted in the input .proto file. It replaces non-identifier characters such as +// dot or dash with underscore. +func badToUnderscore(r rune) rune { + if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' { + return r + } + return '_' +} + +// baseName returns the last path element of the name, with the last dotted suffix removed. +func baseName(name string) string { + // First, find the last element + if i := strings.LastIndex(name, "/"); i >= 0 { + name = name[i+1:] + } + // Now drop the suffix + if i := strings.LastIndex(name, "."); i >= 0 { + name = name[0:i] + } + return name +} + +// The SourceCodeInfo message describes the location of elements of a parsed +// .proto file by way of a "path", which is a sequence of integers that +// describe the route from a FileDescriptorProto to the relevant submessage. +// The path alternates between a field number of a repeated field, and an index +// into that repeated field. The constants below define the field numbers that +// are used. +// +// See descriptor.proto for more information about this. +const ( + // tag numbers in FileDescriptorProto + packagePath = 2 // package + messagePath = 4 // message_type + enumPath = 5 // enum_type + // tag numbers in DescriptorProto + messageFieldPath = 2 // field + messageMessagePath = 3 // nested_type + messageEnumPath = 4 // enum_type + messageOneofPath = 8 // oneof_decl + // tag numbers in EnumDescriptorProto + enumValuePath = 2 // value +) diff --git a/vendor/github.com/gogo/protobuf/protoc-gen-gogo/generator/helper.go b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/generator/helper.go new file mode 100644 index 00000000000..d7a406e7ccd --- /dev/null +++ b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/generator/helper.go @@ -0,0 +1,447 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2013, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package generator + +import ( + "bytes" + "go/parser" + "go/printer" + "go/token" + "path" + "strings" + + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/proto" + descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + plugin "github.com/gogo/protobuf/protoc-gen-gogo/plugin" +) + +func (d *FileDescriptor) Messages() []*Descriptor { + return d.desc +} + +func (d *FileDescriptor) Enums() []*EnumDescriptor { + return d.enum +} + +func (d *Descriptor) IsGroup() bool { + return d.group +} + +func (g *Generator) IsGroup(field *descriptor.FieldDescriptorProto) bool { + if d, ok := g.typeNameToObject[field.GetTypeName()].(*Descriptor); ok { + return d.IsGroup() + } + return false +} + +func (g *Generator) TypeNameByObject(typeName string) Object { + o, ok := g.typeNameToObject[typeName] + if !ok { + g.Fail("can't find object with type", typeName) + } + return o +} + +func (g *Generator) OneOfTypeName(message *Descriptor, field *descriptor.FieldDescriptorProto) string { + typeName := message.TypeName() + ccTypeName := CamelCaseSlice(typeName) + fieldName := g.GetOneOfFieldName(message, field) + tname := ccTypeName + "_" + fieldName + // It is possible for this to collide with a message or enum + // nested in this message. Check for collisions. + ok := true + for _, desc := range message.nested { + if strings.Join(desc.TypeName(), "_") == tname { + ok = false + break + } + } + for _, enum := range message.enums { + if strings.Join(enum.TypeName(), "_") == tname { + ok = false + break + } + } + if !ok { + tname += "_" + } + return tname +} + +type PluginImports interface { + NewImport(pkg string) Single + GenerateImports(file *FileDescriptor) +} + +type pluginImports struct { + generator *Generator + singles []Single +} + +func NewPluginImports(generator *Generator) *pluginImports { + return &pluginImports{generator, make([]Single, 0)} +} + +func (this *pluginImports) NewImport(pkg string) Single { + imp := newImportedPackage(this.generator.ImportPrefix, pkg) + this.singles = append(this.singles, imp) + return imp +} + +func (this *pluginImports) GenerateImports(file *FileDescriptor) { + for _, s := range this.singles { + if s.IsUsed() { + this.generator.PrintImport(s.Name(), s.Location()) + } + } +} + +type Single interface { + Use() string + IsUsed() bool + Name() string + Location() string +} + +type importedPackage struct { + used bool + pkg string + name string + importPrefix string +} + +func newImportedPackage(importPrefix, pkg string) *importedPackage { + return &importedPackage{ + pkg: pkg, + importPrefix: importPrefix, + } +} + +func (this *importedPackage) Use() string { + if !this.used { + this.name = RegisterUniquePackageName(this.pkg, nil) + this.used = true + } + return this.name +} + +func (this *importedPackage) IsUsed() bool { + return this.used +} + +func (this *importedPackage) Name() string { + return this.name +} + +func (this *importedPackage) Location() string { + return this.importPrefix + this.pkg +} + +func (g *Generator) GetFieldName(message *Descriptor, field *descriptor.FieldDescriptorProto) string { + goTyp, _ := g.GoType(message, field) + fieldname := CamelCase(*field.Name) + if gogoproto.IsCustomName(field) { + fieldname = gogoproto.GetCustomName(field) + } + if gogoproto.IsEmbed(field) { + fieldname = EmbedFieldName(goTyp) + } + if field.OneofIndex != nil { + fieldname = message.OneofDecl[int(*field.OneofIndex)].GetName() + fieldname = CamelCase(fieldname) + } + for _, f := range methodNames { + if f == fieldname { + return fieldname + "_" + } + } + if !gogoproto.IsProtoSizer(message.file, message.DescriptorProto) { + if fieldname == "Size" { + return fieldname + "_" + } + } + return fieldname +} + +func (g *Generator) GetOneOfFieldName(message *Descriptor, field *descriptor.FieldDescriptorProto) string { + goTyp, _ := g.GoType(message, field) + fieldname := CamelCase(*field.Name) + if gogoproto.IsCustomName(field) { + fieldname = gogoproto.GetCustomName(field) + } + if gogoproto.IsEmbed(field) { + fieldname = EmbedFieldName(goTyp) + } + for _, f := range methodNames { + if f == fieldname { + return fieldname + "_" + } + } + if !gogoproto.IsProtoSizer(message.file, message.DescriptorProto) { + if fieldname == "Size" { + return fieldname + "_" + } + } + return fieldname +} + +func (g *Generator) IsMap(field *descriptor.FieldDescriptorProto) bool { + if !field.IsMessage() { + return false + } + byName := g.ObjectNamed(field.GetTypeName()) + desc, ok := byName.(*Descriptor) + if byName == nil || !ok || !desc.GetOptions().GetMapEntry() { + return false + } + return true +} + +func (g *Generator) GetMapKeyField(field, keyField *descriptor.FieldDescriptorProto) *descriptor.FieldDescriptorProto { + if !gogoproto.IsCastKey(field) { + return keyField + } + keyField = proto.Clone(keyField).(*descriptor.FieldDescriptorProto) + if keyField.Options == nil { + keyField.Options = &descriptor.FieldOptions{} + } + keyType := gogoproto.GetCastKey(field) + if err := proto.SetExtension(keyField.Options, gogoproto.E_Casttype, &keyType); err != nil { + g.Fail(err.Error()) + } + return keyField +} + +func (g *Generator) GetMapValueField(field, valField *descriptor.FieldDescriptorProto) *descriptor.FieldDescriptorProto { + if gogoproto.IsCustomType(field) && gogoproto.IsCastValue(field) { + g.Fail("cannot have a customtype and casttype: ", field.String()) + } + valField = proto.Clone(valField).(*descriptor.FieldDescriptorProto) + if valField.Options == nil { + valField.Options = &descriptor.FieldOptions{} + } + + stdtime := gogoproto.IsStdTime(field) + if stdtime { + if err := proto.SetExtension(valField.Options, gogoproto.E_Stdtime, &stdtime); err != nil { + g.Fail(err.Error()) + } + } + + stddur := gogoproto.IsStdDuration(field) + if stddur { + if err := proto.SetExtension(valField.Options, gogoproto.E_Stdduration, &stddur); err != nil { + g.Fail(err.Error()) + } + } + + if valType := gogoproto.GetCastValue(field); len(valType) > 0 { + if err := proto.SetExtension(valField.Options, gogoproto.E_Casttype, &valType); err != nil { + g.Fail(err.Error()) + } + } + if valType := gogoproto.GetCustomType(field); len(valType) > 0 { + if err := proto.SetExtension(valField.Options, gogoproto.E_Customtype, &valType); err != nil { + g.Fail(err.Error()) + } + } + + nullable := gogoproto.IsNullable(field) + if err := proto.SetExtension(valField.Options, gogoproto.E_Nullable, &nullable); err != nil { + g.Fail(err.Error()) + } + return valField +} + +// GoMapValueTypes returns the map value Go type and the alias map value Go type (for casting), taking into +// account whether the map is nullable or the value is a message. +func GoMapValueTypes(mapField, valueField *descriptor.FieldDescriptorProto, goValueType, goValueAliasType string) (nullable bool, outGoType string, outGoAliasType string) { + nullable = gogoproto.IsNullable(mapField) && (valueField.IsMessage() || gogoproto.IsCustomType(mapField)) + if nullable { + // ensure the non-aliased Go value type is a pointer for consistency + if strings.HasPrefix(goValueType, "*") { + outGoType = goValueType + } else { + outGoType = "*" + goValueType + } + outGoAliasType = goValueAliasType + } else { + outGoType = strings.Replace(goValueType, "*", "", 1) + outGoAliasType = strings.Replace(goValueAliasType, "*", "", 1) + } + return +} + +func GoTypeToName(goTyp string) string { + return strings.Replace(strings.Replace(goTyp, "*", "", -1), "[]", "", -1) +} + +func EmbedFieldName(goTyp string) string { + goTyp = GoTypeToName(goTyp) + goTyps := strings.Split(goTyp, ".") + if len(goTyps) == 1 { + return goTyp + } + if len(goTyps) == 2 { + return goTyps[1] + } + panic("unreachable") +} + +func (g *Generator) GeneratePlugin(p Plugin) { + plugins = []Plugin{p} + p.Init(g) + // Generate the output. The generator runs for every file, even the files + // that we don't generate output for, so that we can collate the full list + // of exported symbols to support public imports. + genFileMap := make(map[*FileDescriptor]bool, len(g.genFiles)) + for _, file := range g.genFiles { + genFileMap[file] = true + } + for _, file := range g.allFiles { + g.Reset() + g.writeOutput = genFileMap[file] + g.generatePlugin(file, p) + if !g.writeOutput { + continue + } + g.Response.File = append(g.Response.File, &plugin.CodeGeneratorResponse_File{ + Name: proto.String(file.goFileName()), + Content: proto.String(g.String()), + }) + } +} + +func (g *Generator) SetFile(file *descriptor.FileDescriptorProto) { + g.file = g.FileOf(file) +} + +func (g *Generator) generatePlugin(file *FileDescriptor, p Plugin) { + g.writtenImports = make(map[string]bool) + g.file = g.FileOf(file.FileDescriptorProto) + g.usedPackages = make(map[string]bool) + + // Run the plugins before the imports so we know which imports are necessary. + p.Generate(file) + + // Generate header and imports last, though they appear first in the output. + rem := g.Buffer + g.Buffer = new(bytes.Buffer) + g.generateHeader() + p.GenerateImports(g.file) + g.generateImports() + if !g.writeOutput { + return + } + g.Write(rem.Bytes()) + + // Reformat generated code. + contents := string(g.Buffer.Bytes()) + fset := token.NewFileSet() + ast, err := parser.ParseFile(fset, "", g, parser.ParseComments) + if err != nil { + g.Fail("bad Go source code was generated:", contents, err.Error()) + return + } + g.Reset() + err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(g, fset, ast) + if err != nil { + g.Fail("generated Go source code could not be reformatted:", err.Error()) + } +} + +func GetCustomType(field *descriptor.FieldDescriptorProto) (packageName string, typ string, err error) { + return getCustomType(field) +} + +func getCustomType(field *descriptor.FieldDescriptorProto) (packageName string, typ string, err error) { + if field.Options != nil { + var v interface{} + v, err = proto.GetExtension(field.Options, gogoproto.E_Customtype) + if err == nil && v.(*string) != nil { + ctype := *(v.(*string)) + packageName, typ = splitCPackageType(ctype) + return packageName, typ, nil + } + } + return "", "", err +} + +func splitCPackageType(ctype string) (packageName string, typ string) { + ss := strings.Split(ctype, ".") + if len(ss) == 1 { + return "", ctype + } + packageName = strings.Join(ss[0:len(ss)-1], ".") + typeName := ss[len(ss)-1] + importStr := strings.Map(badToUnderscore, packageName) + typ = importStr + "." + typeName + return packageName, typ +} + +func getCastType(field *descriptor.FieldDescriptorProto) (packageName string, typ string, err error) { + if field.Options != nil { + var v interface{} + v, err = proto.GetExtension(field.Options, gogoproto.E_Casttype) + if err == nil && v.(*string) != nil { + ctype := *(v.(*string)) + packageName, typ = splitCPackageType(ctype) + return packageName, typ, nil + } + } + return "", "", err +} + +func FileName(file *FileDescriptor) string { + fname := path.Base(file.FileDescriptorProto.GetName()) + fname = strings.Replace(fname, ".proto", "", -1) + fname = strings.Replace(fname, "-", "_", -1) + fname = strings.Replace(fname, ".", "_", -1) + return CamelCase(fname) +} + +func (g *Generator) AllFiles() *descriptor.FileDescriptorSet { + set := &descriptor.FileDescriptorSet{} + set.File = make([]*descriptor.FileDescriptorProto, len(g.allFiles)) + for i := range g.allFiles { + set.File[i] = g.allFiles[i].FileDescriptorProto + } + return set +} + +func (d *Descriptor) Path() string { + return d.path +} + +func (g *Generator) useTypes() string { + pkg := strings.Map(badToUnderscore, "github.com/gogo/protobuf/types") + g.customImports = append(g.customImports, "github.com/gogo/protobuf/types") + return pkg +} diff --git a/vendor/github.com/gogo/protobuf/protoc-gen-gogo/grpc/grpc.go b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/grpc/grpc.go new file mode 100644 index 00000000000..06abe9b6af2 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/grpc/grpc.go @@ -0,0 +1,462 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2015 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Package grpc outputs gRPC service descriptions in Go code. +// It runs as a plugin for the Go protocol buffer compiler plugin. +// It is linked in to protoc-gen-go. +package grpc + +import ( + "fmt" + "strconv" + "strings" + + pb "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" +) + +// generatedCodeVersion indicates a version of the generated code. +// It is incremented whenever an incompatibility between the generated code and +// the grpc package is introduced; the generated code references +// a constant, grpc.SupportPackageIsVersionN (where N is generatedCodeVersion). +const generatedCodeVersion = 4 + +// Paths for packages used by code generated in this file, +// relative to the import_prefix of the generator.Generator. +const ( + contextPkgPath = "golang.org/x/net/context" + grpcPkgPath = "google.golang.org/grpc" +) + +func init() { + generator.RegisterPlugin(new(grpc)) +} + +// grpc is an implementation of the Go protocol buffer compiler's +// plugin architecture. It generates bindings for gRPC support. +type grpc struct { + gen *generator.Generator +} + +// Name returns the name of this plugin, "grpc". +func (g *grpc) Name() string { + return "grpc" +} + +// The names for packages imported in the generated code. +// They may vary from the final path component of the import path +// if the name is used by other packages. +var ( + contextPkg string + grpcPkg string +) + +// Init initializes the plugin. +func (g *grpc) Init(gen *generator.Generator) { + g.gen = gen + contextPkg = generator.RegisterUniquePackageName("context", nil) + grpcPkg = generator.RegisterUniquePackageName("grpc", nil) +} + +// Given a type name defined in a .proto, return its object. +// Also record that we're using it, to guarantee the associated import. +func (g *grpc) objectNamed(name string) generator.Object { + g.gen.RecordTypeUse(name) + return g.gen.ObjectNamed(name) +} + +// Given a type name defined in a .proto, return its name as we will print it. +func (g *grpc) typeName(str string) string { + return g.gen.TypeName(g.objectNamed(str)) +} + +// P forwards to g.gen.P. +func (g *grpc) P(args ...interface{}) { g.gen.P(args...) } + +// Generate generates code for the services in the given file. +func (g *grpc) Generate(file *generator.FileDescriptor) { + if len(file.FileDescriptorProto.Service) == 0 { + return + } + + g.P("// Reference imports to suppress errors if they are not otherwise used.") + g.P("var _ ", contextPkg, ".Context") + g.P("var _ ", grpcPkg, ".ClientConn") + g.P() + + // Assert version compatibility. + g.P("// This is a compile-time assertion to ensure that this generated file") + g.P("// is compatible with the grpc package it is being compiled against.") + g.P("const _ = ", grpcPkg, ".SupportPackageIsVersion", generatedCodeVersion) + g.P() + + for i, service := range file.FileDescriptorProto.Service { + g.generateService(file, service, i) + } +} + +// GenerateImports generates the import declaration for this file. +func (g *grpc) GenerateImports(file *generator.FileDescriptor) { + if len(file.FileDescriptorProto.Service) == 0 { + return + } + imports := generator.NewPluginImports(g.gen) + for _, i := range []string{contextPkgPath, grpcPkgPath} { + imports.NewImport(i).Use() + } + imports.GenerateImports(file) +} + +// reservedClientName records whether a client name is reserved on the client side. +var reservedClientName = map[string]bool{ +// TODO: do we need any in gRPC? +} + +func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] } + +// generateService generates all the code for the named service. +func (g *grpc) generateService(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto, index int) { + path := fmt.Sprintf("6,%d", index) // 6 means service. + + origServName := service.GetName() + fullServName := origServName + if pkg := file.GetPackage(); pkg != "" { + fullServName = pkg + "." + fullServName + } + servName := generator.CamelCase(origServName) + + g.P() + g.P("// Client API for ", servName, " service") + g.P() + + // Client interface. + g.P("type ", servName, "Client interface {") + for i, method := range service.Method { + g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service. + g.P(g.generateClientSignature(servName, method)) + } + g.P("}") + g.P() + + // Client structure. + g.P("type ", unexport(servName), "Client struct {") + g.P("cc *", grpcPkg, ".ClientConn") + g.P("}") + g.P() + + // NewClient factory. + g.P("func New", servName, "Client (cc *", grpcPkg, ".ClientConn) ", servName, "Client {") + g.P("return &", unexport(servName), "Client{cc}") + g.P("}") + g.P() + + var methodIndex, streamIndex int + serviceDescVar := "_" + servName + "_serviceDesc" + // Client method implementations. + for _, method := range service.Method { + var descExpr string + if !method.GetServerStreaming() && !method.GetClientStreaming() { + // Unary RPC method + descExpr = fmt.Sprintf("&%s.Methods[%d]", serviceDescVar, methodIndex) + methodIndex++ + } else { + // Streaming RPC method + descExpr = fmt.Sprintf("&%s.Streams[%d]", serviceDescVar, streamIndex) + streamIndex++ + } + g.generateClientMethod(servName, fullServName, serviceDescVar, method, descExpr) + } + + g.P("// Server API for ", servName, " service") + g.P() + + // Server interface. + serverType := servName + "Server" + g.P("type ", serverType, " interface {") + for i, method := range service.Method { + g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service. + g.P(g.generateServerSignature(servName, method)) + } + g.P("}") + g.P() + + // Server registration. + g.P("func Register", servName, "Server(s *", grpcPkg, ".Server, srv ", serverType, ") {") + g.P("s.RegisterService(&", serviceDescVar, `, srv)`) + g.P("}") + g.P() + + // Server handler implementations. + var handlerNames []string + for _, method := range service.Method { + hname := g.generateServerMethod(servName, fullServName, method) + handlerNames = append(handlerNames, hname) + } + + // Service descriptor. + g.P("var ", serviceDescVar, " = ", grpcPkg, ".ServiceDesc {") + g.P("ServiceName: ", strconv.Quote(fullServName), ",") + g.P("HandlerType: (*", serverType, ")(nil),") + g.P("Methods: []", grpcPkg, ".MethodDesc{") + for i, method := range service.Method { + if method.GetServerStreaming() || method.GetClientStreaming() { + continue + } + g.P("{") + g.P("MethodName: ", strconv.Quote(method.GetName()), ",") + g.P("Handler: ", handlerNames[i], ",") + g.P("},") + } + g.P("},") + g.P("Streams: []", grpcPkg, ".StreamDesc{") + for i, method := range service.Method { + if !method.GetServerStreaming() && !method.GetClientStreaming() { + continue + } + g.P("{") + g.P("StreamName: ", strconv.Quote(method.GetName()), ",") + g.P("Handler: ", handlerNames[i], ",") + if method.GetServerStreaming() { + g.P("ServerStreams: true,") + } + if method.GetClientStreaming() { + g.P("ClientStreams: true,") + } + g.P("},") + } + g.P("},") + g.P("Metadata: \"", file.GetName(), "\",") + g.P("}") + g.P() +} + +// generateClientSignature returns the client-side signature for a method. +func (g *grpc) generateClientSignature(servName string, method *pb.MethodDescriptorProto) string { + origMethName := method.GetName() + methName := generator.CamelCase(origMethName) + if reservedClientName[methName] { + methName += "_" + } + reqArg := ", in *" + g.typeName(method.GetInputType()) + if method.GetClientStreaming() { + reqArg = "" + } + respName := "*" + g.typeName(method.GetOutputType()) + if method.GetServerStreaming() || method.GetClientStreaming() { + respName = servName + "_" + generator.CamelCase(origMethName) + "Client" + } + return fmt.Sprintf("%s(ctx %s.Context%s, opts ...%s.CallOption) (%s, error)", methName, contextPkg, reqArg, grpcPkg, respName) +} + +func (g *grpc) generateClientMethod(servName, fullServName, serviceDescVar string, method *pb.MethodDescriptorProto, descExpr string) { + sname := fmt.Sprintf("/%s/%s", fullServName, method.GetName()) + methName := generator.CamelCase(method.GetName()) + inType := g.typeName(method.GetInputType()) + outType := g.typeName(method.GetOutputType()) + + g.P("func (c *", unexport(servName), "Client) ", g.generateClientSignature(servName, method), "{") + if !method.GetServerStreaming() && !method.GetClientStreaming() { + g.P("out := new(", outType, ")") + // TODO: Pass descExpr to Invoke. + g.P("err := ", grpcPkg, `.Invoke(ctx, "`, sname, `", in, out, c.cc, opts...)`) + g.P("if err != nil { return nil, err }") + g.P("return out, nil") + g.P("}") + g.P() + return + } + streamType := unexport(servName) + methName + "Client" + g.P("stream, err := ", grpcPkg, ".NewClientStream(ctx, ", descExpr, `, c.cc, "`, sname, `", opts...)`) + g.P("if err != nil { return nil, err }") + g.P("x := &", streamType, "{stream}") + if !method.GetClientStreaming() { + g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }") + g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }") + } + g.P("return x, nil") + g.P("}") + g.P() + + genSend := method.GetClientStreaming() + genRecv := method.GetServerStreaming() + genCloseAndRecv := !method.GetServerStreaming() + + // Stream auxiliary types and methods. + g.P("type ", servName, "_", methName, "Client interface {") + if genSend { + g.P("Send(*", inType, ") error") + } + if genRecv { + g.P("Recv() (*", outType, ", error)") + } + if genCloseAndRecv { + g.P("CloseAndRecv() (*", outType, ", error)") + } + g.P(grpcPkg, ".ClientStream") + g.P("}") + g.P() + + g.P("type ", streamType, " struct {") + g.P(grpcPkg, ".ClientStream") + g.P("}") + g.P() + + if genSend { + g.P("func (x *", streamType, ") Send(m *", inType, ") error {") + g.P("return x.ClientStream.SendMsg(m)") + g.P("}") + g.P() + } + if genRecv { + g.P("func (x *", streamType, ") Recv() (*", outType, ", error) {") + g.P("m := new(", outType, ")") + g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }") + g.P("return m, nil") + g.P("}") + g.P() + } + if genCloseAndRecv { + g.P("func (x *", streamType, ") CloseAndRecv() (*", outType, ", error) {") + g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }") + g.P("m := new(", outType, ")") + g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }") + g.P("return m, nil") + g.P("}") + g.P() + } +} + +// generateServerSignature returns the server-side signature for a method. +func (g *grpc) generateServerSignature(servName string, method *pb.MethodDescriptorProto) string { + origMethName := method.GetName() + methName := generator.CamelCase(origMethName) + if reservedClientName[methName] { + methName += "_" + } + + var reqArgs []string + ret := "error" + if !method.GetServerStreaming() && !method.GetClientStreaming() { + reqArgs = append(reqArgs, contextPkg+".Context") + ret = "(*" + g.typeName(method.GetOutputType()) + ", error)" + } + if !method.GetClientStreaming() { + reqArgs = append(reqArgs, "*"+g.typeName(method.GetInputType())) + } + if method.GetServerStreaming() || method.GetClientStreaming() { + reqArgs = append(reqArgs, servName+"_"+generator.CamelCase(origMethName)+"Server") + } + + return methName + "(" + strings.Join(reqArgs, ", ") + ") " + ret +} + +func (g *grpc) generateServerMethod(servName, fullServName string, method *pb.MethodDescriptorProto) string { + methName := generator.CamelCase(method.GetName()) + hname := fmt.Sprintf("_%s_%s_Handler", servName, methName) + inType := g.typeName(method.GetInputType()) + outType := g.typeName(method.GetOutputType()) + + if !method.GetServerStreaming() && !method.GetClientStreaming() { + g.P("func ", hname, "(srv interface{}, ctx ", contextPkg, ".Context, dec func(interface{}) error, interceptor ", grpcPkg, ".UnaryServerInterceptor) (interface{}, error) {") + g.P("in := new(", inType, ")") + g.P("if err := dec(in); err != nil { return nil, err }") + g.P("if interceptor == nil { return srv.(", servName, "Server).", methName, "(ctx, in) }") + g.P("info := &", grpcPkg, ".UnaryServerInfo{") + g.P("Server: srv,") + g.P("FullMethod: ", strconv.Quote(fmt.Sprintf("/%s/%s", fullServName, methName)), ",") + g.P("}") + g.P("handler := func(ctx ", contextPkg, ".Context, req interface{}) (interface{}, error) {") + g.P("return srv.(", servName, "Server).", methName, "(ctx, req.(*", inType, "))") + g.P("}") + g.P("return interceptor(ctx, in, info, handler)") + g.P("}") + g.P() + return hname + } + streamType := unexport(servName) + methName + "Server" + g.P("func ", hname, "(srv interface{}, stream ", grpcPkg, ".ServerStream) error {") + if !method.GetClientStreaming() { + g.P("m := new(", inType, ")") + g.P("if err := stream.RecvMsg(m); err != nil { return err }") + g.P("return srv.(", servName, "Server).", methName, "(m, &", streamType, "{stream})") + } else { + g.P("return srv.(", servName, "Server).", methName, "(&", streamType, "{stream})") + } + g.P("}") + g.P() + + genSend := method.GetServerStreaming() + genSendAndClose := !method.GetServerStreaming() + genRecv := method.GetClientStreaming() + + // Stream auxiliary types and methods. + g.P("type ", servName, "_", methName, "Server interface {") + if genSend { + g.P("Send(*", outType, ") error") + } + if genSendAndClose { + g.P("SendAndClose(*", outType, ") error") + } + if genRecv { + g.P("Recv() (*", inType, ", error)") + } + g.P(grpcPkg, ".ServerStream") + g.P("}") + g.P() + + g.P("type ", streamType, " struct {") + g.P(grpcPkg, ".ServerStream") + g.P("}") + g.P() + + if genSend { + g.P("func (x *", streamType, ") Send(m *", outType, ") error {") + g.P("return x.ServerStream.SendMsg(m)") + g.P("}") + g.P() + } + if genSendAndClose { + g.P("func (x *", streamType, ") SendAndClose(m *", outType, ") error {") + g.P("return x.ServerStream.SendMsg(m)") + g.P("}") + g.P() + } + if genRecv { + g.P("func (x *", streamType, ") Recv() (*", inType, ", error) {") + g.P("m := new(", inType, ")") + g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }") + g.P("return m, nil") + g.P("}") + g.P() + } + + return hname +} diff --git a/vendor/github.com/gogo/protobuf/protoc-gen-gogo/plugin/plugin.pb.go b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/plugin/plugin.pb.go new file mode 100644 index 00000000000..c673d503558 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/protoc-gen-gogo/plugin/plugin.pb.go @@ -0,0 +1,292 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: plugin.proto + +/* +Package plugin_go is a generated protocol buffer package. + +It is generated from these files: + plugin.proto + +It has these top-level messages: + Version + CodeGeneratorRequest + CodeGeneratorResponse +*/ +package plugin_go + +import proto "github.com/gogo/protobuf/proto" +import fmt "fmt" +import math "math" +import google_protobuf "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package + +// The version number of protocol compiler. +type Version struct { + Major *int32 `protobuf:"varint,1,opt,name=major" json:"major,omitempty"` + Minor *int32 `protobuf:"varint,2,opt,name=minor" json:"minor,omitempty"` + Patch *int32 `protobuf:"varint,3,opt,name=patch" json:"patch,omitempty"` + // A suffix for alpha, beta or rc release, e.g., "alpha-1", "rc2". It should + // be empty for mainline stable releases. + Suffix *string `protobuf:"bytes,4,opt,name=suffix" json:"suffix,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *Version) Reset() { *m = Version{} } +func (m *Version) String() string { return proto.CompactTextString(m) } +func (*Version) ProtoMessage() {} +func (*Version) Descriptor() ([]byte, []int) { return fileDescriptorPlugin, []int{0} } + +func (m *Version) GetMajor() int32 { + if m != nil && m.Major != nil { + return *m.Major + } + return 0 +} + +func (m *Version) GetMinor() int32 { + if m != nil && m.Minor != nil { + return *m.Minor + } + return 0 +} + +func (m *Version) GetPatch() int32 { + if m != nil && m.Patch != nil { + return *m.Patch + } + return 0 +} + +func (m *Version) GetSuffix() string { + if m != nil && m.Suffix != nil { + return *m.Suffix + } + return "" +} + +// An encoded CodeGeneratorRequest is written to the plugin's stdin. +type CodeGeneratorRequest struct { + // The .proto files that were explicitly listed on the command-line. The + // code generator should generate code only for these files. Each file's + // descriptor will be included in proto_file, below. + FileToGenerate []string `protobuf:"bytes,1,rep,name=file_to_generate,json=fileToGenerate" json:"file_to_generate,omitempty"` + // The generator parameter passed on the command-line. + Parameter *string `protobuf:"bytes,2,opt,name=parameter" json:"parameter,omitempty"` + // FileDescriptorProtos for all files in files_to_generate and everything + // they import. The files will appear in topological order, so each file + // appears before any file that imports it. + // + // protoc guarantees that all proto_files will be written after + // the fields above, even though this is not technically guaranteed by the + // protobuf wire format. This theoretically could allow a plugin to stream + // in the FileDescriptorProtos and handle them one by one rather than read + // the entire set into memory at once. However, as of this writing, this + // is not similarly optimized on protoc's end -- it will store all fields in + // memory at once before sending them to the plugin. + // + // Type names of fields and extensions in the FileDescriptorProto are always + // fully qualified. + ProtoFile []*google_protobuf.FileDescriptorProto `protobuf:"bytes,15,rep,name=proto_file,json=protoFile" json:"proto_file,omitempty"` + // The version number of protocol compiler. + CompilerVersion *Version `protobuf:"bytes,3,opt,name=compiler_version,json=compilerVersion" json:"compiler_version,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *CodeGeneratorRequest) Reset() { *m = CodeGeneratorRequest{} } +func (m *CodeGeneratorRequest) String() string { return proto.CompactTextString(m) } +func (*CodeGeneratorRequest) ProtoMessage() {} +func (*CodeGeneratorRequest) Descriptor() ([]byte, []int) { return fileDescriptorPlugin, []int{1} } + +func (m *CodeGeneratorRequest) GetFileToGenerate() []string { + if m != nil { + return m.FileToGenerate + } + return nil +} + +func (m *CodeGeneratorRequest) GetParameter() string { + if m != nil && m.Parameter != nil { + return *m.Parameter + } + return "" +} + +func (m *CodeGeneratorRequest) GetProtoFile() []*google_protobuf.FileDescriptorProto { + if m != nil { + return m.ProtoFile + } + return nil +} + +func (m *CodeGeneratorRequest) GetCompilerVersion() *Version { + if m != nil { + return m.CompilerVersion + } + return nil +} + +// The plugin writes an encoded CodeGeneratorResponse to stdout. +type CodeGeneratorResponse struct { + // Error message. If non-empty, code generation failed. The plugin process + // should exit with status code zero even if it reports an error in this way. + // + // This should be used to indicate errors in .proto files which prevent the + // code generator from generating correct code. Errors which indicate a + // problem in protoc itself -- such as the input CodeGeneratorRequest being + // unparseable -- should be reported by writing a message to stderr and + // exiting with a non-zero status code. + Error *string `protobuf:"bytes,1,opt,name=error" json:"error,omitempty"` + File []*CodeGeneratorResponse_File `protobuf:"bytes,15,rep,name=file" json:"file,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *CodeGeneratorResponse) Reset() { *m = CodeGeneratorResponse{} } +func (m *CodeGeneratorResponse) String() string { return proto.CompactTextString(m) } +func (*CodeGeneratorResponse) ProtoMessage() {} +func (*CodeGeneratorResponse) Descriptor() ([]byte, []int) { return fileDescriptorPlugin, []int{2} } + +func (m *CodeGeneratorResponse) GetError() string { + if m != nil && m.Error != nil { + return *m.Error + } + return "" +} + +func (m *CodeGeneratorResponse) GetFile() []*CodeGeneratorResponse_File { + if m != nil { + return m.File + } + return nil +} + +// Represents a single generated file. +type CodeGeneratorResponse_File struct { + // The file name, relative to the output directory. The name must not + // contain "." or ".." components and must be relative, not be absolute (so, + // the file cannot lie outside the output directory). "/" must be used as + // the path separator, not "\". + // + // If the name is omitted, the content will be appended to the previous + // file. This allows the generator to break large files into small chunks, + // and allows the generated text to be streamed back to protoc so that large + // files need not reside completely in memory at one time. Note that as of + // this writing protoc does not optimize for this -- it will read the entire + // CodeGeneratorResponse before writing files to disk. + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + // If non-empty, indicates that the named file should already exist, and the + // content here is to be inserted into that file at a defined insertion + // point. This feature allows a code generator to extend the output + // produced by another code generator. The original generator may provide + // insertion points by placing special annotations in the file that look + // like: + // @@protoc_insertion_point(NAME) + // The annotation can have arbitrary text before and after it on the line, + // which allows it to be placed in a comment. NAME should be replaced with + // an identifier naming the point -- this is what other generators will use + // as the insertion_point. Code inserted at this point will be placed + // immediately above the line containing the insertion point (thus multiple + // insertions to the same point will come out in the order they were added). + // The double-@ is intended to make it unlikely that the generated code + // could contain things that look like insertion points by accident. + // + // For example, the C++ code generator places the following line in the + // .pb.h files that it generates: + // // @@protoc_insertion_point(namespace_scope) + // This line appears within the scope of the file's package namespace, but + // outside of any particular class. Another plugin can then specify the + // insertion_point "namespace_scope" to generate additional classes or + // other declarations that should be placed in this scope. + // + // Note that if the line containing the insertion point begins with + // whitespace, the same whitespace will be added to every line of the + // inserted text. This is useful for languages like Python, where + // indentation matters. In these languages, the insertion point comment + // should be indented the same amount as any inserted code will need to be + // in order to work correctly in that context. + // + // The code generator that generates the initial file and the one which + // inserts into it must both run as part of a single invocation of protoc. + // Code generators are executed in the order in which they appear on the + // command line. + // + // If |insertion_point| is present, |name| must also be present. + InsertionPoint *string `protobuf:"bytes,2,opt,name=insertion_point,json=insertionPoint" json:"insertion_point,omitempty"` + // The file contents. + Content *string `protobuf:"bytes,15,opt,name=content" json:"content,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *CodeGeneratorResponse_File) Reset() { *m = CodeGeneratorResponse_File{} } +func (m *CodeGeneratorResponse_File) String() string { return proto.CompactTextString(m) } +func (*CodeGeneratorResponse_File) ProtoMessage() {} +func (*CodeGeneratorResponse_File) Descriptor() ([]byte, []int) { + return fileDescriptorPlugin, []int{2, 0} +} + +func (m *CodeGeneratorResponse_File) GetName() string { + if m != nil && m.Name != nil { + return *m.Name + } + return "" +} + +func (m *CodeGeneratorResponse_File) GetInsertionPoint() string { + if m != nil && m.InsertionPoint != nil { + return *m.InsertionPoint + } + return "" +} + +func (m *CodeGeneratorResponse_File) GetContent() string { + if m != nil && m.Content != nil { + return *m.Content + } + return "" +} + +func init() { + proto.RegisterType((*Version)(nil), "google.protobuf.compiler.Version") + proto.RegisterType((*CodeGeneratorRequest)(nil), "google.protobuf.compiler.CodeGeneratorRequest") + proto.RegisterType((*CodeGeneratorResponse)(nil), "google.protobuf.compiler.CodeGeneratorResponse") + proto.RegisterType((*CodeGeneratorResponse_File)(nil), "google.protobuf.compiler.CodeGeneratorResponse.File") +} + +func init() { proto.RegisterFile("plugin.proto", fileDescriptorPlugin) } + +var fileDescriptorPlugin = []byte{ + // 383 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x74, 0x92, 0xcd, 0x6a, 0xd5, 0x40, + 0x14, 0xc7, 0x89, 0x37, 0xb5, 0xe4, 0xb4, 0x34, 0x65, 0xa8, 0x32, 0x94, 0x2e, 0xe2, 0x45, 0x30, + 0xab, 0x14, 0x8a, 0xe0, 0xbe, 0x15, 0x75, 0xe1, 0xe2, 0x32, 0x88, 0x0b, 0x41, 0x42, 0x4c, 0x4f, + 0xe2, 0x48, 0x32, 0x67, 0x9c, 0x99, 0x88, 0x4f, 0xea, 0x7b, 0xf8, 0x06, 0x32, 0x1f, 0xa9, 0x72, + 0xf1, 0xee, 0xe6, 0xff, 0x3b, 0xf3, 0x71, 0xce, 0x8f, 0x81, 0x53, 0x3d, 0x2d, 0xa3, 0x54, 0x8d, + 0x36, 0xe4, 0x88, 0xf1, 0x91, 0x68, 0x9c, 0x30, 0xa6, 0x2f, 0xcb, 0xd0, 0xf4, 0x34, 0x6b, 0x39, + 0xa1, 0xb9, 0xac, 0x62, 0xe5, 0x7a, 0xad, 0x5c, 0xdf, 0xa3, 0xed, 0x8d, 0xd4, 0x8e, 0x4c, 0xdc, + 0xbd, 0xed, 0xe1, 0xf8, 0x23, 0x1a, 0x2b, 0x49, 0xb1, 0x0b, 0x38, 0x9a, 0xbb, 0x6f, 0x64, 0x78, + 0x56, 0x65, 0xf5, 0x91, 0x88, 0x21, 0x50, 0xa9, 0xc8, 0xf0, 0x47, 0x89, 0xfa, 0xe0, 0xa9, 0xee, + 0x5c, 0xff, 0x95, 0x6f, 0x22, 0x0d, 0x81, 0x3d, 0x85, 0xc7, 0x76, 0x19, 0x06, 0xf9, 0x93, 0xe7, + 0x55, 0x56, 0x17, 0x22, 0xa5, 0xed, 0xef, 0x0c, 0x2e, 0xee, 0xe8, 0x1e, 0xdf, 0xa2, 0x42, 0xd3, + 0x39, 0x32, 0x02, 0xbf, 0x2f, 0x68, 0x1d, 0xab, 0xe1, 0x7c, 0x90, 0x13, 0xb6, 0x8e, 0xda, 0x31, + 0xd6, 0x90, 0x67, 0xd5, 0xa6, 0x2e, 0xc4, 0x99, 0xe7, 0x1f, 0x28, 0x9d, 0x40, 0x76, 0x05, 0x85, + 0xee, 0x4c, 0x37, 0xa3, 0xc3, 0xd8, 0x4a, 0x21, 0xfe, 0x02, 0x76, 0x07, 0x10, 0xc6, 0x69, 0xfd, + 0x29, 0x5e, 0x56, 0x9b, 0xfa, 0xe4, 0xe6, 0x79, 0xb3, 0xaf, 0xe5, 0x8d, 0x9c, 0xf0, 0xf5, 0x83, + 0x80, 0x9d, 0xc7, 0xa2, 0x08, 0x55, 0x5f, 0x61, 0xef, 0xe1, 0x7c, 0x15, 0xd7, 0xfe, 0x88, 0x4e, + 0xc2, 0x78, 0x27, 0x37, 0xcf, 0x9a, 0x43, 0x86, 0x9b, 0x24, 0x4f, 0x94, 0x2b, 0x49, 0x60, 0xfb, + 0x2b, 0x83, 0x27, 0x7b, 0x33, 0x5b, 0x4d, 0xca, 0xa2, 0x77, 0x87, 0xc6, 0x24, 0xcf, 0x85, 0x88, + 0x81, 0xbd, 0x83, 0xfc, 0x9f, 0xe6, 0x5f, 0x1e, 0x7e, 0xf1, 0xbf, 0x97, 0x86, 0xd9, 0x44, 0xb8, + 0xe1, 0xf2, 0x33, 0xe4, 0x61, 0x1e, 0x06, 0xb9, 0xea, 0x66, 0x4c, 0xcf, 0x84, 0x35, 0x7b, 0x01, + 0xa5, 0x54, 0x16, 0x8d, 0x93, 0xa4, 0x5a, 0x4d, 0x52, 0xb9, 0x24, 0xf3, 0xec, 0x01, 0xef, 0x3c, + 0x65, 0x1c, 0x8e, 0x7b, 0x52, 0x0e, 0x95, 0xe3, 0x65, 0xd8, 0xb0, 0xc6, 0xdb, 0x57, 0x70, 0xd5, + 0xd3, 0x7c, 0xb0, 0xbf, 0xdb, 0xd3, 0x5d, 0xf8, 0x9b, 0x41, 0xaf, 0xfd, 0x54, 0xc4, 0x9f, 0xda, + 0x8e, 0xf4, 0x27, 0x00, 0x00, 0xff, 0xff, 0x7a, 0x72, 0x3d, 0x18, 0xb5, 0x02, 0x00, 0x00, +} diff --git a/vendor/github.com/gogo/protobuf/vanity/command/command.go b/vendor/github.com/gogo/protobuf/vanity/command/command.go new file mode 100644 index 00000000000..eeca42ba0d0 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/vanity/command/command.go @@ -0,0 +1,161 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2015, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package command + +import ( + "fmt" + "go/format" + "io/ioutil" + "os" + "strings" + + _ "github.com/gogo/protobuf/plugin/compare" + _ "github.com/gogo/protobuf/plugin/defaultcheck" + _ "github.com/gogo/protobuf/plugin/description" + _ "github.com/gogo/protobuf/plugin/embedcheck" + _ "github.com/gogo/protobuf/plugin/enumstringer" + _ "github.com/gogo/protobuf/plugin/equal" + _ "github.com/gogo/protobuf/plugin/face" + _ "github.com/gogo/protobuf/plugin/gostring" + _ "github.com/gogo/protobuf/plugin/marshalto" + _ "github.com/gogo/protobuf/plugin/oneofcheck" + _ "github.com/gogo/protobuf/plugin/populate" + _ "github.com/gogo/protobuf/plugin/size" + _ "github.com/gogo/protobuf/plugin/stringer" + "github.com/gogo/protobuf/plugin/testgen" + _ "github.com/gogo/protobuf/plugin/union" + _ "github.com/gogo/protobuf/plugin/unmarshal" + "github.com/gogo/protobuf/proto" + "github.com/gogo/protobuf/protoc-gen-gogo/generator" + _ "github.com/gogo/protobuf/protoc-gen-gogo/grpc" + plugin "github.com/gogo/protobuf/protoc-gen-gogo/plugin" +) + +func Read() *plugin.CodeGeneratorRequest { + g := generator.New() + data, err := ioutil.ReadAll(os.Stdin) + if err != nil { + g.Error(err, "reading input") + } + + if err := proto.Unmarshal(data, g.Request); err != nil { + g.Error(err, "parsing input proto") + } + + if len(g.Request.FileToGenerate) == 0 { + g.Fail("no files to generate") + } + return g.Request +} + +// filenameSuffix replaces the .pb.go at the end of each filename. +func GeneratePlugin(req *plugin.CodeGeneratorRequest, p generator.Plugin, filenameSuffix string) *plugin.CodeGeneratorResponse { + g := generator.New() + g.Request = req + if len(g.Request.FileToGenerate) == 0 { + g.Fail("no files to generate") + } + + g.CommandLineParameters(g.Request.GetParameter()) + + g.WrapTypes() + g.SetPackageNames() + g.BuildTypeNameMap() + g.GeneratePlugin(p) + + for i := 0; i < len(g.Response.File); i++ { + g.Response.File[i].Name = proto.String( + strings.Replace(*g.Response.File[i].Name, ".pb.go", filenameSuffix, -1), + ) + } + if err := goformat(g.Response); err != nil { + g.Error(err) + } + return g.Response +} + +func goformat(resp *plugin.CodeGeneratorResponse) error { + for i := 0; i < len(resp.File); i++ { + formatted, err := format.Source([]byte(resp.File[i].GetContent())) + if err != nil { + return fmt.Errorf("go format error: %v", err) + } + fmts := string(formatted) + resp.File[i].Content = &fmts + } + return nil +} + +func Generate(req *plugin.CodeGeneratorRequest) *plugin.CodeGeneratorResponse { + // Begin by allocating a generator. The request and response structures are stored there + // so we can do error handling easily - the response structure contains the field to + // report failure. + g := generator.New() + g.Request = req + + g.CommandLineParameters(g.Request.GetParameter()) + + // Create a wrapped version of the Descriptors and EnumDescriptors that + // point to the file that defines them. + g.WrapTypes() + + g.SetPackageNames() + g.BuildTypeNameMap() + + g.GenerateAllFiles() + + if err := goformat(g.Response); err != nil { + g.Error(err) + } + + testReq := proto.Clone(req).(*plugin.CodeGeneratorRequest) + + testResp := GeneratePlugin(testReq, testgen.NewPlugin(), "pb_test.go") + + for i := 0; i < len(testResp.File); i++ { + if strings.Contains(*testResp.File[i].Content, `//These tests are generated by github.com/gogo/protobuf/plugin/testgen`) { + g.Response.File = append(g.Response.File, testResp.File[i]) + } + } + + return g.Response +} + +func Write(resp *plugin.CodeGeneratorResponse) { + g := generator.New() + // Send back the results. + data, err := proto.Marshal(resp) + if err != nil { + g.Error(err, "failed to marshal output proto") + } + _, err = os.Stdout.Write(data) + if err != nil { + g.Error(err, "failed to write output proto") + } +} diff --git a/vendor/github.com/gogo/protobuf/vanity/enum.go b/vendor/github.com/gogo/protobuf/vanity/enum.go new file mode 100644 index 00000000000..466d07b54eb --- /dev/null +++ b/vendor/github.com/gogo/protobuf/vanity/enum.go @@ -0,0 +1,78 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2015, The GoGo Authors. rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package vanity + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/proto" + descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" +) + +func EnumHasBoolExtension(enum *descriptor.EnumDescriptorProto, extension *proto.ExtensionDesc) bool { + if enum.Options == nil { + return false + } + value, err := proto.GetExtension(enum.Options, extension) + if err != nil { + return false + } + if value == nil { + return false + } + if value.(*bool) == nil { + return false + } + return true +} + +func SetBoolEnumOption(extension *proto.ExtensionDesc, value bool) func(enum *descriptor.EnumDescriptorProto) { + return func(enum *descriptor.EnumDescriptorProto) { + if EnumHasBoolExtension(enum, extension) { + return + } + if enum.Options == nil { + enum.Options = &descriptor.EnumOptions{} + } + if err := proto.SetExtension(enum.Options, extension, &value); err != nil { + panic(err) + } + } +} + +func TurnOffGoEnumPrefix(enum *descriptor.EnumDescriptorProto) { + SetBoolEnumOption(gogoproto.E_GoprotoEnumPrefix, false)(enum) +} + +func TurnOffGoEnumStringer(enum *descriptor.EnumDescriptorProto) { + SetBoolEnumOption(gogoproto.E_GoprotoEnumStringer, false)(enum) +} + +func TurnOnEnumStringer(enum *descriptor.EnumDescriptorProto) { + SetBoolEnumOption(gogoproto.E_EnumStringer, true)(enum) +} diff --git a/vendor/github.com/gogo/protobuf/vanity/field.go b/vendor/github.com/gogo/protobuf/vanity/field.go new file mode 100644 index 00000000000..62cdddfabb4 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/vanity/field.go @@ -0,0 +1,90 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2015, The GoGo Authors. rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package vanity + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/proto" + descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" +) + +func FieldHasBoolExtension(field *descriptor.FieldDescriptorProto, extension *proto.ExtensionDesc) bool { + if field.Options == nil { + return false + } + value, err := proto.GetExtension(field.Options, extension) + if err != nil { + return false + } + if value == nil { + return false + } + if value.(*bool) == nil { + return false + } + return true +} + +func SetBoolFieldOption(extension *proto.ExtensionDesc, value bool) func(field *descriptor.FieldDescriptorProto) { + return func(field *descriptor.FieldDescriptorProto) { + if FieldHasBoolExtension(field, extension) { + return + } + if field.Options == nil { + field.Options = &descriptor.FieldOptions{} + } + if err := proto.SetExtension(field.Options, extension, &value); err != nil { + panic(err) + } + } +} + +func TurnOffNullable(field *descriptor.FieldDescriptorProto) { + if field.IsRepeated() && !field.IsMessage() { + return + } + SetBoolFieldOption(gogoproto.E_Nullable, false)(field) +} + +func TurnOffNullableForNativeTypes(field *descriptor.FieldDescriptorProto) { + if field.IsRepeated() || field.IsMessage() { + return + } + SetBoolFieldOption(gogoproto.E_Nullable, false)(field) +} + +func TurnOffNullableForNativeTypesWithoutDefaultsOnly(field *descriptor.FieldDescriptorProto) { + if field.IsRepeated() || field.IsMessage() { + return + } + if field.DefaultValue != nil { + return + } + SetBoolFieldOption(gogoproto.E_Nullable, false)(field) +} diff --git a/vendor/github.com/gogo/protobuf/vanity/file.go b/vendor/github.com/gogo/protobuf/vanity/file.go new file mode 100644 index 00000000000..e7b56de1f7a --- /dev/null +++ b/vendor/github.com/gogo/protobuf/vanity/file.go @@ -0,0 +1,181 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2015, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package vanity + +import ( + "path/filepath" + + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/proto" + descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" +) + +func NotGoogleProtobufDescriptorProto(file *descriptor.FileDescriptorProto) bool { + // can not just check if file.GetName() == "google/protobuf/descriptor.proto" because we do not want to assume compile path + _, fileName := filepath.Split(file.GetName()) + return !(file.GetPackage() == "google.protobuf" && fileName == "descriptor.proto") +} + +func FilterFiles(files []*descriptor.FileDescriptorProto, f func(file *descriptor.FileDescriptorProto) bool) []*descriptor.FileDescriptorProto { + filtered := make([]*descriptor.FileDescriptorProto, 0, len(files)) + for i := range files { + if !f(files[i]) { + continue + } + filtered = append(filtered, files[i]) + } + return filtered +} + +func FileHasBoolExtension(file *descriptor.FileDescriptorProto, extension *proto.ExtensionDesc) bool { + if file.Options == nil { + return false + } + value, err := proto.GetExtension(file.Options, extension) + if err != nil { + return false + } + if value == nil { + return false + } + if value.(*bool) == nil { + return false + } + return true +} + +func SetBoolFileOption(extension *proto.ExtensionDesc, value bool) func(file *descriptor.FileDescriptorProto) { + return func(file *descriptor.FileDescriptorProto) { + if FileHasBoolExtension(file, extension) { + return + } + if file.Options == nil { + file.Options = &descriptor.FileOptions{} + } + if err := proto.SetExtension(file.Options, extension, &value); err != nil { + panic(err) + } + } +} + +func TurnOffGoGettersAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_GoprotoGettersAll, false)(file) +} + +func TurnOffGoEnumPrefixAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_GoprotoEnumPrefixAll, false)(file) +} + +func TurnOffGoStringerAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_GoprotoStringerAll, false)(file) +} + +func TurnOnVerboseEqualAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_VerboseEqualAll, true)(file) +} + +func TurnOnFaceAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_FaceAll, true)(file) +} + +func TurnOnGoStringAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_GostringAll, true)(file) +} + +func TurnOnPopulateAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_PopulateAll, true)(file) +} + +func TurnOnStringerAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_StringerAll, true)(file) +} + +func TurnOnEqualAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_EqualAll, true)(file) +} + +func TurnOnDescriptionAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_DescriptionAll, true)(file) +} + +func TurnOnTestGenAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_TestgenAll, true)(file) +} + +func TurnOnBenchGenAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_BenchgenAll, true)(file) +} + +func TurnOnMarshalerAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_MarshalerAll, true)(file) +} + +func TurnOnUnmarshalerAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_UnmarshalerAll, true)(file) +} + +func TurnOnStable_MarshalerAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_StableMarshalerAll, true)(file) +} + +func TurnOnSizerAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_SizerAll, true)(file) +} + +func TurnOffGoEnumStringerAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_GoprotoEnumStringerAll, false)(file) +} + +func TurnOnEnumStringerAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_EnumStringerAll, true)(file) +} + +func TurnOnUnsafeUnmarshalerAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_UnsafeUnmarshalerAll, true)(file) +} + +func TurnOnUnsafeMarshalerAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_UnsafeMarshalerAll, true)(file) +} + +func TurnOffGoExtensionsMapAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_GoprotoExtensionsMapAll, false)(file) +} + +func TurnOffGoUnrecognizedAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_GoprotoUnrecognizedAll, false)(file) +} + +func TurnOffGogoImport(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_GogoprotoImport, false)(file) +} + +func TurnOnCompareAll(file *descriptor.FileDescriptorProto) { + SetBoolFileOption(gogoproto.E_CompareAll, true)(file) +} diff --git a/vendor/github.com/gogo/protobuf/vanity/foreach.go b/vendor/github.com/gogo/protobuf/vanity/foreach.go new file mode 100644 index 00000000000..888b6d04d59 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/vanity/foreach.go @@ -0,0 +1,125 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2015, The GoGo Authors. All rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package vanity + +import descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" + +func ForEachFile(files []*descriptor.FileDescriptorProto, f func(file *descriptor.FileDescriptorProto)) { + for _, file := range files { + f(file) + } +} + +func OnlyProto2(files []*descriptor.FileDescriptorProto) []*descriptor.FileDescriptorProto { + outs := make([]*descriptor.FileDescriptorProto, 0, len(files)) + for i, file := range files { + if file.GetSyntax() == "proto3" { + continue + } + outs = append(outs, files[i]) + } + return outs +} + +func OnlyProto3(files []*descriptor.FileDescriptorProto) []*descriptor.FileDescriptorProto { + outs := make([]*descriptor.FileDescriptorProto, 0, len(files)) + for i, file := range files { + if file.GetSyntax() != "proto3" { + continue + } + outs = append(outs, files[i]) + } + return outs +} + +func ForEachMessageInFiles(files []*descriptor.FileDescriptorProto, f func(msg *descriptor.DescriptorProto)) { + for _, file := range files { + ForEachMessage(file.MessageType, f) + } +} + +func ForEachMessage(msgs []*descriptor.DescriptorProto, f func(msg *descriptor.DescriptorProto)) { + for _, msg := range msgs { + f(msg) + ForEachMessage(msg.NestedType, f) + } +} + +func ForEachFieldInFilesExcludingExtensions(files []*descriptor.FileDescriptorProto, f func(field *descriptor.FieldDescriptorProto)) { + for _, file := range files { + ForEachFieldExcludingExtensions(file.MessageType, f) + } +} + +func ForEachFieldInFiles(files []*descriptor.FileDescriptorProto, f func(field *descriptor.FieldDescriptorProto)) { + for _, file := range files { + for _, ext := range file.Extension { + f(ext) + } + ForEachField(file.MessageType, f) + } +} + +func ForEachFieldExcludingExtensions(msgs []*descriptor.DescriptorProto, f func(field *descriptor.FieldDescriptorProto)) { + for _, msg := range msgs { + for _, field := range msg.Field { + f(field) + } + ForEachField(msg.NestedType, f) + } +} + +func ForEachField(msgs []*descriptor.DescriptorProto, f func(field *descriptor.FieldDescriptorProto)) { + for _, msg := range msgs { + for _, field := range msg.Field { + f(field) + } + for _, ext := range msg.Extension { + f(ext) + } + ForEachField(msg.NestedType, f) + } +} + +func ForEachEnumInFiles(files []*descriptor.FileDescriptorProto, f func(enum *descriptor.EnumDescriptorProto)) { + for _, file := range files { + for _, enum := range file.EnumType { + f(enum) + } + } +} + +func ForEachEnum(msgs []*descriptor.DescriptorProto, f func(field *descriptor.EnumDescriptorProto)) { + for _, msg := range msgs { + for _, field := range msg.EnumType { + f(field) + } + ForEachEnum(msg.NestedType, f) + } +} diff --git a/vendor/github.com/gogo/protobuf/vanity/msg.go b/vendor/github.com/gogo/protobuf/vanity/msg.go new file mode 100644 index 00000000000..7ff2b9879e5 --- /dev/null +++ b/vendor/github.com/gogo/protobuf/vanity/msg.go @@ -0,0 +1,142 @@ +// Protocol Buffers for Go with Gadgets +// +// Copyright (c) 2015, The GoGo Authors. rights reserved. +// http://github.com/gogo/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package vanity + +import ( + "github.com/gogo/protobuf/gogoproto" + "github.com/gogo/protobuf/proto" + descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" +) + +func MessageHasBoolExtension(msg *descriptor.DescriptorProto, extension *proto.ExtensionDesc) bool { + if msg.Options == nil { + return false + } + value, err := proto.GetExtension(msg.Options, extension) + if err != nil { + return false + } + if value == nil { + return false + } + if value.(*bool) == nil { + return false + } + return true +} + +func SetBoolMessageOption(extension *proto.ExtensionDesc, value bool) func(msg *descriptor.DescriptorProto) { + return func(msg *descriptor.DescriptorProto) { + if MessageHasBoolExtension(msg, extension) { + return + } + if msg.Options == nil { + msg.Options = &descriptor.MessageOptions{} + } + if err := proto.SetExtension(msg.Options, extension, &value); err != nil { + panic(err) + } + } +} + +func TurnOffGoGetters(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_GoprotoGetters, false)(msg) +} + +func TurnOffGoStringer(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_GoprotoStringer, false)(msg) +} + +func TurnOnVerboseEqual(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_VerboseEqual, true)(msg) +} + +func TurnOnFace(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_Face, true)(msg) +} + +func TurnOnGoString(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_Face, true)(msg) +} + +func TurnOnPopulate(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_Populate, true)(msg) +} + +func TurnOnStringer(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_Stringer, true)(msg) +} + +func TurnOnEqual(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_Equal, true)(msg) +} + +func TurnOnDescription(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_Description, true)(msg) +} + +func TurnOnTestGen(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_Testgen, true)(msg) +} + +func TurnOnBenchGen(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_Benchgen, true)(msg) +} + +func TurnOnMarshaler(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_Marshaler, true)(msg) +} + +func TurnOnUnmarshaler(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_Unmarshaler, true)(msg) +} + +func TurnOnSizer(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_Sizer, true)(msg) +} + +func TurnOnUnsafeUnmarshaler(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_UnsafeUnmarshaler, true)(msg) +} + +func TurnOnUnsafeMarshaler(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_UnsafeMarshaler, true)(msg) +} + +func TurnOffGoExtensionsMap(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_GoprotoExtensionsMap, false)(msg) +} + +func TurnOffGoUnrecognized(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_GoprotoUnrecognized, false)(msg) +} + +func TurnOnCompare(msg *descriptor.DescriptorProto) { + SetBoolMessageOption(gogoproto.E_Compare, true)(msg) +} diff --git a/vendor/github.com/hashicorp/go-immutable-radix/LICENSE b/vendor/github.com/hashicorp/go-immutable-radix/LICENSE new file mode 100644 index 00000000000..e87a115e462 --- /dev/null +++ b/vendor/github.com/hashicorp/go-immutable-radix/LICENSE @@ -0,0 +1,363 @@ +Mozilla Public License, version 2.0 + +1. Definitions + +1.1. "Contributor" + + means each individual or legal entity that creates, contributes to the + creation of, or owns Covered Software. + +1.2. "Contributor Version" + + means the combination of the Contributions of others (if any) used by a + Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + + means Source Code Form to which the initial Contributor has attached the + notice in Exhibit A, the Executable Form of such Source Code Form, and + Modifications of such Source Code Form, in each case including portions + thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + a. that the initial Contributor has attached the notice described in + Exhibit B to the Covered Software; or + + b. that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the terms of + a Secondary License. + +1.6. "Executable Form" + + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + + means a work that combines Covered Software with other material, in a + separate file or files, that is not Covered Software. + +1.8. "License" + + means this document. + +1.9. "Licensable" + + means having the right to grant, to the maximum extent possible, whether + at the time of the initial grant or subsequently, any and all of the + rights conveyed by this License. + +1.10. "Modifications" + + means any of the following: + + a. any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered Software; or + + b. any new file in Source Code Form that contains any Covered Software. + +1.11. "Patent Claims" of a Contributor + + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the License, + by the making, using, selling, offering for sale, having made, import, + or transfer of either its Contributions or its Contributor Version. + +1.12. "Secondary License" + + means either the GNU General Public License, Version 2.0, the GNU Lesser + General Public License, Version 2.1, the GNU Affero General Public + License, Version 3.0, or any later versions of those licenses. + +1.13. "Source Code Form" + + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that controls, is + controlled by, or is under common control with You. For purposes of this + definition, "control" means (a) the power, direct or indirect, to cause + the direction or management of such entity, whether by contract or + otherwise, or (b) ownership of more than fifty percent (50%) of the + outstanding shares or beneficial ownership of such entity. + + +2. License Grants and Conditions + +2.1. Grants + + Each Contributor hereby grants You a world-wide, royalty-free, + non-exclusive license: + + a. under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + + b. under Patent Claims of such Contributor to make, use, sell, offer for + sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + + The licenses granted in Section 2.1 with respect to any Contribution + become effective for each Contribution on the date the Contributor first + distributes such Contribution. + +2.3. Limitations on Grant Scope + + The licenses granted in this Section 2 are the only rights granted under + this License. No additional rights or licenses will be implied from the + distribution or licensing of Covered Software under this License. + Notwithstanding Section 2.1(b) above, no patent license is granted by a + Contributor: + + a. for any code that a Contributor has removed from Covered Software; or + + b. for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + + c. under Patent Claims infringed by Covered Software in the absence of + its Contributions. + + This License does not grant any rights in the trademarks, service marks, + or logos of any Contributor (except as may be necessary to comply with + the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + + No Contributor makes additional grants as a result of Your choice to + distribute the Covered Software under a subsequent version of this + License (see Section 10.2) or under the terms of a Secondary License (if + permitted under the terms of Section 3.3). + +2.5. Representation + + Each Contributor represents that the Contributor believes its + Contributions are its original creation(s) or it has sufficient rights to + grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + + This License is not intended to limit any rights You have under + applicable copyright doctrines of fair use, fair dealing, or other + equivalents. + +2.7. Conditions + + Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in + Section 2.1. + + +3. Responsibilities + +3.1. Distribution of Source Form + + All distribution of Covered Software in Source Code Form, including any + Modifications that You create or to which You contribute, must be under + the terms of this License. You must inform recipients that the Source + Code Form of the Covered Software is governed by the terms of this + License, and how they can obtain a copy of this License. You may not + attempt to alter or restrict the recipients' rights in the Source Code + Form. + +3.2. Distribution of Executable Form + + If You distribute Covered Software in Executable Form then: + + a. such Covered Software must also be made available in Source Code Form, + as described in Section 3.1, and You must inform recipients of the + Executable Form how they can obtain a copy of such Source Code Form by + reasonable means in a timely manner, at a charge no more than the cost + of distribution to the recipient; and + + b. You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter the + recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + + You may create and distribute a Larger Work under terms of Your choice, + provided that You also comply with the requirements of this License for + the Covered Software. If the Larger Work is a combination of Covered + Software with a work governed by one or more Secondary Licenses, and the + Covered Software is not Incompatible With Secondary Licenses, this + License permits You to additionally distribute such Covered Software + under the terms of such Secondary License(s), so that the recipient of + the Larger Work may, at their option, further distribute the Covered + Software under the terms of either this License or such Secondary + License(s). + +3.4. Notices + + You may not remove or alter the substance of any license notices + (including copyright notices, patent notices, disclaimers of warranty, or + limitations of liability) contained within the Source Code Form of the + Covered Software, except that You may alter any license notices to the + extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + + You may choose to offer, and to charge a fee for, warranty, support, + indemnity or liability obligations to one or more recipients of Covered + Software. However, You may do so only on Your own behalf, and not on + behalf of any Contributor. You must make it absolutely clear that any + such warranty, support, indemnity, or liability obligation is offered by + You alone, and You hereby agree to indemnify every Contributor for any + liability incurred by such Contributor as a result of warranty, support, + indemnity or liability terms You offer. You may include additional + disclaimers of warranty and limitations of liability specific to any + jurisdiction. + +4. Inability to Comply Due to Statute or Regulation + + If it is impossible for You to comply with any of the terms of this License + with respect to some or all of the Covered Software due to statute, + judicial order, or regulation then You must: (a) comply with the terms of + this License to the maximum extent possible; and (b) describe the + limitations and the code they affect. Such description must be placed in a + text file included with all distributions of the Covered Software under + this License. Except to the extent prohibited by statute or regulation, + such description must be sufficiently detailed for a recipient of ordinary + skill to be able to understand it. + +5. Termination + +5.1. The rights granted under this License will terminate automatically if You + fail to comply with any of its terms. However, if You become compliant, + then the rights granted under this License from a particular Contributor + are reinstated (a) provisionally, unless and until such Contributor + explicitly and finally terminates Your grants, and (b) on an ongoing + basis, if such Contributor fails to notify You of the non-compliance by + some reasonable means prior to 60 days after You have come back into + compliance. Moreover, Your grants from a particular Contributor are + reinstated on an ongoing basis if such Contributor notifies You of the + non-compliance by some reasonable means, this is the first time You have + received notice of non-compliance with this License from such + Contributor, and You become compliant prior to 30 days after Your receipt + of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent + infringement claim (excluding declaratory judgment actions, + counter-claims, and cross-claims) alleging that a Contributor Version + directly or indirectly infringes any patent, then the rights granted to + You by any and all Contributors for the Covered Software under Section + 2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user + license agreements (excluding distributors and resellers) which have been + validly granted by You or Your distributors under this License prior to + termination shall survive termination. + +6. Disclaimer of Warranty + + Covered Software is provided under this License on an "as is" basis, + without warranty of any kind, either expressed, implied, or statutory, + including, without limitation, warranties that the Covered Software is free + of defects, merchantable, fit for a particular purpose or non-infringing. + The entire risk as to the quality and performance of the Covered Software + is with You. Should any Covered Software prove defective in any respect, + You (not any Contributor) assume the cost of any necessary servicing, + repair, or correction. This disclaimer of warranty constitutes an essential + part of this License. No use of any Covered Software is authorized under + this License except under this disclaimer. + +7. Limitation of Liability + + Under no circumstances and under no legal theory, whether tort (including + negligence), contract, or otherwise, shall any Contributor, or anyone who + distributes Covered Software as permitted above, be liable to You for any + direct, indirect, special, incidental, or consequential damages of any + character including, without limitation, damages for lost profits, loss of + goodwill, work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses, even if such party shall have been + informed of the possibility of such damages. This limitation of liability + shall not apply to liability for death or personal injury resulting from + such party's negligence to the extent applicable law prohibits such + limitation. Some jurisdictions do not allow the exclusion or limitation of + incidental or consequential damages, so this exclusion and limitation may + not apply to You. + +8. Litigation + + Any litigation relating to this License may be brought only in the courts + of a jurisdiction where the defendant maintains its principal place of + business and such litigation shall be governed by laws of that + jurisdiction, without reference to its conflict-of-law provisions. Nothing + in this Section shall prevent a party's ability to bring cross-claims or + counter-claims. + +9. Miscellaneous + + This License represents the complete agreement concerning the subject + matter hereof. If any provision of this License is held to be + unenforceable, such provision shall be reformed only to the extent + necessary to make it enforceable. Any law or regulation which provides that + the language of a contract shall be construed against the drafter shall not + be used to construe this License against a Contributor. + + +10. Versions of the License + +10.1. New Versions + + Mozilla Foundation is the license steward. Except as provided in Section + 10.3, no one other than the license steward has the right to modify or + publish new versions of this License. Each version will be given a + distinguishing version number. + +10.2. Effect of New Versions + + You may distribute the Covered Software under the terms of the version + of the License under which You originally received the Covered Software, + or under the terms of any subsequent version published by the license + steward. + +10.3. Modified Versions + + If you create software not governed by this License, and you want to + create a new license for such software, you may create and use a + modified version of this License if you rename the license and remove + any references to the name of the license steward (except to note that + such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary + Licenses If You choose to distribute Source Code Form that is + Incompatible With Secondary Licenses under the terms of this version of + the License, the notice described in Exhibit B of this License must be + attached. + +Exhibit A - Source Code Form License Notice + + This Source Code Form is subject to the + terms of the Mozilla Public License, v. + 2.0. If a copy of the MPL was not + distributed with this file, You can + obtain one at + http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular file, +then You may include the notice in a location (such as a LICENSE file in a +relevant directory) where a recipient would be likely to look for such a +notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice + + This Source Code Form is "Incompatible + With Secondary Licenses", as defined by + the Mozilla Public License, v. 2.0. + diff --git a/vendor/github.com/hashicorp/go-immutable-radix/edges.go b/vendor/github.com/hashicorp/go-immutable-radix/edges.go new file mode 100644 index 00000000000..a63674775f2 --- /dev/null +++ b/vendor/github.com/hashicorp/go-immutable-radix/edges.go @@ -0,0 +1,21 @@ +package iradix + +import "sort" + +type edges []edge + +func (e edges) Len() int { + return len(e) +} + +func (e edges) Less(i, j int) bool { + return e[i].label < e[j].label +} + +func (e edges) Swap(i, j int) { + e[i], e[j] = e[j], e[i] +} + +func (e edges) Sort() { + sort.Sort(e) +} diff --git a/vendor/github.com/hashicorp/go-immutable-radix/iradix.go b/vendor/github.com/hashicorp/go-immutable-radix/iradix.go new file mode 100644 index 00000000000..e5e6e57f262 --- /dev/null +++ b/vendor/github.com/hashicorp/go-immutable-radix/iradix.go @@ -0,0 +1,662 @@ +package iradix + +import ( + "bytes" + "strings" + + "github.com/hashicorp/golang-lru/simplelru" +) + +const ( + // defaultModifiedCache is the default size of the modified node + // cache used per transaction. This is used to cache the updates + // to the nodes near the root, while the leaves do not need to be + // cached. This is important for very large transactions to prevent + // the modified cache from growing to be enormous. This is also used + // to set the max size of the mutation notify maps since those should + // also be bounded in a similar way. + defaultModifiedCache = 8192 +) + +// Tree implements an immutable radix tree. This can be treated as a +// Dictionary abstract data type. The main advantage over a standard +// hash map is prefix-based lookups and ordered iteration. The immutability +// means that it is safe to concurrently read from a Tree without any +// coordination. +type Tree struct { + root *Node + size int +} + +// New returns an empty Tree +func New() *Tree { + t := &Tree{ + root: &Node{ + mutateCh: make(chan struct{}), + }, + } + return t +} + +// Len is used to return the number of elements in the tree +func (t *Tree) Len() int { + return t.size +} + +// Txn is a transaction on the tree. This transaction is applied +// atomically and returns a new tree when committed. A transaction +// is not thread safe, and should only be used by a single goroutine. +type Txn struct { + // root is the modified root for the transaction. + root *Node + + // snap is a snapshot of the root node for use if we have to run the + // slow notify algorithm. + snap *Node + + // size tracks the size of the tree as it is modified during the + // transaction. + size int + + // writable is a cache of writable nodes that have been created during + // the course of the transaction. This allows us to re-use the same + // nodes for further writes and avoid unnecessary copies of nodes that + // have never been exposed outside the transaction. This will only hold + // up to defaultModifiedCache number of entries. + writable *simplelru.LRU + + // trackChannels is used to hold channels that need to be notified to + // signal mutation of the tree. This will only hold up to + // defaultModifiedCache number of entries, after which we will set the + // trackOverflow flag, which will cause us to use a more expensive + // algorithm to perform the notifications. Mutation tracking is only + // performed if trackMutate is true. + trackChannels map[chan struct{}]struct{} + trackOverflow bool + trackMutate bool +} + +// Txn starts a new transaction that can be used to mutate the tree +func (t *Tree) Txn() *Txn { + txn := &Txn{ + root: t.root, + snap: t.root, + size: t.size, + } + return txn +} + +// TrackMutate can be used to toggle if mutations are tracked. If this is enabled +// then notifications will be issued for affected internal nodes and leaves when +// the transaction is committed. +func (t *Txn) TrackMutate(track bool) { + t.trackMutate = track +} + +// trackChannel safely attempts to track the given mutation channel, setting the +// overflow flag if we can no longer track any more. This limits the amount of +// state that will accumulate during a transaction and we have a slower algorithm +// to switch to if we overflow. +func (t *Txn) trackChannel(ch chan struct{}) { + // In overflow, make sure we don't store any more objects. + if t.trackOverflow { + return + } + + // If this would overflow the state we reject it and set the flag (since + // we aren't tracking everything that's required any longer). + if len(t.trackChannels) >= defaultModifiedCache { + // Mark that we are in the overflow state + t.trackOverflow = true + + // Clear the map so that the channels can be garbage collected. It is + // safe to do this since we have already overflowed and will be using + // the slow notify algorithm. + t.trackChannels = nil + return + } + + // Create the map on the fly when we need it. + if t.trackChannels == nil { + t.trackChannels = make(map[chan struct{}]struct{}) + } + + // Otherwise we are good to track it. + t.trackChannels[ch] = struct{}{} +} + +// writeNode returns a node to be modified, if the current node has already been +// modified during the course of the transaction, it is used in-place. Set +// forLeafUpdate to true if you are getting a write node to update the leaf, +// which will set leaf mutation tracking appropriately as well. +func (t *Txn) writeNode(n *Node, forLeafUpdate bool) *Node { + // Ensure the writable set exists. + if t.writable == nil { + lru, err := simplelru.NewLRU(defaultModifiedCache, nil) + if err != nil { + panic(err) + } + t.writable = lru + } + + // If this node has already been modified, we can continue to use it + // during this transaction. We know that we don't need to track it for + // a node update since the node is writable, but if this is for a leaf + // update we track it, in case the initial write to this node didn't + // update the leaf. + if _, ok := t.writable.Get(n); ok { + if t.trackMutate && forLeafUpdate && n.leaf != nil { + t.trackChannel(n.leaf.mutateCh) + } + return n + } + + // Mark this node as being mutated. + if t.trackMutate { + t.trackChannel(n.mutateCh) + } + + // Mark its leaf as being mutated, if appropriate. + if t.trackMutate && forLeafUpdate && n.leaf != nil { + t.trackChannel(n.leaf.mutateCh) + } + + // Copy the existing node. If you have set forLeafUpdate it will be + // safe to replace this leaf with another after you get your node for + // writing. You MUST replace it, because the channel associated with + // this leaf will be closed when this transaction is committed. + nc := &Node{ + mutateCh: make(chan struct{}), + leaf: n.leaf, + } + if n.prefix != nil { + nc.prefix = make([]byte, len(n.prefix)) + copy(nc.prefix, n.prefix) + } + if len(n.edges) != 0 { + nc.edges = make([]edge, len(n.edges)) + copy(nc.edges, n.edges) + } + + // Mark this node as writable. + t.writable.Add(nc, nil) + return nc +} + +// Visit all the nodes in the tree under n, and add their mutateChannels to the transaction +// Returns the size of the subtree visited +func (t *Txn) trackChannelsAndCount(n *Node) int { + // Count only leaf nodes + leaves := 0 + if n.leaf != nil { + leaves = 1 + } + // Mark this node as being mutated. + if t.trackMutate { + t.trackChannel(n.mutateCh) + } + + // Mark its leaf as being mutated, if appropriate. + if t.trackMutate && n.leaf != nil { + t.trackChannel(n.leaf.mutateCh) + } + + // Recurse on the children + for _, e := range n.edges { + leaves += t.trackChannelsAndCount(e.node) + } + return leaves +} + +// mergeChild is called to collapse the given node with its child. This is only +// called when the given node is not a leaf and has a single edge. +func (t *Txn) mergeChild(n *Node) { + // Mark the child node as being mutated since we are about to abandon + // it. We don't need to mark the leaf since we are retaining it if it + // is there. + e := n.edges[0] + child := e.node + if t.trackMutate { + t.trackChannel(child.mutateCh) + } + + // Merge the nodes. + n.prefix = concat(n.prefix, child.prefix) + n.leaf = child.leaf + if len(child.edges) != 0 { + n.edges = make([]edge, len(child.edges)) + copy(n.edges, child.edges) + } else { + n.edges = nil + } +} + +// insert does a recursive insertion +func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface{}, bool) { + // Handle key exhaustion + if len(search) == 0 { + var oldVal interface{} + didUpdate := false + if n.isLeaf() { + oldVal = n.leaf.val + didUpdate = true + } + + nc := t.writeNode(n, true) + nc.leaf = &leafNode{ + mutateCh: make(chan struct{}), + key: k, + val: v, + } + return nc, oldVal, didUpdate + } + + // Look for the edge + idx, child := n.getEdge(search[0]) + + // No edge, create one + if child == nil { + e := edge{ + label: search[0], + node: &Node{ + mutateCh: make(chan struct{}), + leaf: &leafNode{ + mutateCh: make(chan struct{}), + key: k, + val: v, + }, + prefix: search, + }, + } + nc := t.writeNode(n, false) + nc.addEdge(e) + return nc, nil, false + } + + // Determine longest prefix of the search key on match + commonPrefix := longestPrefix(search, child.prefix) + if commonPrefix == len(child.prefix) { + search = search[commonPrefix:] + newChild, oldVal, didUpdate := t.insert(child, k, search, v) + if newChild != nil { + nc := t.writeNode(n, false) + nc.edges[idx].node = newChild + return nc, oldVal, didUpdate + } + return nil, oldVal, didUpdate + } + + // Split the node + nc := t.writeNode(n, false) + splitNode := &Node{ + mutateCh: make(chan struct{}), + prefix: search[:commonPrefix], + } + nc.replaceEdge(edge{ + label: search[0], + node: splitNode, + }) + + // Restore the existing child node + modChild := t.writeNode(child, false) + splitNode.addEdge(edge{ + label: modChild.prefix[commonPrefix], + node: modChild, + }) + modChild.prefix = modChild.prefix[commonPrefix:] + + // Create a new leaf node + leaf := &leafNode{ + mutateCh: make(chan struct{}), + key: k, + val: v, + } + + // If the new key is a subset, add to to this node + search = search[commonPrefix:] + if len(search) == 0 { + splitNode.leaf = leaf + return nc, nil, false + } + + // Create a new edge for the node + splitNode.addEdge(edge{ + label: search[0], + node: &Node{ + mutateCh: make(chan struct{}), + leaf: leaf, + prefix: search, + }, + }) + return nc, nil, false +} + +// delete does a recursive deletion +func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) { + // Check for key exhaustion + if len(search) == 0 { + if !n.isLeaf() { + return nil, nil + } + // Copy the pointer in case we are in a transaction that already + // modified this node since the node will be reused. Any changes + // made to the node will not affect returning the original leaf + // value. + oldLeaf := n.leaf + + // Remove the leaf node + nc := t.writeNode(n, true) + nc.leaf = nil + + // Check if this node should be merged + if n != t.root && len(nc.edges) == 1 { + t.mergeChild(nc) + } + return nc, oldLeaf + } + + // Look for an edge + label := search[0] + idx, child := n.getEdge(label) + if child == nil || !bytes.HasPrefix(search, child.prefix) { + return nil, nil + } + + // Consume the search prefix + search = search[len(child.prefix):] + newChild, leaf := t.delete(n, child, search) + if newChild == nil { + return nil, nil + } + + // Copy this node. WATCH OUT - it's safe to pass "false" here because we + // will only ADD a leaf via nc.mergeChild() if there isn't one due to + // the !nc.isLeaf() check in the logic just below. This is pretty subtle, + // so be careful if you change any of the logic here. + nc := t.writeNode(n, false) + + // Delete the edge if the node has no edges + if newChild.leaf == nil && len(newChild.edges) == 0 { + nc.delEdge(label) + if n != t.root && len(nc.edges) == 1 && !nc.isLeaf() { + t.mergeChild(nc) + } + } else { + nc.edges[idx].node = newChild + } + return nc, leaf +} + +// delete does a recursive deletion +func (t *Txn) deletePrefix(parent, n *Node, search []byte) (*Node, int) { + // Check for key exhaustion + if len(search) == 0 { + nc := t.writeNode(n, true) + if n.isLeaf() { + nc.leaf = nil + } + nc.edges = nil + return nc, t.trackChannelsAndCount(n) + } + + // Look for an edge + label := search[0] + idx, child := n.getEdge(label) + // We make sure that either the child node's prefix starts with the search term, or the search term starts with the child node's prefix + // Need to do both so that we can delete prefixes that don't correspond to any node in the tree + if child == nil || (!bytes.HasPrefix(child.prefix, search) && !bytes.HasPrefix(search, child.prefix)) { + return nil, 0 + } + + // Consume the search prefix + if len(child.prefix) > len(search) { + search = []byte("") + } else { + search = search[len(child.prefix):] + } + newChild, numDeletions := t.deletePrefix(n, child, search) + if newChild == nil { + return nil, 0 + } + // Copy this node. WATCH OUT - it's safe to pass "false" here because we + // will only ADD a leaf via nc.mergeChild() if there isn't one due to + // the !nc.isLeaf() check in the logic just below. This is pretty subtle, + // so be careful if you change any of the logic here. + + nc := t.writeNode(n, false) + + // Delete the edge if the node has no edges + if newChild.leaf == nil && len(newChild.edges) == 0 { + nc.delEdge(label) + if n != t.root && len(nc.edges) == 1 && !nc.isLeaf() { + t.mergeChild(nc) + } + } else { + nc.edges[idx].node = newChild + } + return nc, numDeletions +} + +// Insert is used to add or update a given key. The return provides +// the previous value and a bool indicating if any was set. +func (t *Txn) Insert(k []byte, v interface{}) (interface{}, bool) { + newRoot, oldVal, didUpdate := t.insert(t.root, k, k, v) + if newRoot != nil { + t.root = newRoot + } + if !didUpdate { + t.size++ + } + return oldVal, didUpdate +} + +// Delete is used to delete a given key. Returns the old value if any, +// and a bool indicating if the key was set. +func (t *Txn) Delete(k []byte) (interface{}, bool) { + newRoot, leaf := t.delete(nil, t.root, k) + if newRoot != nil { + t.root = newRoot + } + if leaf != nil { + t.size-- + return leaf.val, true + } + return nil, false +} + +// DeletePrefix is used to delete an entire subtree that matches the prefix +// This will delete all nodes under that prefix +func (t *Txn) DeletePrefix(prefix []byte) bool { + newRoot, numDeletions := t.deletePrefix(nil, t.root, prefix) + if newRoot != nil { + t.root = newRoot + t.size = t.size - numDeletions + return true + } + return false + +} + +// Root returns the current root of the radix tree within this +// transaction. The root is not safe across insert and delete operations, +// but can be used to read the current state during a transaction. +func (t *Txn) Root() *Node { + return t.root +} + +// Get is used to lookup a specific key, returning +// the value and if it was found +func (t *Txn) Get(k []byte) (interface{}, bool) { + return t.root.Get(k) +} + +// GetWatch is used to lookup a specific key, returning +// the watch channel, value and if it was found +func (t *Txn) GetWatch(k []byte) (<-chan struct{}, interface{}, bool) { + return t.root.GetWatch(k) +} + +// Commit is used to finalize the transaction and return a new tree. If mutation +// tracking is turned on then notifications will also be issued. +func (t *Txn) Commit() *Tree { + nt := t.CommitOnly() + if t.trackMutate { + t.Notify() + } + return nt +} + +// CommitOnly is used to finalize the transaction and return a new tree, but +// does not issue any notifications until Notify is called. +func (t *Txn) CommitOnly() *Tree { + nt := &Tree{t.root, t.size} + t.writable = nil + return nt +} + +// slowNotify does a complete comparison of the before and after trees in order +// to trigger notifications. This doesn't require any additional state but it +// is very expensive to compute. +func (t *Txn) slowNotify() { + snapIter := t.snap.rawIterator() + rootIter := t.root.rawIterator() + for snapIter.Front() != nil || rootIter.Front() != nil { + // If we've exhausted the nodes in the old snapshot, we know + // there's nothing remaining to notify. + if snapIter.Front() == nil { + return + } + snapElem := snapIter.Front() + + // If we've exhausted the nodes in the new root, we know we need + // to invalidate everything that remains in the old snapshot. We + // know from the loop condition there's something in the old + // snapshot. + if rootIter.Front() == nil { + close(snapElem.mutateCh) + if snapElem.isLeaf() { + close(snapElem.leaf.mutateCh) + } + snapIter.Next() + continue + } + + // Do one string compare so we can check the various conditions + // below without repeating the compare. + cmp := strings.Compare(snapIter.Path(), rootIter.Path()) + + // If the snapshot is behind the root, then we must have deleted + // this node during the transaction. + if cmp < 0 { + close(snapElem.mutateCh) + if snapElem.isLeaf() { + close(snapElem.leaf.mutateCh) + } + snapIter.Next() + continue + } + + // If the snapshot is ahead of the root, then we must have added + // this node during the transaction. + if cmp > 0 { + rootIter.Next() + continue + } + + // If we have the same path, then we need to see if we mutated a + // node and possibly the leaf. + rootElem := rootIter.Front() + if snapElem != rootElem { + close(snapElem.mutateCh) + if snapElem.leaf != nil && (snapElem.leaf != rootElem.leaf) { + close(snapElem.leaf.mutateCh) + } + } + snapIter.Next() + rootIter.Next() + } +} + +// Notify is used along with TrackMutate to trigger notifications. This must +// only be done once a transaction is committed via CommitOnly, and it is called +// automatically by Commit. +func (t *Txn) Notify() { + if !t.trackMutate { + return + } + + // If we've overflowed the tracking state we can't use it in any way and + // need to do a full tree compare. + if t.trackOverflow { + t.slowNotify() + } else { + for ch := range t.trackChannels { + close(ch) + } + } + + // Clean up the tracking state so that a re-notify is safe (will trigger + // the else clause above which will be a no-op). + t.trackChannels = nil + t.trackOverflow = false +} + +// Insert is used to add or update a given key. The return provides +// the new tree, previous value and a bool indicating if any was set. +func (t *Tree) Insert(k []byte, v interface{}) (*Tree, interface{}, bool) { + txn := t.Txn() + old, ok := txn.Insert(k, v) + return txn.Commit(), old, ok +} + +// Delete is used to delete a given key. Returns the new tree, +// old value if any, and a bool indicating if the key was set. +func (t *Tree) Delete(k []byte) (*Tree, interface{}, bool) { + txn := t.Txn() + old, ok := txn.Delete(k) + return txn.Commit(), old, ok +} + +// DeletePrefix is used to delete all nodes starting with a given prefix. Returns the new tree, +// and a bool indicating if the prefix matched any nodes +func (t *Tree) DeletePrefix(k []byte) (*Tree, bool) { + txn := t.Txn() + ok := txn.DeletePrefix(k) + return txn.Commit(), ok +} + +// Root returns the root node of the tree which can be used for richer +// query operations. +func (t *Tree) Root() *Node { + return t.root +} + +// Get is used to lookup a specific key, returning +// the value and if it was found +func (t *Tree) Get(k []byte) (interface{}, bool) { + return t.root.Get(k) +} + +// longestPrefix finds the length of the shared prefix +// of two strings +func longestPrefix(k1, k2 []byte) int { + max := len(k1) + if l := len(k2); l < max { + max = l + } + var i int + for i = 0; i < max; i++ { + if k1[i] != k2[i] { + break + } + } + return i +} + +// concat two byte slices, returning a third new copy +func concat(a, b []byte) []byte { + c := make([]byte, len(a)+len(b)) + copy(c, a) + copy(c[len(a):], b) + return c +} diff --git a/vendor/github.com/hashicorp/go-immutable-radix/iter.go b/vendor/github.com/hashicorp/go-immutable-radix/iter.go new file mode 100644 index 00000000000..9815e02538e --- /dev/null +++ b/vendor/github.com/hashicorp/go-immutable-radix/iter.go @@ -0,0 +1,91 @@ +package iradix + +import "bytes" + +// Iterator is used to iterate over a set of nodes +// in pre-order +type Iterator struct { + node *Node + stack []edges +} + +// SeekPrefixWatch is used to seek the iterator to a given prefix +// and returns the watch channel of the finest granularity +func (i *Iterator) SeekPrefixWatch(prefix []byte) (watch <-chan struct{}) { + // Wipe the stack + i.stack = nil + n := i.node + watch = n.mutateCh + search := prefix + for { + // Check for key exhaution + if len(search) == 0 { + i.node = n + return + } + + // Look for an edge + _, n = n.getEdge(search[0]) + if n == nil { + i.node = nil + return + } + + // Update to the finest granularity as the search makes progress + watch = n.mutateCh + + // Consume the search prefix + if bytes.HasPrefix(search, n.prefix) { + search = search[len(n.prefix):] + + } else if bytes.HasPrefix(n.prefix, search) { + i.node = n + return + } else { + i.node = nil + return + } + } +} + +// SeekPrefix is used to seek the iterator to a given prefix +func (i *Iterator) SeekPrefix(prefix []byte) { + i.SeekPrefixWatch(prefix) +} + +// Next returns the next node in order +func (i *Iterator) Next() ([]byte, interface{}, bool) { + // Initialize our stack if needed + if i.stack == nil && i.node != nil { + i.stack = []edges{ + edges{ + edge{node: i.node}, + }, + } + } + + for len(i.stack) > 0 { + // Inspect the last element of the stack + n := len(i.stack) + last := i.stack[n-1] + elem := last[0].node + + // Update the stack + if len(last) > 1 { + i.stack[n-1] = last[1:] + } else { + i.stack = i.stack[:n-1] + } + + // Push the edges onto the frontier + if len(elem.edges) > 0 { + i.stack = append(i.stack, elem.edges) + } + + // Return the leaf values if any + if elem.leaf != nil { + return elem.leaf.key, elem.leaf.val, true + } + } + return nil, nil, false +} diff --git a/vendor/github.com/hashicorp/go-immutable-radix/node.go b/vendor/github.com/hashicorp/go-immutable-radix/node.go new file mode 100644 index 00000000000..7a065e7a09e --- /dev/null +++ b/vendor/github.com/hashicorp/go-immutable-radix/node.go @@ -0,0 +1,292 @@ +package iradix + +import ( + "bytes" + "sort" +) + +// WalkFn is used when walking the tree. Takes a +// key and value, returning if iteration should +// be terminated. +type WalkFn func(k []byte, v interface{}) bool + +// leafNode is used to represent a value +type leafNode struct { + mutateCh chan struct{} + key []byte + val interface{} +} + +// edge is used to represent an edge node +type edge struct { + label byte + node *Node +} + +// Node is an immutable node in the radix tree +type Node struct { + // mutateCh is closed if this node is modified + mutateCh chan struct{} + + // leaf is used to store possible leaf + leaf *leafNode + + // prefix is the common prefix we ignore + prefix []byte + + // Edges should be stored in-order for iteration. + // We avoid a fully materialized slice to save memory, + // since in most cases we expect to be sparse + edges edges +} + +func (n *Node) isLeaf() bool { + return n.leaf != nil +} + +func (n *Node) addEdge(e edge) { + num := len(n.edges) + idx := sort.Search(num, func(i int) bool { + return n.edges[i].label >= e.label + }) + n.edges = append(n.edges, e) + if idx != num { + copy(n.edges[idx+1:], n.edges[idx:num]) + n.edges[idx] = e + } +} + +func (n *Node) replaceEdge(e edge) { + num := len(n.edges) + idx := sort.Search(num, func(i int) bool { + return n.edges[i].label >= e.label + }) + if idx < num && n.edges[idx].label == e.label { + n.edges[idx].node = e.node + return + } + panic("replacing missing edge") +} + +func (n *Node) getEdge(label byte) (int, *Node) { + num := len(n.edges) + idx := sort.Search(num, func(i int) bool { + return n.edges[i].label >= label + }) + if idx < num && n.edges[idx].label == label { + return idx, n.edges[idx].node + } + return -1, nil +} + +func (n *Node) delEdge(label byte) { + num := len(n.edges) + idx := sort.Search(num, func(i int) bool { + return n.edges[i].label >= label + }) + if idx < num && n.edges[idx].label == label { + copy(n.edges[idx:], n.edges[idx+1:]) + n.edges[len(n.edges)-1] = edge{} + n.edges = n.edges[:len(n.edges)-1] + } +} + +func (n *Node) GetWatch(k []byte) (<-chan struct{}, interface{}, bool) { + search := k + watch := n.mutateCh + for { + // Check for key exhaustion + if len(search) == 0 { + if n.isLeaf() { + return n.leaf.mutateCh, n.leaf.val, true + } + break + } + + // Look for an edge + _, n = n.getEdge(search[0]) + if n == nil { + break + } + + // Update to the finest granularity as the search makes progress + watch = n.mutateCh + + // Consume the search prefix + if bytes.HasPrefix(search, n.prefix) { + search = search[len(n.prefix):] + } else { + break + } + } + return watch, nil, false +} + +func (n *Node) Get(k []byte) (interface{}, bool) { + _, val, ok := n.GetWatch(k) + return val, ok +} + +// LongestPrefix is like Get, but instead of an +// exact match, it will return the longest prefix match. +func (n *Node) LongestPrefix(k []byte) ([]byte, interface{}, bool) { + var last *leafNode + search := k + for { + // Look for a leaf node + if n.isLeaf() { + last = n.leaf + } + + // Check for key exhaution + if len(search) == 0 { + break + } + + // Look for an edge + _, n = n.getEdge(search[0]) + if n == nil { + break + } + + // Consume the search prefix + if bytes.HasPrefix(search, n.prefix) { + search = search[len(n.prefix):] + } else { + break + } + } + if last != nil { + return last.key, last.val, true + } + return nil, nil, false +} + +// Minimum is used to return the minimum value in the tree +func (n *Node) Minimum() ([]byte, interface{}, bool) { + for { + if n.isLeaf() { + return n.leaf.key, n.leaf.val, true + } + if len(n.edges) > 0 { + n = n.edges[0].node + } else { + break + } + } + return nil, nil, false +} + +// Maximum is used to return the maximum value in the tree +func (n *Node) Maximum() ([]byte, interface{}, bool) { + for { + if num := len(n.edges); num > 0 { + n = n.edges[num-1].node + continue + } + if n.isLeaf() { + return n.leaf.key, n.leaf.val, true + } else { + break + } + } + return nil, nil, false +} + +// Iterator is used to return an iterator at +// the given node to walk the tree +func (n *Node) Iterator() *Iterator { + return &Iterator{node: n} +} + +// rawIterator is used to return a raw iterator at the given node to walk the +// tree. +func (n *Node) rawIterator() *rawIterator { + iter := &rawIterator{node: n} + iter.Next() + return iter +} + +// Walk is used to walk the tree +func (n *Node) Walk(fn WalkFn) { + recursiveWalk(n, fn) +} + +// WalkPrefix is used to walk the tree under a prefix +func (n *Node) WalkPrefix(prefix []byte, fn WalkFn) { + search := prefix + for { + // Check for key exhaution + if len(search) == 0 { + recursiveWalk(n, fn) + return + } + + // Look for an edge + _, n = n.getEdge(search[0]) + if n == nil { + break + } + + // Consume the search prefix + if bytes.HasPrefix(search, n.prefix) { + search = search[len(n.prefix):] + + } else if bytes.HasPrefix(n.prefix, search) { + // Child may be under our search prefix + recursiveWalk(n, fn) + return + } else { + break + } + } +} + +// WalkPath is used to walk the tree, but only visiting nodes +// from the root down to a given leaf. Where WalkPrefix walks +// all the entries *under* the given prefix, this walks the +// entries *above* the given prefix. +func (n *Node) WalkPath(path []byte, fn WalkFn) { + search := path + for { + // Visit the leaf values if any + if n.leaf != nil && fn(n.leaf.key, n.leaf.val) { + return + } + + // Check for key exhaution + if len(search) == 0 { + return + } + + // Look for an edge + _, n = n.getEdge(search[0]) + if n == nil { + return + } + + // Consume the search prefix + if bytes.HasPrefix(search, n.prefix) { + search = search[len(n.prefix):] + } else { + break + } + } +} + +// recursiveWalk is used to do a pre-order walk of a node +// recursively. Returns true if the walk should be aborted +func recursiveWalk(n *Node, fn WalkFn) bool { + // Visit the leaf values if any + if n.leaf != nil && fn(n.leaf.key, n.leaf.val) { + return true + } + + // Recurse on the children + for _, e := range n.edges { + if recursiveWalk(e.node, fn) { + return true + } + } + return false +} diff --git a/vendor/github.com/hashicorp/go-immutable-radix/raw_iter.go b/vendor/github.com/hashicorp/go-immutable-radix/raw_iter.go new file mode 100644 index 00000000000..04814c1323f --- /dev/null +++ b/vendor/github.com/hashicorp/go-immutable-radix/raw_iter.go @@ -0,0 +1,78 @@ +package iradix + +// rawIterator visits each of the nodes in the tree, even the ones that are not +// leaves. It keeps track of the effective path (what a leaf at a given node +// would be called), which is useful for comparing trees. +type rawIterator struct { + // node is the starting node in the tree for the iterator. + node *Node + + // stack keeps track of edges in the frontier. + stack []rawStackEntry + + // pos is the current position of the iterator. + pos *Node + + // path is the effective path of the current iterator position, + // regardless of whether the current node is a leaf. + path string +} + +// rawStackEntry is used to keep track of the cumulative common path as well as +// its associated edges in the frontier. +type rawStackEntry struct { + path string + edges edges +} + +// Front returns the current node that has been iterated to. +func (i *rawIterator) Front() *Node { + return i.pos +} + +// Path returns the effective path of the current node, even if it's not actually +// a leaf. +func (i *rawIterator) Path() string { + return i.path +} + +// Next advances the iterator to the next node. +func (i *rawIterator) Next() { + // Initialize our stack if needed. + if i.stack == nil && i.node != nil { + i.stack = []rawStackEntry{ + rawStackEntry{ + edges: edges{ + edge{node: i.node}, + }, + }, + } + } + + for len(i.stack) > 0 { + // Inspect the last element of the stack. + n := len(i.stack) + last := i.stack[n-1] + elem := last.edges[0].node + + // Update the stack. + if len(last.edges) > 1 { + i.stack[n-1].edges = last.edges[1:] + } else { + i.stack = i.stack[:n-1] + } + + // Push the edges onto the frontier. + if len(elem.edges) > 0 { + path := last.path + string(elem.prefix) + i.stack = append(i.stack, rawStackEntry{path, elem.edges}) + } + + i.pos = elem + i.path = last.path + string(elem.prefix) + return + } + + i.pos = nil + i.path = "" +} diff --git a/vendor/github.com/hashicorp/go-msgpack/LICENSE b/vendor/github.com/hashicorp/go-msgpack/LICENSE new file mode 100644 index 00000000000..ccae99f6a9a --- /dev/null +++ b/vendor/github.com/hashicorp/go-msgpack/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2012, 2013 Ugorji Nwoke. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. +* Neither the name of the author nor the names of its contributors may be used + to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/hashicorp/go-msgpack/codec/0doc.go b/vendor/github.com/hashicorp/go-msgpack/codec/0doc.go new file mode 100644 index 00000000000..c14d810a73e --- /dev/null +++ b/vendor/github.com/hashicorp/go-msgpack/codec/0doc.go @@ -0,0 +1,143 @@ +// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved. +// Use of this source code is governed by a BSD-style license found in the LICENSE file. + +/* +High Performance, Feature-Rich Idiomatic Go encoding library for msgpack and binc . + +Supported Serialization formats are: + + - msgpack: [https://github.com/msgpack/msgpack] + - binc: [http://github.com/ugorji/binc] + +To install: + + go get github.com/ugorji/go/codec + +The idiomatic Go support is as seen in other encoding packages in +the standard library (ie json, xml, gob, etc). + +Rich Feature Set includes: + + - Simple but extremely powerful and feature-rich API + - Very High Performance. + Our extensive benchmarks show us outperforming Gob, Json and Bson by 2-4X. + This was achieved by taking extreme care on: + - managing allocation + - function frame size (important due to Go's use of split stacks), + - reflection use (and by-passing reflection for common types) + - recursion implications + - zero-copy mode (encoding/decoding to byte slice without using temp buffers) + - Correct. + Care was taken to precisely handle corner cases like: + overflows, nil maps and slices, nil value in stream, etc. + - Efficient zero-copying into temporary byte buffers + when encoding into or decoding from a byte slice. + - Standard field renaming via tags + - Encoding from any value + (struct, slice, map, primitives, pointers, interface{}, etc) + - Decoding into pointer to any non-nil typed value + (struct, slice, map, int, float32, bool, string, reflect.Value, etc) + - Supports extension functions to handle the encode/decode of custom types + - Support Go 1.2 encoding.BinaryMarshaler/BinaryUnmarshaler + - Schema-less decoding + (decode into a pointer to a nil interface{} as opposed to a typed non-nil value). + Includes Options to configure what specific map or slice type to use + when decoding an encoded list or map into a nil interface{} + - Provides a RPC Server and Client Codec for net/rpc communication protocol. + - Msgpack Specific: + - Provides extension functions to handle spec-defined extensions (binary, timestamp) + - Options to resolve ambiguities in handling raw bytes (as string or []byte) + during schema-less decoding (decoding into a nil interface{}) + - RPC Server/Client Codec for msgpack-rpc protocol defined at: + https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md + - Fast Paths for some container types: + For some container types, we circumvent reflection and its associated overhead + and allocation costs, and encode/decode directly. These types are: + []interface{} + []int + []string + map[interface{}]interface{} + map[int]interface{} + map[string]interface{} + +Extension Support + +Users can register a function to handle the encoding or decoding of +their custom types. + +There are no restrictions on what the custom type can be. Some examples: + + type BisSet []int + type BitSet64 uint64 + type UUID string + type MyStructWithUnexportedFields struct { a int; b bool; c []int; } + type GifImage struct { ... } + +As an illustration, MyStructWithUnexportedFields would normally be +encoded as an empty map because it has no exported fields, while UUID +would be encoded as a string. However, with extension support, you can +encode any of these however you like. + +RPC + +RPC Client and Server Codecs are implemented, so the codecs can be used +with the standard net/rpc package. + +Usage + +Typical usage model: + + // create and configure Handle + var ( + bh codec.BincHandle + mh codec.MsgpackHandle + ) + + mh.MapType = reflect.TypeOf(map[string]interface{}(nil)) + + // configure extensions + // e.g. for msgpack, define functions and enable Time support for tag 1 + // mh.AddExt(reflect.TypeOf(time.Time{}), 1, myMsgpackTimeEncodeExtFn, myMsgpackTimeDecodeExtFn) + + // create and use decoder/encoder + var ( + r io.Reader + w io.Writer + b []byte + h = &bh // or mh to use msgpack + ) + + dec = codec.NewDecoder(r, h) + dec = codec.NewDecoderBytes(b, h) + err = dec.Decode(&v) + + enc = codec.NewEncoder(w, h) + enc = codec.NewEncoderBytes(&b, h) + err = enc.Encode(v) + + //RPC Server + go func() { + for { + conn, err := listener.Accept() + rpcCodec := codec.GoRpc.ServerCodec(conn, h) + //OR rpcCodec := codec.MsgpackSpecRpc.ServerCodec(conn, h) + rpc.ServeCodec(rpcCodec) + } + }() + + //RPC Communication (client side) + conn, err = net.Dial("tcp", "localhost:5555") + rpcCodec := codec.GoRpc.ClientCodec(conn, h) + //OR rpcCodec := codec.MsgpackSpecRpc.ClientCodec(conn, h) + client := rpc.NewClientWithCodec(rpcCodec) + +Representative Benchmark Results + +Run the benchmark suite using: + go test -bi -bench=. -benchmem + +To run full benchmark suite (including against vmsgpack and bson), +see notes in ext_dep_test.go + +*/ +package codec diff --git a/vendor/github.com/hashicorp/go-msgpack/codec/binc.go b/vendor/github.com/hashicorp/go-msgpack/codec/binc.go new file mode 100644 index 00000000000..2bb5e8fee85 --- /dev/null +++ b/vendor/github.com/hashicorp/go-msgpack/codec/binc.go @@ -0,0 +1,786 @@ +// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved. +// Use of this source code is governed by a BSD-style license found in the LICENSE file. + +package codec + +import ( + "math" + // "reflect" + // "sync/atomic" + "time" + //"fmt" +) + +const bincDoPrune = true // No longer needed. Needed before as C lib did not support pruning. + +//var _ = fmt.Printf + +// vd as low 4 bits (there are 16 slots) +const ( + bincVdSpecial byte = iota + bincVdPosInt + bincVdNegInt + bincVdFloat + + bincVdString + bincVdByteArray + bincVdArray + bincVdMap + + bincVdTimestamp + bincVdSmallInt + bincVdUnicodeOther + bincVdSymbol + + bincVdDecimal + _ // open slot + _ // open slot + bincVdCustomExt = 0x0f +) + +const ( + bincSpNil byte = iota + bincSpFalse + bincSpTrue + bincSpNan + bincSpPosInf + bincSpNegInf + bincSpZeroFloat + bincSpZero + bincSpNegOne +) + +const ( + bincFlBin16 byte = iota + bincFlBin32 + _ // bincFlBin32e + bincFlBin64 + _ // bincFlBin64e + // others not currently supported +) + +type bincEncDriver struct { + w encWriter + m map[string]uint16 // symbols + s uint32 // symbols sequencer + b [8]byte +} + +func (e *bincEncDriver) isBuiltinType(rt uintptr) bool { + return rt == timeTypId +} + +func (e *bincEncDriver) encodeBuiltin(rt uintptr, v interface{}) { + switch rt { + case timeTypId: + bs := encodeTime(v.(time.Time)) + e.w.writen1(bincVdTimestamp<<4 | uint8(len(bs))) + e.w.writeb(bs) + } +} + +func (e *bincEncDriver) encodeNil() { + e.w.writen1(bincVdSpecial<<4 | bincSpNil) +} + +func (e *bincEncDriver) encodeBool(b bool) { + if b { + e.w.writen1(bincVdSpecial<<4 | bincSpTrue) + } else { + e.w.writen1(bincVdSpecial<<4 | bincSpFalse) + } +} + +func (e *bincEncDriver) encodeFloat32(f float32) { + if f == 0 { + e.w.writen1(bincVdSpecial<<4 | bincSpZeroFloat) + return + } + e.w.writen1(bincVdFloat<<4 | bincFlBin32) + e.w.writeUint32(math.Float32bits(f)) +} + +func (e *bincEncDriver) encodeFloat64(f float64) { + if f == 0 { + e.w.writen1(bincVdSpecial<<4 | bincSpZeroFloat) + return + } + bigen.PutUint64(e.b[:], math.Float64bits(f)) + if bincDoPrune { + i := 7 + for ; i >= 0 && (e.b[i] == 0); i-- { + } + i++ + if i <= 6 { + e.w.writen1(bincVdFloat<<4 | 0x8 | bincFlBin64) + e.w.writen1(byte(i)) + e.w.writeb(e.b[:i]) + return + } + } + e.w.writen1(bincVdFloat<<4 | bincFlBin64) + e.w.writeb(e.b[:]) +} + +func (e *bincEncDriver) encIntegerPrune(bd byte, pos bool, v uint64, lim uint8) { + if lim == 4 { + bigen.PutUint32(e.b[:lim], uint32(v)) + } else { + bigen.PutUint64(e.b[:lim], v) + } + if bincDoPrune { + i := pruneSignExt(e.b[:lim], pos) + e.w.writen1(bd | lim - 1 - byte(i)) + e.w.writeb(e.b[i:lim]) + } else { + e.w.writen1(bd | lim - 1) + e.w.writeb(e.b[:lim]) + } +} + +func (e *bincEncDriver) encodeInt(v int64) { + const nbd byte = bincVdNegInt << 4 + switch { + case v >= 0: + e.encUint(bincVdPosInt<<4, true, uint64(v)) + case v == -1: + e.w.writen1(bincVdSpecial<<4 | bincSpNegOne) + default: + e.encUint(bincVdNegInt<<4, false, uint64(-v)) + } +} + +func (e *bincEncDriver) encodeUint(v uint64) { + e.encUint(bincVdPosInt<<4, true, v) +} + +func (e *bincEncDriver) encUint(bd byte, pos bool, v uint64) { + switch { + case v == 0: + e.w.writen1(bincVdSpecial<<4 | bincSpZero) + case pos && v >= 1 && v <= 16: + e.w.writen1(bincVdSmallInt<<4 | byte(v-1)) + case v <= math.MaxUint8: + e.w.writen2(bd|0x0, byte(v)) + case v <= math.MaxUint16: + e.w.writen1(bd | 0x01) + e.w.writeUint16(uint16(v)) + case v <= math.MaxUint32: + e.encIntegerPrune(bd, pos, v, 4) + default: + e.encIntegerPrune(bd, pos, v, 8) + } +} + +func (e *bincEncDriver) encodeExtPreamble(xtag byte, length int) { + e.encLen(bincVdCustomExt<<4, uint64(length)) + e.w.writen1(xtag) +} + +func (e *bincEncDriver) encodeArrayPreamble(length int) { + e.encLen(bincVdArray<<4, uint64(length)) +} + +func (e *bincEncDriver) encodeMapPreamble(length int) { + e.encLen(bincVdMap<<4, uint64(length)) +} + +func (e *bincEncDriver) encodeString(c charEncoding, v string) { + l := uint64(len(v)) + e.encBytesLen(c, l) + if l > 0 { + e.w.writestr(v) + } +} + +func (e *bincEncDriver) encodeSymbol(v string) { + // if WriteSymbolsNoRefs { + // e.encodeString(c_UTF8, v) + // return + // } + + //symbols only offer benefit when string length > 1. + //This is because strings with length 1 take only 2 bytes to store + //(bd with embedded length, and single byte for string val). + + l := len(v) + switch l { + case 0: + e.encBytesLen(c_UTF8, 0) + return + case 1: + e.encBytesLen(c_UTF8, 1) + e.w.writen1(v[0]) + return + } + if e.m == nil { + e.m = make(map[string]uint16, 16) + } + ui, ok := e.m[v] + if ok { + if ui <= math.MaxUint8 { + e.w.writen2(bincVdSymbol<<4, byte(ui)) + } else { + e.w.writen1(bincVdSymbol<<4 | 0x8) + e.w.writeUint16(ui) + } + } else { + e.s++ + ui = uint16(e.s) + //ui = uint16(atomic.AddUint32(&e.s, 1)) + e.m[v] = ui + var lenprec uint8 + switch { + case l <= math.MaxUint8: + // lenprec = 0 + case l <= math.MaxUint16: + lenprec = 1 + case int64(l) <= math.MaxUint32: + lenprec = 2 + default: + lenprec = 3 + } + if ui <= math.MaxUint8 { + e.w.writen2(bincVdSymbol<<4|0x0|0x4|lenprec, byte(ui)) + } else { + e.w.writen1(bincVdSymbol<<4 | 0x8 | 0x4 | lenprec) + e.w.writeUint16(ui) + } + switch lenprec { + case 0: + e.w.writen1(byte(l)) + case 1: + e.w.writeUint16(uint16(l)) + case 2: + e.w.writeUint32(uint32(l)) + default: + e.w.writeUint64(uint64(l)) + } + e.w.writestr(v) + } +} + +func (e *bincEncDriver) encodeStringBytes(c charEncoding, v []byte) { + l := uint64(len(v)) + e.encBytesLen(c, l) + if l > 0 { + e.w.writeb(v) + } +} + +func (e *bincEncDriver) encBytesLen(c charEncoding, length uint64) { + //TODO: support bincUnicodeOther (for now, just use string or bytearray) + if c == c_RAW { + e.encLen(bincVdByteArray<<4, length) + } else { + e.encLen(bincVdString<<4, length) + } +} + +func (e *bincEncDriver) encLen(bd byte, l uint64) { + if l < 12 { + e.w.writen1(bd | uint8(l+4)) + } else { + e.encLenNumber(bd, l) + } +} + +func (e *bincEncDriver) encLenNumber(bd byte, v uint64) { + switch { + case v <= math.MaxUint8: + e.w.writen2(bd, byte(v)) + case v <= math.MaxUint16: + e.w.writen1(bd | 0x01) + e.w.writeUint16(uint16(v)) + case v <= math.MaxUint32: + e.w.writen1(bd | 0x02) + e.w.writeUint32(uint32(v)) + default: + e.w.writen1(bd | 0x03) + e.w.writeUint64(uint64(v)) + } +} + +//------------------------------------ + +type bincDecDriver struct { + r decReader + bdRead bool + bdType valueType + bd byte + vd byte + vs byte + b [8]byte + m map[uint32]string // symbols (use uint32 as key, as map optimizes for it) +} + +func (d *bincDecDriver) initReadNext() { + if d.bdRead { + return + } + d.bd = d.r.readn1() + d.vd = d.bd >> 4 + d.vs = d.bd & 0x0f + d.bdRead = true + d.bdType = valueTypeUnset +} + +func (d *bincDecDriver) currentEncodedType() valueType { + if d.bdType == valueTypeUnset { + switch d.vd { + case bincVdSpecial: + switch d.vs { + case bincSpNil: + d.bdType = valueTypeNil + case bincSpFalse, bincSpTrue: + d.bdType = valueTypeBool + case bincSpNan, bincSpNegInf, bincSpPosInf, bincSpZeroFloat: + d.bdType = valueTypeFloat + case bincSpZero: + d.bdType = valueTypeUint + case bincSpNegOne: + d.bdType = valueTypeInt + default: + decErr("currentEncodedType: Unrecognized special value 0x%x", d.vs) + } + case bincVdSmallInt: + d.bdType = valueTypeUint + case bincVdPosInt: + d.bdType = valueTypeUint + case bincVdNegInt: + d.bdType = valueTypeInt + case bincVdFloat: + d.bdType = valueTypeFloat + case bincVdString: + d.bdType = valueTypeString + case bincVdSymbol: + d.bdType = valueTypeSymbol + case bincVdByteArray: + d.bdType = valueTypeBytes + case bincVdTimestamp: + d.bdType = valueTypeTimestamp + case bincVdCustomExt: + d.bdType = valueTypeExt + case bincVdArray: + d.bdType = valueTypeArray + case bincVdMap: + d.bdType = valueTypeMap + default: + decErr("currentEncodedType: Unrecognized d.vd: 0x%x", d.vd) + } + } + return d.bdType +} + +func (d *bincDecDriver) tryDecodeAsNil() bool { + if d.bd == bincVdSpecial<<4|bincSpNil { + d.bdRead = false + return true + } + return false +} + +func (d *bincDecDriver) isBuiltinType(rt uintptr) bool { + return rt == timeTypId +} + +func (d *bincDecDriver) decodeBuiltin(rt uintptr, v interface{}) { + switch rt { + case timeTypId: + if d.vd != bincVdTimestamp { + decErr("Invalid d.vd. Expecting 0x%x. Received: 0x%x", bincVdTimestamp, d.vd) + } + tt, err := decodeTime(d.r.readn(int(d.vs))) + if err != nil { + panic(err) + } + var vt *time.Time = v.(*time.Time) + *vt = tt + d.bdRead = false + } +} + +func (d *bincDecDriver) decFloatPre(vs, defaultLen byte) { + if vs&0x8 == 0 { + d.r.readb(d.b[0:defaultLen]) + } else { + l := d.r.readn1() + if l > 8 { + decErr("At most 8 bytes used to represent float. Received: %v bytes", l) + } + for i := l; i < 8; i++ { + d.b[i] = 0 + } + d.r.readb(d.b[0:l]) + } +} + +func (d *bincDecDriver) decFloat() (f float64) { + //if true { f = math.Float64frombits(d.r.readUint64()); break; } + switch vs := d.vs; vs & 0x7 { + case bincFlBin32: + d.decFloatPre(vs, 4) + f = float64(math.Float32frombits(bigen.Uint32(d.b[0:4]))) + case bincFlBin64: + d.decFloatPre(vs, 8) + f = math.Float64frombits(bigen.Uint64(d.b[0:8])) + default: + decErr("only float32 and float64 are supported. d.vd: 0x%x, d.vs: 0x%x", d.vd, d.vs) + } + return +} + +func (d *bincDecDriver) decUint() (v uint64) { + // need to inline the code (interface conversion and type assertion expensive) + switch d.vs { + case 0: + v = uint64(d.r.readn1()) + case 1: + d.r.readb(d.b[6:]) + v = uint64(bigen.Uint16(d.b[6:])) + case 2: + d.b[4] = 0 + d.r.readb(d.b[5:]) + v = uint64(bigen.Uint32(d.b[4:])) + case 3: + d.r.readb(d.b[4:]) + v = uint64(bigen.Uint32(d.b[4:])) + case 4, 5, 6: + lim := int(7 - d.vs) + d.r.readb(d.b[lim:]) + for i := 0; i < lim; i++ { + d.b[i] = 0 + } + v = uint64(bigen.Uint64(d.b[:])) + case 7: + d.r.readb(d.b[:]) + v = uint64(bigen.Uint64(d.b[:])) + default: + decErr("unsigned integers with greater than 64 bits of precision not supported") + } + return +} + +func (d *bincDecDriver) decIntAny() (ui uint64, i int64, neg bool) { + switch d.vd { + case bincVdPosInt: + ui = d.decUint() + i = int64(ui) + case bincVdNegInt: + ui = d.decUint() + i = -(int64(ui)) + neg = true + case bincVdSmallInt: + i = int64(d.vs) + 1 + ui = uint64(d.vs) + 1 + case bincVdSpecial: + switch d.vs { + case bincSpZero: + //i = 0 + case bincSpNegOne: + neg = true + ui = 1 + i = -1 + default: + decErr("numeric decode fails for special value: d.vs: 0x%x", d.vs) + } + default: + decErr("number can only be decoded from uint or int values. d.bd: 0x%x, d.vd: 0x%x", d.bd, d.vd) + } + return +} + +func (d *bincDecDriver) decodeInt(bitsize uint8) (i int64) { + _, i, _ = d.decIntAny() + checkOverflow(0, i, bitsize) + d.bdRead = false + return +} + +func (d *bincDecDriver) decodeUint(bitsize uint8) (ui uint64) { + ui, i, neg := d.decIntAny() + if neg { + decErr("Assigning negative signed value: %v, to unsigned type", i) + } + checkOverflow(ui, 0, bitsize) + d.bdRead = false + return +} + +func (d *bincDecDriver) decodeFloat(chkOverflow32 bool) (f float64) { + switch d.vd { + case bincVdSpecial: + d.bdRead = false + switch d.vs { + case bincSpNan: + return math.NaN() + case bincSpPosInf: + return math.Inf(1) + case bincSpZeroFloat, bincSpZero: + return + case bincSpNegInf: + return math.Inf(-1) + default: + decErr("Invalid d.vs decoding float where d.vd=bincVdSpecial: %v", d.vs) + } + case bincVdFloat: + f = d.decFloat() + default: + _, i, _ := d.decIntAny() + f = float64(i) + } + checkOverflowFloat32(f, chkOverflow32) + d.bdRead = false + return +} + +// bool can be decoded from bool only (single byte). +func (d *bincDecDriver) decodeBool() (b bool) { + switch d.bd { + case (bincVdSpecial | bincSpFalse): + // b = false + case (bincVdSpecial | bincSpTrue): + b = true + default: + decErr("Invalid single-byte value for bool: %s: %x", msgBadDesc, d.bd) + } + d.bdRead = false + return +} + +func (d *bincDecDriver) readMapLen() (length int) { + if d.vd != bincVdMap { + decErr("Invalid d.vd for map. Expecting 0x%x. Got: 0x%x", bincVdMap, d.vd) + } + length = d.decLen() + d.bdRead = false + return +} + +func (d *bincDecDriver) readArrayLen() (length int) { + if d.vd != bincVdArray { + decErr("Invalid d.vd for array. Expecting 0x%x. Got: 0x%x", bincVdArray, d.vd) + } + length = d.decLen() + d.bdRead = false + return +} + +func (d *bincDecDriver) decLen() int { + if d.vs <= 3 { + return int(d.decUint()) + } + return int(d.vs - 4) +} + +func (d *bincDecDriver) decodeString() (s string) { + switch d.vd { + case bincVdString, bincVdByteArray: + if length := d.decLen(); length > 0 { + s = string(d.r.readn(length)) + } + case bincVdSymbol: + //from vs: extract numSymbolBytes, containsStringVal, strLenPrecision, + //extract symbol + //if containsStringVal, read it and put in map + //else look in map for string value + var symbol uint32 + vs := d.vs + //fmt.Printf(">>>> d.vs: 0b%b, & 0x8: %v, & 0x4: %v\n", d.vs, vs & 0x8, vs & 0x4) + if vs&0x8 == 0 { + symbol = uint32(d.r.readn1()) + } else { + symbol = uint32(d.r.readUint16()) + } + if d.m == nil { + d.m = make(map[uint32]string, 16) + } + + if vs&0x4 == 0 { + s = d.m[symbol] + } else { + var slen int + switch vs & 0x3 { + case 0: + slen = int(d.r.readn1()) + case 1: + slen = int(d.r.readUint16()) + case 2: + slen = int(d.r.readUint32()) + case 3: + slen = int(d.r.readUint64()) + } + s = string(d.r.readn(slen)) + d.m[symbol] = s + } + default: + decErr("Invalid d.vd for string. Expecting string:0x%x, bytearray:0x%x or symbol: 0x%x. Got: 0x%x", + bincVdString, bincVdByteArray, bincVdSymbol, d.vd) + } + d.bdRead = false + return +} + +func (d *bincDecDriver) decodeBytes(bs []byte) (bsOut []byte, changed bool) { + var clen int + switch d.vd { + case bincVdString, bincVdByteArray: + clen = d.decLen() + default: + decErr("Invalid d.vd for bytes. Expecting string:0x%x or bytearray:0x%x. Got: 0x%x", + bincVdString, bincVdByteArray, d.vd) + } + if clen > 0 { + // if no contents in stream, don't update the passed byteslice + if len(bs) != clen { + if len(bs) > clen { + bs = bs[:clen] + } else { + bs = make([]byte, clen) + } + bsOut = bs + changed = true + } + d.r.readb(bs) + } + d.bdRead = false + return +} + +func (d *bincDecDriver) decodeExt(verifyTag bool, tag byte) (xtag byte, xbs []byte) { + switch d.vd { + case bincVdCustomExt: + l := d.decLen() + xtag = d.r.readn1() + if verifyTag && xtag != tag { + decErr("Wrong extension tag. Got %b. Expecting: %v", xtag, tag) + } + xbs = d.r.readn(l) + case bincVdByteArray: + xbs, _ = d.decodeBytes(nil) + default: + decErr("Invalid d.vd for extensions (Expecting extensions or byte array). Got: 0x%x", d.vd) + } + d.bdRead = false + return +} + +func (d *bincDecDriver) decodeNaked() (v interface{}, vt valueType, decodeFurther bool) { + d.initReadNext() + + switch d.vd { + case bincVdSpecial: + switch d.vs { + case bincSpNil: + vt = valueTypeNil + case bincSpFalse: + vt = valueTypeBool + v = false + case bincSpTrue: + vt = valueTypeBool + v = true + case bincSpNan: + vt = valueTypeFloat + v = math.NaN() + case bincSpPosInf: + vt = valueTypeFloat + v = math.Inf(1) + case bincSpNegInf: + vt = valueTypeFloat + v = math.Inf(-1) + case bincSpZeroFloat: + vt = valueTypeFloat + v = float64(0) + case bincSpZero: + vt = valueTypeUint + v = int64(0) // int8(0) + case bincSpNegOne: + vt = valueTypeInt + v = int64(-1) // int8(-1) + default: + decErr("decodeNaked: Unrecognized special value 0x%x", d.vs) + } + case bincVdSmallInt: + vt = valueTypeUint + v = uint64(int8(d.vs)) + 1 // int8(d.vs) + 1 + case bincVdPosInt: + vt = valueTypeUint + v = d.decUint() + case bincVdNegInt: + vt = valueTypeInt + v = -(int64(d.decUint())) + case bincVdFloat: + vt = valueTypeFloat + v = d.decFloat() + case bincVdSymbol: + vt = valueTypeSymbol + v = d.decodeString() + case bincVdString: + vt = valueTypeString + v = d.decodeString() + case bincVdByteArray: + vt = valueTypeBytes + v, _ = d.decodeBytes(nil) + case bincVdTimestamp: + vt = valueTypeTimestamp + tt, err := decodeTime(d.r.readn(int(d.vs))) + if err != nil { + panic(err) + } + v = tt + case bincVdCustomExt: + vt = valueTypeExt + l := d.decLen() + var re RawExt + re.Tag = d.r.readn1() + re.Data = d.r.readn(l) + v = &re + vt = valueTypeExt + case bincVdArray: + vt = valueTypeArray + decodeFurther = true + case bincVdMap: + vt = valueTypeMap + decodeFurther = true + default: + decErr("decodeNaked: Unrecognized d.vd: 0x%x", d.vd) + } + + if !decodeFurther { + d.bdRead = false + } + return +} + +//------------------------------------ + +//BincHandle is a Handle for the Binc Schema-Free Encoding Format +//defined at https://github.com/ugorji/binc . +// +//BincHandle currently supports all Binc features with the following EXCEPTIONS: +// - only integers up to 64 bits of precision are supported. +// big integers are unsupported. +// - Only IEEE 754 binary32 and binary64 floats are supported (ie Go float32 and float64 types). +// extended precision and decimal IEEE 754 floats are unsupported. +// - Only UTF-8 strings supported. +// Unicode_Other Binc types (UTF16, UTF32) are currently unsupported. +//Note that these EXCEPTIONS are temporary and full support is possible and may happen soon. +type BincHandle struct { + BasicHandle +} + +func (h *BincHandle) newEncDriver(w encWriter) encDriver { + return &bincEncDriver{w: w} +} + +func (h *BincHandle) newDecDriver(r decReader) decDriver { + return &bincDecDriver{r: r} +} + +func (_ *BincHandle) writeExt() bool { + return true +} + +func (h *BincHandle) getBasicHandle() *BasicHandle { + return &h.BasicHandle +} diff --git a/vendor/github.com/hashicorp/go-msgpack/codec/decode.go b/vendor/github.com/hashicorp/go-msgpack/codec/decode.go new file mode 100644 index 00000000000..87bef2b9358 --- /dev/null +++ b/vendor/github.com/hashicorp/go-msgpack/codec/decode.go @@ -0,0 +1,1048 @@ +// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved. +// Use of this source code is governed by a BSD-style license found in the LICENSE file. + +package codec + +import ( + "io" + "reflect" + // "runtime/debug" +) + +// Some tagging information for error messages. +const ( + msgTagDec = "codec.decoder" + msgBadDesc = "Unrecognized descriptor byte" + msgDecCannotExpandArr = "cannot expand go array from %v to stream length: %v" +) + +// decReader abstracts the reading source, allowing implementations that can +// read from an io.Reader or directly off a byte slice with zero-copying. +type decReader interface { + readn(n int) []byte + readb([]byte) + readn1() uint8 + readUint16() uint16 + readUint32() uint32 + readUint64() uint64 +} + +type decDriver interface { + initReadNext() + tryDecodeAsNil() bool + currentEncodedType() valueType + isBuiltinType(rt uintptr) bool + decodeBuiltin(rt uintptr, v interface{}) + //decodeNaked: Numbers are decoded as int64, uint64, float64 only (no smaller sized number types). + decodeNaked() (v interface{}, vt valueType, decodeFurther bool) + decodeInt(bitsize uint8) (i int64) + decodeUint(bitsize uint8) (ui uint64) + decodeFloat(chkOverflow32 bool) (f float64) + decodeBool() (b bool) + // decodeString can also decode symbols + decodeString() (s string) + decodeBytes(bs []byte) (bsOut []byte, changed bool) + decodeExt(verifyTag bool, tag byte) (xtag byte, xbs []byte) + readMapLen() int + readArrayLen() int +} + +type DecodeOptions struct { + // An instance of MapType is used during schema-less decoding of a map in the stream. + // If nil, we use map[interface{}]interface{} + MapType reflect.Type + // An instance of SliceType is used during schema-less decoding of an array in the stream. + // If nil, we use []interface{} + SliceType reflect.Type + // ErrorIfNoField controls whether an error is returned when decoding a map + // from a codec stream into a struct, and no matching struct field is found. + ErrorIfNoField bool +} + +// ------------------------------------ + +// ioDecReader is a decReader that reads off an io.Reader +type ioDecReader struct { + r io.Reader + br io.ByteReader + x [8]byte //temp byte array re-used internally for efficiency +} + +func (z *ioDecReader) readn(n int) (bs []byte) { + if n <= 0 { + return + } + bs = make([]byte, n) + if _, err := io.ReadAtLeast(z.r, bs, n); err != nil { + panic(err) + } + return +} + +func (z *ioDecReader) readb(bs []byte) { + if _, err := io.ReadAtLeast(z.r, bs, len(bs)); err != nil { + panic(err) + } +} + +func (z *ioDecReader) readn1() uint8 { + if z.br != nil { + b, err := z.br.ReadByte() + if err != nil { + panic(err) + } + return b + } + z.readb(z.x[:1]) + return z.x[0] +} + +func (z *ioDecReader) readUint16() uint16 { + z.readb(z.x[:2]) + return bigen.Uint16(z.x[:2]) +} + +func (z *ioDecReader) readUint32() uint32 { + z.readb(z.x[:4]) + return bigen.Uint32(z.x[:4]) +} + +func (z *ioDecReader) readUint64() uint64 { + z.readb(z.x[:8]) + return bigen.Uint64(z.x[:8]) +} + +// ------------------------------------ + +// bytesDecReader is a decReader that reads off a byte slice with zero copying +type bytesDecReader struct { + b []byte // data + c int // cursor + a int // available +} + +func (z *bytesDecReader) consume(n int) (oldcursor int) { + if z.a == 0 { + panic(io.EOF) + } + if n > z.a { + decErr("Trying to read %v bytes. Only %v available", n, z.a) + } + // z.checkAvailable(n) + oldcursor = z.c + z.c = oldcursor + n + z.a = z.a - n + return +} + +func (z *bytesDecReader) readn(n int) (bs []byte) { + if n <= 0 { + return + } + c0 := z.consume(n) + bs = z.b[c0:z.c] + return +} + +func (z *bytesDecReader) readb(bs []byte) { + copy(bs, z.readn(len(bs))) +} + +func (z *bytesDecReader) readn1() uint8 { + c0 := z.consume(1) + return z.b[c0] +} + +// Use binaryEncoding helper for 4 and 8 bits, but inline it for 2 bits +// creating temp slice variable and copying it to helper function is expensive +// for just 2 bits. + +func (z *bytesDecReader) readUint16() uint16 { + c0 := z.consume(2) + return uint16(z.b[c0+1]) | uint16(z.b[c0])<<8 +} + +func (z *bytesDecReader) readUint32() uint32 { + c0 := z.consume(4) + return bigen.Uint32(z.b[c0:z.c]) +} + +func (z *bytesDecReader) readUint64() uint64 { + c0 := z.consume(8) + return bigen.Uint64(z.b[c0:z.c]) +} + +// ------------------------------------ + +// decFnInfo has methods for registering handling decoding of a specific type +// based on some characteristics (builtin, extension, reflect Kind, etc) +type decFnInfo struct { + ti *typeInfo + d *Decoder + dd decDriver + xfFn func(reflect.Value, []byte) error + xfTag byte + array bool +} + +func (f *decFnInfo) builtin(rv reflect.Value) { + f.dd.decodeBuiltin(f.ti.rtid, rv.Addr().Interface()) +} + +func (f *decFnInfo) rawExt(rv reflect.Value) { + xtag, xbs := f.dd.decodeExt(false, 0) + rv.Field(0).SetUint(uint64(xtag)) + rv.Field(1).SetBytes(xbs) +} + +func (f *decFnInfo) ext(rv reflect.Value) { + _, xbs := f.dd.decodeExt(true, f.xfTag) + if fnerr := f.xfFn(rv, xbs); fnerr != nil { + panic(fnerr) + } +} + +func (f *decFnInfo) binaryMarshal(rv reflect.Value) { + var bm binaryUnmarshaler + if f.ti.unmIndir == -1 { + bm = rv.Addr().Interface().(binaryUnmarshaler) + } else if f.ti.unmIndir == 0 { + bm = rv.Interface().(binaryUnmarshaler) + } else { + for j, k := int8(0), f.ti.unmIndir; j < k; j++ { + if rv.IsNil() { + rv.Set(reflect.New(rv.Type().Elem())) + } + rv = rv.Elem() + } + bm = rv.Interface().(binaryUnmarshaler) + } + xbs, _ := f.dd.decodeBytes(nil) + if fnerr := bm.UnmarshalBinary(xbs); fnerr != nil { + panic(fnerr) + } +} + +func (f *decFnInfo) kErr(rv reflect.Value) { + decErr("Unhandled value for kind: %v: %s", rv.Kind(), msgBadDesc) +} + +func (f *decFnInfo) kString(rv reflect.Value) { + rv.SetString(f.dd.decodeString()) +} + +func (f *decFnInfo) kBool(rv reflect.Value) { + rv.SetBool(f.dd.decodeBool()) +} + +func (f *decFnInfo) kInt(rv reflect.Value) { + rv.SetInt(f.dd.decodeInt(intBitsize)) +} + +func (f *decFnInfo) kInt64(rv reflect.Value) { + rv.SetInt(f.dd.decodeInt(64)) +} + +func (f *decFnInfo) kInt32(rv reflect.Value) { + rv.SetInt(f.dd.decodeInt(32)) +} + +func (f *decFnInfo) kInt8(rv reflect.Value) { + rv.SetInt(f.dd.decodeInt(8)) +} + +func (f *decFnInfo) kInt16(rv reflect.Value) { + rv.SetInt(f.dd.decodeInt(16)) +} + +func (f *decFnInfo) kFloat32(rv reflect.Value) { + rv.SetFloat(f.dd.decodeFloat(true)) +} + +func (f *decFnInfo) kFloat64(rv reflect.Value) { + rv.SetFloat(f.dd.decodeFloat(false)) +} + +func (f *decFnInfo) kUint8(rv reflect.Value) { + rv.SetUint(f.dd.decodeUint(8)) +} + +func (f *decFnInfo) kUint64(rv reflect.Value) { + rv.SetUint(f.dd.decodeUint(64)) +} + +func (f *decFnInfo) kUint(rv reflect.Value) { + rv.SetUint(f.dd.decodeUint(uintBitsize)) +} + +func (f *decFnInfo) kUint32(rv reflect.Value) { + rv.SetUint(f.dd.decodeUint(32)) +} + +func (f *decFnInfo) kUint16(rv reflect.Value) { + rv.SetUint(f.dd.decodeUint(16)) +} + +// func (f *decFnInfo) kPtr(rv reflect.Value) { +// debugf(">>>>>>> ??? decode kPtr called - shouldn't get called") +// if rv.IsNil() { +// rv.Set(reflect.New(rv.Type().Elem())) +// } +// f.d.decodeValue(rv.Elem()) +// } + +func (f *decFnInfo) kInterface(rv reflect.Value) { + // debugf("\t===> kInterface") + if !rv.IsNil() { + f.d.decodeValue(rv.Elem()) + return + } + // nil interface: + // use some hieristics to set the nil interface to an + // appropriate value based on the first byte read (byte descriptor bd) + v, vt, decodeFurther := f.dd.decodeNaked() + if vt == valueTypeNil { + return + } + // Cannot decode into nil interface with methods (e.g. error, io.Reader, etc) + // if non-nil value in stream. + if num := f.ti.rt.NumMethod(); num > 0 { + decErr("decodeValue: Cannot decode non-nil codec value into nil %v (%v methods)", + f.ti.rt, num) + } + var rvn reflect.Value + var useRvn bool + switch vt { + case valueTypeMap: + if f.d.h.MapType == nil { + var m2 map[interface{}]interface{} + v = &m2 + } else { + rvn = reflect.New(f.d.h.MapType).Elem() + useRvn = true + } + case valueTypeArray: + if f.d.h.SliceType == nil { + var m2 []interface{} + v = &m2 + } else { + rvn = reflect.New(f.d.h.SliceType).Elem() + useRvn = true + } + case valueTypeExt: + re := v.(*RawExt) + var bfn func(reflect.Value, []byte) error + rvn, bfn = f.d.h.getDecodeExtForTag(re.Tag) + if bfn == nil { + rvn = reflect.ValueOf(*re) + } else if fnerr := bfn(rvn, re.Data); fnerr != nil { + panic(fnerr) + } + rv.Set(rvn) + return + } + if decodeFurther { + if useRvn { + f.d.decodeValue(rvn) + } else if v != nil { + // this v is a pointer, so we need to dereference it when done + f.d.decode(v) + rvn = reflect.ValueOf(v).Elem() + useRvn = true + } + } + if useRvn { + rv.Set(rvn) + } else if v != nil { + rv.Set(reflect.ValueOf(v)) + } +} + +func (f *decFnInfo) kStruct(rv reflect.Value) { + fti := f.ti + if currEncodedType := f.dd.currentEncodedType(); currEncodedType == valueTypeMap { + containerLen := f.dd.readMapLen() + if containerLen == 0 { + return + } + tisfi := fti.sfi + for j := 0; j < containerLen; j++ { + // var rvkencname string + // ddecode(&rvkencname) + f.dd.initReadNext() + rvkencname := f.dd.decodeString() + // rvksi := ti.getForEncName(rvkencname) + if k := fti.indexForEncName(rvkencname); k > -1 { + sfik := tisfi[k] + if sfik.i != -1 { + f.d.decodeValue(rv.Field(int(sfik.i))) + } else { + f.d.decEmbeddedField(rv, sfik.is) + } + // f.d.decodeValue(ti.field(k, rv)) + } else { + if f.d.h.ErrorIfNoField { + decErr("No matching struct field found when decoding stream map with key: %v", + rvkencname) + } else { + var nilintf0 interface{} + f.d.decodeValue(reflect.ValueOf(&nilintf0).Elem()) + } + } + } + } else if currEncodedType == valueTypeArray { + containerLen := f.dd.readArrayLen() + if containerLen == 0 { + return + } + for j, si := range fti.sfip { + if j == containerLen { + break + } + if si.i != -1 { + f.d.decodeValue(rv.Field(int(si.i))) + } else { + f.d.decEmbeddedField(rv, si.is) + } + } + if containerLen > len(fti.sfip) { + // read remaining values and throw away + for j := len(fti.sfip); j < containerLen; j++ { + var nilintf0 interface{} + f.d.decodeValue(reflect.ValueOf(&nilintf0).Elem()) + } + } + } else { + decErr("Only encoded map or array can be decoded into a struct. (valueType: %x)", + currEncodedType) + } +} + +func (f *decFnInfo) kSlice(rv reflect.Value) { + // A slice can be set from a map or array in stream. + currEncodedType := f.dd.currentEncodedType() + + switch currEncodedType { + case valueTypeBytes, valueTypeString: + if f.ti.rtid == uint8SliceTypId || f.ti.rt.Elem().Kind() == reflect.Uint8 { + if bs2, changed2 := f.dd.decodeBytes(rv.Bytes()); changed2 { + rv.SetBytes(bs2) + } + return + } + } + + if shortCircuitReflectToFastPath && rv.CanAddr() { + switch f.ti.rtid { + case intfSliceTypId: + f.d.decSliceIntf(rv.Addr().Interface().(*[]interface{}), currEncodedType, f.array) + return + case uint64SliceTypId: + f.d.decSliceUint64(rv.Addr().Interface().(*[]uint64), currEncodedType, f.array) + return + case int64SliceTypId: + f.d.decSliceInt64(rv.Addr().Interface().(*[]int64), currEncodedType, f.array) + return + case strSliceTypId: + f.d.decSliceStr(rv.Addr().Interface().(*[]string), currEncodedType, f.array) + return + } + } + + containerLen, containerLenS := decContLens(f.dd, currEncodedType) + + // an array can never return a nil slice. so no need to check f.array here. + + if rv.IsNil() { + rv.Set(reflect.MakeSlice(f.ti.rt, containerLenS, containerLenS)) + } + + if containerLen == 0 { + return + } + + if rvcap, rvlen := rv.Len(), rv.Cap(); containerLenS > rvcap { + if f.array { // !rv.CanSet() + decErr(msgDecCannotExpandArr, rvcap, containerLenS) + } + rvn := reflect.MakeSlice(f.ti.rt, containerLenS, containerLenS) + if rvlen > 0 { + reflect.Copy(rvn, rv) + } + rv.Set(rvn) + } else if containerLenS > rvlen { + rv.SetLen(containerLenS) + } + + for j := 0; j < containerLenS; j++ { + f.d.decodeValue(rv.Index(j)) + } +} + +func (f *decFnInfo) kArray(rv reflect.Value) { + // f.d.decodeValue(rv.Slice(0, rv.Len())) + f.kSlice(rv.Slice(0, rv.Len())) +} + +func (f *decFnInfo) kMap(rv reflect.Value) { + if shortCircuitReflectToFastPath && rv.CanAddr() { + switch f.ti.rtid { + case mapStrIntfTypId: + f.d.decMapStrIntf(rv.Addr().Interface().(*map[string]interface{})) + return + case mapIntfIntfTypId: + f.d.decMapIntfIntf(rv.Addr().Interface().(*map[interface{}]interface{})) + return + case mapInt64IntfTypId: + f.d.decMapInt64Intf(rv.Addr().Interface().(*map[int64]interface{})) + return + case mapUint64IntfTypId: + f.d.decMapUint64Intf(rv.Addr().Interface().(*map[uint64]interface{})) + return + } + } + + containerLen := f.dd.readMapLen() + + if rv.IsNil() { + rv.Set(reflect.MakeMap(f.ti.rt)) + } + + if containerLen == 0 { + return + } + + ktype, vtype := f.ti.rt.Key(), f.ti.rt.Elem() + ktypeId := reflect.ValueOf(ktype).Pointer() + for j := 0; j < containerLen; j++ { + rvk := reflect.New(ktype).Elem() + f.d.decodeValue(rvk) + + // special case if a byte array. + // if ktype == intfTyp { + if ktypeId == intfTypId { + rvk = rvk.Elem() + if rvk.Type() == uint8SliceTyp { + rvk = reflect.ValueOf(string(rvk.Bytes())) + } + } + rvv := rv.MapIndex(rvk) + if !rvv.IsValid() { + rvv = reflect.New(vtype).Elem() + } + + f.d.decodeValue(rvv) + rv.SetMapIndex(rvk, rvv) + } +} + +// ---------------------------------------- + +type decFn struct { + i *decFnInfo + f func(*decFnInfo, reflect.Value) +} + +// A Decoder reads and decodes an object from an input stream in the codec format. +type Decoder struct { + r decReader + d decDriver + h *BasicHandle + f map[uintptr]decFn + x []uintptr + s []decFn +} + +// NewDecoder returns a Decoder for decoding a stream of bytes from an io.Reader. +// +// For efficiency, Users are encouraged to pass in a memory buffered writer +// (eg bufio.Reader, bytes.Buffer). +func NewDecoder(r io.Reader, h Handle) *Decoder { + z := ioDecReader{ + r: r, + } + z.br, _ = r.(io.ByteReader) + return &Decoder{r: &z, d: h.newDecDriver(&z), h: h.getBasicHandle()} +} + +// NewDecoderBytes returns a Decoder which efficiently decodes directly +// from a byte slice with zero copying. +func NewDecoderBytes(in []byte, h Handle) *Decoder { + z := bytesDecReader{ + b: in, + a: len(in), + } + return &Decoder{r: &z, d: h.newDecDriver(&z), h: h.getBasicHandle()} +} + +// Decode decodes the stream from reader and stores the result in the +// value pointed to by v. v cannot be a nil pointer. v can also be +// a reflect.Value of a pointer. +// +// Note that a pointer to a nil interface is not a nil pointer. +// If you do not know what type of stream it is, pass in a pointer to a nil interface. +// We will decode and store a value in that nil interface. +// +// Sample usages: +// // Decoding into a non-nil typed value +// var f float32 +// err = codec.NewDecoder(r, handle).Decode(&f) +// +// // Decoding into nil interface +// var v interface{} +// dec := codec.NewDecoder(r, handle) +// err = dec.Decode(&v) +// +// When decoding into a nil interface{}, we will decode into an appropriate value based +// on the contents of the stream: +// - Numbers are decoded as float64, int64 or uint64. +// - Other values are decoded appropriately depending on the type: +// bool, string, []byte, time.Time, etc +// - Extensions are decoded as RawExt (if no ext function registered for the tag) +// Configurations exist on the Handle to override defaults +// (e.g. for MapType, SliceType and how to decode raw bytes). +// +// When decoding into a non-nil interface{} value, the mode of encoding is based on the +// type of the value. When a value is seen: +// - If an extension is registered for it, call that extension function +// - If it implements BinaryUnmarshaler, call its UnmarshalBinary(data []byte) error +// - Else decode it based on its reflect.Kind +// +// There are some special rules when decoding into containers (slice/array/map/struct). +// Decode will typically use the stream contents to UPDATE the container. +// - A map can be decoded from a stream map, by updating matching keys. +// - A slice can be decoded from a stream array, +// by updating the first n elements, where n is length of the stream. +// - A slice can be decoded from a stream map, by decoding as if +// it contains a sequence of key-value pairs. +// - A struct can be decoded from a stream map, by updating matching fields. +// - A struct can be decoded from a stream array, +// by updating fields as they occur in the struct (by index). +// +// When decoding a stream map or array with length of 0 into a nil map or slice, +// we reset the destination map or slice to a zero-length value. +// +// However, when decoding a stream nil, we reset the destination container +// to its "zero" value (e.g. nil for slice/map, etc). +// +func (d *Decoder) Decode(v interface{}) (err error) { + defer panicToErr(&err) + d.decode(v) + return +} + +func (d *Decoder) decode(iv interface{}) { + d.d.initReadNext() + + switch v := iv.(type) { + case nil: + decErr("Cannot decode into nil.") + + case reflect.Value: + d.chkPtrValue(v) + d.decodeValue(v.Elem()) + + case *string: + *v = d.d.decodeString() + case *bool: + *v = d.d.decodeBool() + case *int: + *v = int(d.d.decodeInt(intBitsize)) + case *int8: + *v = int8(d.d.decodeInt(8)) + case *int16: + *v = int16(d.d.decodeInt(16)) + case *int32: + *v = int32(d.d.decodeInt(32)) + case *int64: + *v = d.d.decodeInt(64) + case *uint: + *v = uint(d.d.decodeUint(uintBitsize)) + case *uint8: + *v = uint8(d.d.decodeUint(8)) + case *uint16: + *v = uint16(d.d.decodeUint(16)) + case *uint32: + *v = uint32(d.d.decodeUint(32)) + case *uint64: + *v = d.d.decodeUint(64) + case *float32: + *v = float32(d.d.decodeFloat(true)) + case *float64: + *v = d.d.decodeFloat(false) + case *[]byte: + *v, _ = d.d.decodeBytes(*v) + + case *[]interface{}: + d.decSliceIntf(v, valueTypeInvalid, false) + case *[]uint64: + d.decSliceUint64(v, valueTypeInvalid, false) + case *[]int64: + d.decSliceInt64(v, valueTypeInvalid, false) + case *[]string: + d.decSliceStr(v, valueTypeInvalid, false) + case *map[string]interface{}: + d.decMapStrIntf(v) + case *map[interface{}]interface{}: + d.decMapIntfIntf(v) + case *map[uint64]interface{}: + d.decMapUint64Intf(v) + case *map[int64]interface{}: + d.decMapInt64Intf(v) + + case *interface{}: + d.decodeValue(reflect.ValueOf(iv).Elem()) + + default: + rv := reflect.ValueOf(iv) + d.chkPtrValue(rv) + d.decodeValue(rv.Elem()) + } +} + +func (d *Decoder) decodeValue(rv reflect.Value) { + d.d.initReadNext() + + if d.d.tryDecodeAsNil() { + // If value in stream is nil, set the dereferenced value to its "zero" value (if settable) + if rv.Kind() == reflect.Ptr { + if !rv.IsNil() { + rv.Set(reflect.Zero(rv.Type())) + } + return + } + // for rv.Kind() == reflect.Ptr { + // rv = rv.Elem() + // } + if rv.IsValid() { // rv.CanSet() // always settable, except it's invalid + rv.Set(reflect.Zero(rv.Type())) + } + return + } + + // If stream is not containing a nil value, then we can deref to the base + // non-pointer value, and decode into that. + for rv.Kind() == reflect.Ptr { + if rv.IsNil() { + rv.Set(reflect.New(rv.Type().Elem())) + } + rv = rv.Elem() + } + + rt := rv.Type() + rtid := reflect.ValueOf(rt).Pointer() + + // retrieve or register a focus'ed function for this type + // to eliminate need to do the retrieval multiple times + + // if d.f == nil && d.s == nil { debugf("---->Creating new dec f map for type: %v\n", rt) } + var fn decFn + var ok bool + if useMapForCodecCache { + fn, ok = d.f[rtid] + } else { + for i, v := range d.x { + if v == rtid { + fn, ok = d.s[i], true + break + } + } + } + if !ok { + // debugf("\tCreating new dec fn for type: %v\n", rt) + fi := decFnInfo{ti: getTypeInfo(rtid, rt), d: d, dd: d.d} + fn.i = &fi + // An extension can be registered for any type, regardless of the Kind + // (e.g. type BitSet int64, type MyStruct { / * unexported fields * / }, type X []int, etc. + // + // We can't check if it's an extension byte here first, because the user may have + // registered a pointer or non-pointer type, meaning we may have to recurse first + // before matching a mapped type, even though the extension byte is already detected. + // + // NOTE: if decoding into a nil interface{}, we return a non-nil + // value except even if the container registers a length of 0. + if rtid == rawExtTypId { + fn.f = (*decFnInfo).rawExt + } else if d.d.isBuiltinType(rtid) { + fn.f = (*decFnInfo).builtin + } else if xfTag, xfFn := d.h.getDecodeExt(rtid); xfFn != nil { + fi.xfTag, fi.xfFn = xfTag, xfFn + fn.f = (*decFnInfo).ext + } else if supportBinaryMarshal && fi.ti.unm { + fn.f = (*decFnInfo).binaryMarshal + } else { + switch rk := rt.Kind(); rk { + case reflect.String: + fn.f = (*decFnInfo).kString + case reflect.Bool: + fn.f = (*decFnInfo).kBool + case reflect.Int: + fn.f = (*decFnInfo).kInt + case reflect.Int64: + fn.f = (*decFnInfo).kInt64 + case reflect.Int32: + fn.f = (*decFnInfo).kInt32 + case reflect.Int8: + fn.f = (*decFnInfo).kInt8 + case reflect.Int16: + fn.f = (*decFnInfo).kInt16 + case reflect.Float32: + fn.f = (*decFnInfo).kFloat32 + case reflect.Float64: + fn.f = (*decFnInfo).kFloat64 + case reflect.Uint8: + fn.f = (*decFnInfo).kUint8 + case reflect.Uint64: + fn.f = (*decFnInfo).kUint64 + case reflect.Uint: + fn.f = (*decFnInfo).kUint + case reflect.Uint32: + fn.f = (*decFnInfo).kUint32 + case reflect.Uint16: + fn.f = (*decFnInfo).kUint16 + // case reflect.Ptr: + // fn.f = (*decFnInfo).kPtr + case reflect.Interface: + fn.f = (*decFnInfo).kInterface + case reflect.Struct: + fn.f = (*decFnInfo).kStruct + case reflect.Slice: + fn.f = (*decFnInfo).kSlice + case reflect.Array: + fi.array = true + fn.f = (*decFnInfo).kArray + case reflect.Map: + fn.f = (*decFnInfo).kMap + default: + fn.f = (*decFnInfo).kErr + } + } + if useMapForCodecCache { + if d.f == nil { + d.f = make(map[uintptr]decFn, 16) + } + d.f[rtid] = fn + } else { + d.s = append(d.s, fn) + d.x = append(d.x, rtid) + } + } + + fn.f(fn.i, rv) + + return +} + +func (d *Decoder) chkPtrValue(rv reflect.Value) { + // We can only decode into a non-nil pointer + if rv.Kind() == reflect.Ptr && !rv.IsNil() { + return + } + if !rv.IsValid() { + decErr("Cannot decode into a zero (ie invalid) reflect.Value") + } + if !rv.CanInterface() { + decErr("Cannot decode into a value without an interface: %v", rv) + } + rvi := rv.Interface() + decErr("Cannot decode into non-pointer or nil pointer. Got: %v, %T, %v", + rv.Kind(), rvi, rvi) +} + +func (d *Decoder) decEmbeddedField(rv reflect.Value, index []int) { + // d.decodeValue(rv.FieldByIndex(index)) + // nil pointers may be here; so reproduce FieldByIndex logic + enhancements + for _, j := range index { + if rv.Kind() == reflect.Ptr { + if rv.IsNil() { + rv.Set(reflect.New(rv.Type().Elem())) + } + // If a pointer, it must be a pointer to struct (based on typeInfo contract) + rv = rv.Elem() + } + rv = rv.Field(j) + } + d.decodeValue(rv) +} + +// -------------------------------------------------- + +// short circuit functions for common maps and slices + +func (d *Decoder) decSliceIntf(v *[]interface{}, currEncodedType valueType, doNotReset bool) { + _, containerLenS := decContLens(d.d, currEncodedType) + s := *v + if s == nil { + s = make([]interface{}, containerLenS, containerLenS) + } else if containerLenS > cap(s) { + if doNotReset { + decErr(msgDecCannotExpandArr, cap(s), containerLenS) + } + s = make([]interface{}, containerLenS, containerLenS) + copy(s, *v) + } else if containerLenS > len(s) { + s = s[:containerLenS] + } + for j := 0; j < containerLenS; j++ { + d.decode(&s[j]) + } + *v = s +} + +func (d *Decoder) decSliceInt64(v *[]int64, currEncodedType valueType, doNotReset bool) { + _, containerLenS := decContLens(d.d, currEncodedType) + s := *v + if s == nil { + s = make([]int64, containerLenS, containerLenS) + } else if containerLenS > cap(s) { + if doNotReset { + decErr(msgDecCannotExpandArr, cap(s), containerLenS) + } + s = make([]int64, containerLenS, containerLenS) + copy(s, *v) + } else if containerLenS > len(s) { + s = s[:containerLenS] + } + for j := 0; j < containerLenS; j++ { + // d.decode(&s[j]) + d.d.initReadNext() + s[j] = d.d.decodeInt(intBitsize) + } + *v = s +} + +func (d *Decoder) decSliceUint64(v *[]uint64, currEncodedType valueType, doNotReset bool) { + _, containerLenS := decContLens(d.d, currEncodedType) + s := *v + if s == nil { + s = make([]uint64, containerLenS, containerLenS) + } else if containerLenS > cap(s) { + if doNotReset { + decErr(msgDecCannotExpandArr, cap(s), containerLenS) + } + s = make([]uint64, containerLenS, containerLenS) + copy(s, *v) + } else if containerLenS > len(s) { + s = s[:containerLenS] + } + for j := 0; j < containerLenS; j++ { + // d.decode(&s[j]) + d.d.initReadNext() + s[j] = d.d.decodeUint(intBitsize) + } + *v = s +} + +func (d *Decoder) decSliceStr(v *[]string, currEncodedType valueType, doNotReset bool) { + _, containerLenS := decContLens(d.d, currEncodedType) + s := *v + if s == nil { + s = make([]string, containerLenS, containerLenS) + } else if containerLenS > cap(s) { + if doNotReset { + decErr(msgDecCannotExpandArr, cap(s), containerLenS) + } + s = make([]string, containerLenS, containerLenS) + copy(s, *v) + } else if containerLenS > len(s) { + s = s[:containerLenS] + } + for j := 0; j < containerLenS; j++ { + // d.decode(&s[j]) + d.d.initReadNext() + s[j] = d.d.decodeString() + } + *v = s +} + +func (d *Decoder) decMapIntfIntf(v *map[interface{}]interface{}) { + containerLen := d.d.readMapLen() + m := *v + if m == nil { + m = make(map[interface{}]interface{}, containerLen) + *v = m + } + for j := 0; j < containerLen; j++ { + var mk interface{} + d.decode(&mk) + // special case if a byte array. + if bv, bok := mk.([]byte); bok { + mk = string(bv) + } + mv := m[mk] + d.decode(&mv) + m[mk] = mv + } +} + +func (d *Decoder) decMapInt64Intf(v *map[int64]interface{}) { + containerLen := d.d.readMapLen() + m := *v + if m == nil { + m = make(map[int64]interface{}, containerLen) + *v = m + } + for j := 0; j < containerLen; j++ { + d.d.initReadNext() + mk := d.d.decodeInt(intBitsize) + mv := m[mk] + d.decode(&mv) + m[mk] = mv + } +} + +func (d *Decoder) decMapUint64Intf(v *map[uint64]interface{}) { + containerLen := d.d.readMapLen() + m := *v + if m == nil { + m = make(map[uint64]interface{}, containerLen) + *v = m + } + for j := 0; j < containerLen; j++ { + d.d.initReadNext() + mk := d.d.decodeUint(intBitsize) + mv := m[mk] + d.decode(&mv) + m[mk] = mv + } +} + +func (d *Decoder) decMapStrIntf(v *map[string]interface{}) { + containerLen := d.d.readMapLen() + m := *v + if m == nil { + m = make(map[string]interface{}, containerLen) + *v = m + } + for j := 0; j < containerLen; j++ { + d.d.initReadNext() + mk := d.d.decodeString() + mv := m[mk] + d.decode(&mv) + m[mk] = mv + } +} + +// ---------------------------------------- + +func decContLens(dd decDriver, currEncodedType valueType) (containerLen, containerLenS int) { + if currEncodedType == valueTypeInvalid { + currEncodedType = dd.currentEncodedType() + } + switch currEncodedType { + case valueTypeArray: + containerLen = dd.readArrayLen() + containerLenS = containerLen + case valueTypeMap: + containerLen = dd.readMapLen() + containerLenS = containerLen * 2 + default: + decErr("Only encoded map or array can be decoded into a slice. (valueType: %0x)", + currEncodedType) + } + return +} + +func decErr(format string, params ...interface{}) { + doPanic(msgTagDec, format, params...) +} diff --git a/vendor/github.com/hashicorp/go-msgpack/codec/encode.go b/vendor/github.com/hashicorp/go-msgpack/codec/encode.go new file mode 100644 index 00000000000..4914be0c748 --- /dev/null +++ b/vendor/github.com/hashicorp/go-msgpack/codec/encode.go @@ -0,0 +1,1001 @@ +// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved. +// Use of this source code is governed by a BSD-style license found in the LICENSE file. + +package codec + +import ( + "io" + "reflect" +) + +const ( + // Some tagging information for error messages. + msgTagEnc = "codec.encoder" + defEncByteBufSize = 1 << 6 // 4:16, 6:64, 8:256, 10:1024 + // maxTimeSecs32 = math.MaxInt32 / 60 / 24 / 366 +) + +// AsSymbolFlag defines what should be encoded as symbols. +type AsSymbolFlag uint8 + +const ( + // AsSymbolDefault is default. + // Currently, this means only encode struct field names as symbols. + // The default is subject to change. + AsSymbolDefault AsSymbolFlag = iota + + // AsSymbolAll means encode anything which could be a symbol as a symbol. + AsSymbolAll = 0xfe + + // AsSymbolNone means do not encode anything as a symbol. + AsSymbolNone = 1 << iota + + // AsSymbolMapStringKeys means encode keys in map[string]XXX as symbols. + AsSymbolMapStringKeysFlag + + // AsSymbolStructFieldName means encode struct field names as symbols. + AsSymbolStructFieldNameFlag +) + +// encWriter abstracting writing to a byte array or to an io.Writer. +type encWriter interface { + writeUint16(uint16) + writeUint32(uint32) + writeUint64(uint64) + writeb([]byte) + writestr(string) + writen1(byte) + writen2(byte, byte) + atEndOfEncode() +} + +// encDriver abstracts the actual codec (binc vs msgpack, etc) +type encDriver interface { + isBuiltinType(rt uintptr) bool + encodeBuiltin(rt uintptr, v interface{}) + encodeNil() + encodeInt(i int64) + encodeUint(i uint64) + encodeBool(b bool) + encodeFloat32(f float32) + encodeFloat64(f float64) + encodeExtPreamble(xtag byte, length int) + encodeArrayPreamble(length int) + encodeMapPreamble(length int) + encodeString(c charEncoding, v string) + encodeSymbol(v string) + encodeStringBytes(c charEncoding, v []byte) + //TODO + //encBignum(f *big.Int) + //encStringRunes(c charEncoding, v []rune) +} + +type ioEncWriterWriter interface { + WriteByte(c byte) error + WriteString(s string) (n int, err error) + Write(p []byte) (n int, err error) +} + +type ioEncStringWriter interface { + WriteString(s string) (n int, err error) +} + +type EncodeOptions struct { + // Encode a struct as an array, and not as a map. + StructToArray bool + + // AsSymbols defines what should be encoded as symbols. + // + // Encoding as symbols can reduce the encoded size significantly. + // + // However, during decoding, each string to be encoded as a symbol must + // be checked to see if it has been seen before. Consequently, encoding time + // will increase if using symbols, because string comparisons has a clear cost. + // + // Sample values: + // AsSymbolNone + // AsSymbolAll + // AsSymbolMapStringKeys + // AsSymbolMapStringKeysFlag | AsSymbolStructFieldNameFlag + AsSymbols AsSymbolFlag +} + +// --------------------------------------------- + +type simpleIoEncWriterWriter struct { + w io.Writer + bw io.ByteWriter + sw ioEncStringWriter +} + +func (o *simpleIoEncWriterWriter) WriteByte(c byte) (err error) { + if o.bw != nil { + return o.bw.WriteByte(c) + } + _, err = o.w.Write([]byte{c}) + return +} + +func (o *simpleIoEncWriterWriter) WriteString(s string) (n int, err error) { + if o.sw != nil { + return o.sw.WriteString(s) + } + return o.w.Write([]byte(s)) +} + +func (o *simpleIoEncWriterWriter) Write(p []byte) (n int, err error) { + return o.w.Write(p) +} + +// ---------------------------------------- + +// ioEncWriter implements encWriter and can write to an io.Writer implementation +type ioEncWriter struct { + w ioEncWriterWriter + x [8]byte // temp byte array re-used internally for efficiency +} + +func (z *ioEncWriter) writeUint16(v uint16) { + bigen.PutUint16(z.x[:2], v) + z.writeb(z.x[:2]) +} + +func (z *ioEncWriter) writeUint32(v uint32) { + bigen.PutUint32(z.x[:4], v) + z.writeb(z.x[:4]) +} + +func (z *ioEncWriter) writeUint64(v uint64) { + bigen.PutUint64(z.x[:8], v) + z.writeb(z.x[:8]) +} + +func (z *ioEncWriter) writeb(bs []byte) { + if len(bs) == 0 { + return + } + n, err := z.w.Write(bs) + if err != nil { + panic(err) + } + if n != len(bs) { + encErr("write: Incorrect num bytes written. Expecting: %v, Wrote: %v", len(bs), n) + } +} + +func (z *ioEncWriter) writestr(s string) { + n, err := z.w.WriteString(s) + if err != nil { + panic(err) + } + if n != len(s) { + encErr("write: Incorrect num bytes written. Expecting: %v, Wrote: %v", len(s), n) + } +} + +func (z *ioEncWriter) writen1(b byte) { + if err := z.w.WriteByte(b); err != nil { + panic(err) + } +} + +func (z *ioEncWriter) writen2(b1 byte, b2 byte) { + z.writen1(b1) + z.writen1(b2) +} + +func (z *ioEncWriter) atEndOfEncode() {} + +// ---------------------------------------- + +// bytesEncWriter implements encWriter and can write to an byte slice. +// It is used by Marshal function. +type bytesEncWriter struct { + b []byte + c int // cursor + out *[]byte // write out on atEndOfEncode +} + +func (z *bytesEncWriter) writeUint16(v uint16) { + c := z.grow(2) + z.b[c] = byte(v >> 8) + z.b[c+1] = byte(v) +} + +func (z *bytesEncWriter) writeUint32(v uint32) { + c := z.grow(4) + z.b[c] = byte(v >> 24) + z.b[c+1] = byte(v >> 16) + z.b[c+2] = byte(v >> 8) + z.b[c+3] = byte(v) +} + +func (z *bytesEncWriter) writeUint64(v uint64) { + c := z.grow(8) + z.b[c] = byte(v >> 56) + z.b[c+1] = byte(v >> 48) + z.b[c+2] = byte(v >> 40) + z.b[c+3] = byte(v >> 32) + z.b[c+4] = byte(v >> 24) + z.b[c+5] = byte(v >> 16) + z.b[c+6] = byte(v >> 8) + z.b[c+7] = byte(v) +} + +func (z *bytesEncWriter) writeb(s []byte) { + if len(s) == 0 { + return + } + c := z.grow(len(s)) + copy(z.b[c:], s) +} + +func (z *bytesEncWriter) writestr(s string) { + c := z.grow(len(s)) + copy(z.b[c:], s) +} + +func (z *bytesEncWriter) writen1(b1 byte) { + c := z.grow(1) + z.b[c] = b1 +} + +func (z *bytesEncWriter) writen2(b1 byte, b2 byte) { + c := z.grow(2) + z.b[c] = b1 + z.b[c+1] = b2 +} + +func (z *bytesEncWriter) atEndOfEncode() { + *(z.out) = z.b[:z.c] +} + +func (z *bytesEncWriter) grow(n int) (oldcursor int) { + oldcursor = z.c + z.c = oldcursor + n + if z.c > cap(z.b) { + // Tried using appendslice logic: (if cap < 1024, *2, else *1.25). + // However, it was too expensive, causing too many iterations of copy. + // Using bytes.Buffer model was much better (2*cap + n) + bs := make([]byte, 2*cap(z.b)+n) + copy(bs, z.b[:oldcursor]) + z.b = bs + } else if z.c > len(z.b) { + z.b = z.b[:cap(z.b)] + } + return +} + +// --------------------------------------------- + +type encFnInfo struct { + ti *typeInfo + e *Encoder + ee encDriver + xfFn func(reflect.Value) ([]byte, error) + xfTag byte +} + +func (f *encFnInfo) builtin(rv reflect.Value) { + f.ee.encodeBuiltin(f.ti.rtid, rv.Interface()) +} + +func (f *encFnInfo) rawExt(rv reflect.Value) { + f.e.encRawExt(rv.Interface().(RawExt)) +} + +func (f *encFnInfo) ext(rv reflect.Value) { + bs, fnerr := f.xfFn(rv) + if fnerr != nil { + panic(fnerr) + } + if bs == nil { + f.ee.encodeNil() + return + } + if f.e.hh.writeExt() { + f.ee.encodeExtPreamble(f.xfTag, len(bs)) + f.e.w.writeb(bs) + } else { + f.ee.encodeStringBytes(c_RAW, bs) + } + +} + +func (f *encFnInfo) binaryMarshal(rv reflect.Value) { + var bm binaryMarshaler + if f.ti.mIndir == 0 { + bm = rv.Interface().(binaryMarshaler) + } else if f.ti.mIndir == -1 { + bm = rv.Addr().Interface().(binaryMarshaler) + } else { + for j, k := int8(0), f.ti.mIndir; j < k; j++ { + if rv.IsNil() { + f.ee.encodeNil() + return + } + rv = rv.Elem() + } + bm = rv.Interface().(binaryMarshaler) + } + // debugf(">>>> binaryMarshaler: %T", rv.Interface()) + bs, fnerr := bm.MarshalBinary() + if fnerr != nil { + panic(fnerr) + } + if bs == nil { + f.ee.encodeNil() + } else { + f.ee.encodeStringBytes(c_RAW, bs) + } +} + +func (f *encFnInfo) kBool(rv reflect.Value) { + f.ee.encodeBool(rv.Bool()) +} + +func (f *encFnInfo) kString(rv reflect.Value) { + f.ee.encodeString(c_UTF8, rv.String()) +} + +func (f *encFnInfo) kFloat64(rv reflect.Value) { + f.ee.encodeFloat64(rv.Float()) +} + +func (f *encFnInfo) kFloat32(rv reflect.Value) { + f.ee.encodeFloat32(float32(rv.Float())) +} + +func (f *encFnInfo) kInt(rv reflect.Value) { + f.ee.encodeInt(rv.Int()) +} + +func (f *encFnInfo) kUint(rv reflect.Value) { + f.ee.encodeUint(rv.Uint()) +} + +func (f *encFnInfo) kInvalid(rv reflect.Value) { + f.ee.encodeNil() +} + +func (f *encFnInfo) kErr(rv reflect.Value) { + encErr("Unsupported kind: %s, for: %#v", rv.Kind(), rv) +} + +func (f *encFnInfo) kSlice(rv reflect.Value) { + if rv.IsNil() { + f.ee.encodeNil() + return + } + + if shortCircuitReflectToFastPath { + switch f.ti.rtid { + case intfSliceTypId: + f.e.encSliceIntf(rv.Interface().([]interface{})) + return + case strSliceTypId: + f.e.encSliceStr(rv.Interface().([]string)) + return + case uint64SliceTypId: + f.e.encSliceUint64(rv.Interface().([]uint64)) + return + case int64SliceTypId: + f.e.encSliceInt64(rv.Interface().([]int64)) + return + } + } + + // If in this method, then there was no extension function defined. + // So it's okay to treat as []byte. + if f.ti.rtid == uint8SliceTypId || f.ti.rt.Elem().Kind() == reflect.Uint8 { + f.ee.encodeStringBytes(c_RAW, rv.Bytes()) + return + } + + l := rv.Len() + if f.ti.mbs { + if l%2 == 1 { + encErr("mapBySlice: invalid length (must be divisible by 2): %v", l) + } + f.ee.encodeMapPreamble(l / 2) + } else { + f.ee.encodeArrayPreamble(l) + } + if l == 0 { + return + } + for j := 0; j < l; j++ { + // TODO: Consider perf implication of encoding odd index values as symbols if type is string + f.e.encodeValue(rv.Index(j)) + } +} + +func (f *encFnInfo) kArray(rv reflect.Value) { + // We cannot share kSlice method, because the array may be non-addressable. + // E.g. type struct S{B [2]byte}; Encode(S{}) will bomb on "panic: slice of unaddressable array". + // So we have to duplicate the functionality here. + // f.e.encodeValue(rv.Slice(0, rv.Len())) + // f.kSlice(rv.Slice(0, rv.Len())) + + l := rv.Len() + // Handle an array of bytes specially (in line with what is done for slices) + if f.ti.rt.Elem().Kind() == reflect.Uint8 { + if l == 0 { + f.ee.encodeStringBytes(c_RAW, nil) + return + } + var bs []byte + if rv.CanAddr() { + bs = rv.Slice(0, l).Bytes() + } else { + bs = make([]byte, l) + for i := 0; i < l; i++ { + bs[i] = byte(rv.Index(i).Uint()) + } + } + f.ee.encodeStringBytes(c_RAW, bs) + return + } + + if f.ti.mbs { + if l%2 == 1 { + encErr("mapBySlice: invalid length (must be divisible by 2): %v", l) + } + f.ee.encodeMapPreamble(l / 2) + } else { + f.ee.encodeArrayPreamble(l) + } + if l == 0 { + return + } + for j := 0; j < l; j++ { + // TODO: Consider perf implication of encoding odd index values as symbols if type is string + f.e.encodeValue(rv.Index(j)) + } +} + +func (f *encFnInfo) kStruct(rv reflect.Value) { + fti := f.ti + newlen := len(fti.sfi) + rvals := make([]reflect.Value, newlen) + var encnames []string + e := f.e + tisfi := fti.sfip + toMap := !(fti.toArray || e.h.StructToArray) + // if toMap, use the sorted array. If toArray, use unsorted array (to match sequence in struct) + if toMap { + tisfi = fti.sfi + encnames = make([]string, newlen) + } + newlen = 0 + for _, si := range tisfi { + if si.i != -1 { + rvals[newlen] = rv.Field(int(si.i)) + } else { + rvals[newlen] = rv.FieldByIndex(si.is) + } + if toMap { + if si.omitEmpty && isEmptyValue(rvals[newlen]) { + continue + } + encnames[newlen] = si.encName + } else { + if si.omitEmpty && isEmptyValue(rvals[newlen]) { + rvals[newlen] = reflect.Value{} //encode as nil + } + } + newlen++ + } + + // debugf(">>>> kStruct: newlen: %v", newlen) + if toMap { + ee := f.ee //don't dereference everytime + ee.encodeMapPreamble(newlen) + // asSymbols := e.h.AsSymbols&AsSymbolStructFieldNameFlag != 0 + asSymbols := e.h.AsSymbols == AsSymbolDefault || e.h.AsSymbols&AsSymbolStructFieldNameFlag != 0 + for j := 0; j < newlen; j++ { + if asSymbols { + ee.encodeSymbol(encnames[j]) + } else { + ee.encodeString(c_UTF8, encnames[j]) + } + e.encodeValue(rvals[j]) + } + } else { + f.ee.encodeArrayPreamble(newlen) + for j := 0; j < newlen; j++ { + e.encodeValue(rvals[j]) + } + } +} + +// func (f *encFnInfo) kPtr(rv reflect.Value) { +// debugf(">>>>>>> ??? encode kPtr called - shouldn't get called") +// if rv.IsNil() { +// f.ee.encodeNil() +// return +// } +// f.e.encodeValue(rv.Elem()) +// } + +func (f *encFnInfo) kInterface(rv reflect.Value) { + if rv.IsNil() { + f.ee.encodeNil() + return + } + f.e.encodeValue(rv.Elem()) +} + +func (f *encFnInfo) kMap(rv reflect.Value) { + if rv.IsNil() { + f.ee.encodeNil() + return + } + + if shortCircuitReflectToFastPath { + switch f.ti.rtid { + case mapIntfIntfTypId: + f.e.encMapIntfIntf(rv.Interface().(map[interface{}]interface{})) + return + case mapStrIntfTypId: + f.e.encMapStrIntf(rv.Interface().(map[string]interface{})) + return + case mapStrStrTypId: + f.e.encMapStrStr(rv.Interface().(map[string]string)) + return + case mapInt64IntfTypId: + f.e.encMapInt64Intf(rv.Interface().(map[int64]interface{})) + return + case mapUint64IntfTypId: + f.e.encMapUint64Intf(rv.Interface().(map[uint64]interface{})) + return + } + } + + l := rv.Len() + f.ee.encodeMapPreamble(l) + if l == 0 { + return + } + // keyTypeIsString := f.ti.rt.Key().Kind() == reflect.String + keyTypeIsString := f.ti.rt.Key() == stringTyp + var asSymbols bool + if keyTypeIsString { + asSymbols = f.e.h.AsSymbols&AsSymbolMapStringKeysFlag != 0 + } + mks := rv.MapKeys() + // for j, lmks := 0, len(mks); j < lmks; j++ { + for j := range mks { + if keyTypeIsString { + if asSymbols { + f.ee.encodeSymbol(mks[j].String()) + } else { + f.ee.encodeString(c_UTF8, mks[j].String()) + } + } else { + f.e.encodeValue(mks[j]) + } + f.e.encodeValue(rv.MapIndex(mks[j])) + } + +} + +// -------------------------------------------------- + +// encFn encapsulates the captured variables and the encode function. +// This way, we only do some calculations one times, and pass to the +// code block that should be called (encapsulated in a function) +// instead of executing the checks every time. +type encFn struct { + i *encFnInfo + f func(*encFnInfo, reflect.Value) +} + +// -------------------------------------------------- + +// An Encoder writes an object to an output stream in the codec format. +type Encoder struct { + w encWriter + e encDriver + h *BasicHandle + hh Handle + f map[uintptr]encFn + x []uintptr + s []encFn +} + +// NewEncoder returns an Encoder for encoding into an io.Writer. +// +// For efficiency, Users are encouraged to pass in a memory buffered writer +// (eg bufio.Writer, bytes.Buffer). +func NewEncoder(w io.Writer, h Handle) *Encoder { + ww, ok := w.(ioEncWriterWriter) + if !ok { + sww := simpleIoEncWriterWriter{w: w} + sww.bw, _ = w.(io.ByteWriter) + sww.sw, _ = w.(ioEncStringWriter) + ww = &sww + //ww = bufio.NewWriterSize(w, defEncByteBufSize) + } + z := ioEncWriter{ + w: ww, + } + return &Encoder{w: &z, hh: h, h: h.getBasicHandle(), e: h.newEncDriver(&z)} +} + +// NewEncoderBytes returns an encoder for encoding directly and efficiently +// into a byte slice, using zero-copying to temporary slices. +// +// It will potentially replace the output byte slice pointed to. +// After encoding, the out parameter contains the encoded contents. +func NewEncoderBytes(out *[]byte, h Handle) *Encoder { + in := *out + if in == nil { + in = make([]byte, defEncByteBufSize) + } + z := bytesEncWriter{ + b: in, + out: out, + } + return &Encoder{w: &z, hh: h, h: h.getBasicHandle(), e: h.newEncDriver(&z)} +} + +// Encode writes an object into a stream in the codec format. +// +// Encoding can be configured via the "codec" struct tag for the fields. +// +// The "codec" key in struct field's tag value is the key name, +// followed by an optional comma and options. +// +// To set an option on all fields (e.g. omitempty on all fields), you +// can create a field called _struct, and set flags on it. +// +// Struct values "usually" encode as maps. Each exported struct field is encoded unless: +// - the field's codec tag is "-", OR +// - the field is empty and its codec tag specifies the "omitempty" option. +// +// When encoding as a map, the first string in the tag (before the comma) +// is the map key string to use when encoding. +// +// However, struct values may encode as arrays. This happens when: +// - StructToArray Encode option is set, OR +// - the codec tag on the _struct field sets the "toarray" option +// +// Values with types that implement MapBySlice are encoded as stream maps. +// +// The empty values (for omitempty option) are false, 0, any nil pointer +// or interface value, and any array, slice, map, or string of length zero. +// +// Anonymous fields are encoded inline if no struct tag is present. +// Else they are encoded as regular fields. +// +// Examples: +// +// type MyStruct struct { +// _struct bool `codec:",omitempty"` //set omitempty for every field +// Field1 string `codec:"-"` //skip this field +// Field2 int `codec:"myName"` //Use key "myName" in encode stream +// Field3 int32 `codec:",omitempty"` //use key "Field3". Omit if empty. +// Field4 bool `codec:"f4,omitempty"` //use key "f4". Omit if empty. +// ... +// } +// +// type MyStruct struct { +// _struct bool `codec:",omitempty,toarray"` //set omitempty for every field +// //and encode struct as an array +// } +// +// The mode of encoding is based on the type of the value. When a value is seen: +// - If an extension is registered for it, call that extension function +// - If it implements BinaryMarshaler, call its MarshalBinary() (data []byte, err error) +// - Else encode it based on its reflect.Kind +// +// Note that struct field names and keys in map[string]XXX will be treated as symbols. +// Some formats support symbols (e.g. binc) and will properly encode the string +// only once in the stream, and use a tag to refer to it thereafter. +func (e *Encoder) Encode(v interface{}) (err error) { + defer panicToErr(&err) + e.encode(v) + e.w.atEndOfEncode() + return +} + +func (e *Encoder) encode(iv interface{}) { + switch v := iv.(type) { + case nil: + e.e.encodeNil() + + case reflect.Value: + e.encodeValue(v) + + case string: + e.e.encodeString(c_UTF8, v) + case bool: + e.e.encodeBool(v) + case int: + e.e.encodeInt(int64(v)) + case int8: + e.e.encodeInt(int64(v)) + case int16: + e.e.encodeInt(int64(v)) + case int32: + e.e.encodeInt(int64(v)) + case int64: + e.e.encodeInt(v) + case uint: + e.e.encodeUint(uint64(v)) + case uint8: + e.e.encodeUint(uint64(v)) + case uint16: + e.e.encodeUint(uint64(v)) + case uint32: + e.e.encodeUint(uint64(v)) + case uint64: + e.e.encodeUint(v) + case float32: + e.e.encodeFloat32(v) + case float64: + e.e.encodeFloat64(v) + + case []interface{}: + e.encSliceIntf(v) + case []string: + e.encSliceStr(v) + case []int64: + e.encSliceInt64(v) + case []uint64: + e.encSliceUint64(v) + case []uint8: + e.e.encodeStringBytes(c_RAW, v) + + case map[interface{}]interface{}: + e.encMapIntfIntf(v) + case map[string]interface{}: + e.encMapStrIntf(v) + case map[string]string: + e.encMapStrStr(v) + case map[int64]interface{}: + e.encMapInt64Intf(v) + case map[uint64]interface{}: + e.encMapUint64Intf(v) + + case *string: + e.e.encodeString(c_UTF8, *v) + case *bool: + e.e.encodeBool(*v) + case *int: + e.e.encodeInt(int64(*v)) + case *int8: + e.e.encodeInt(int64(*v)) + case *int16: + e.e.encodeInt(int64(*v)) + case *int32: + e.e.encodeInt(int64(*v)) + case *int64: + e.e.encodeInt(*v) + case *uint: + e.e.encodeUint(uint64(*v)) + case *uint8: + e.e.encodeUint(uint64(*v)) + case *uint16: + e.e.encodeUint(uint64(*v)) + case *uint32: + e.e.encodeUint(uint64(*v)) + case *uint64: + e.e.encodeUint(*v) + case *float32: + e.e.encodeFloat32(*v) + case *float64: + e.e.encodeFloat64(*v) + + case *[]interface{}: + e.encSliceIntf(*v) + case *[]string: + e.encSliceStr(*v) + case *[]int64: + e.encSliceInt64(*v) + case *[]uint64: + e.encSliceUint64(*v) + case *[]uint8: + e.e.encodeStringBytes(c_RAW, *v) + + case *map[interface{}]interface{}: + e.encMapIntfIntf(*v) + case *map[string]interface{}: + e.encMapStrIntf(*v) + case *map[string]string: + e.encMapStrStr(*v) + case *map[int64]interface{}: + e.encMapInt64Intf(*v) + case *map[uint64]interface{}: + e.encMapUint64Intf(*v) + + default: + e.encodeValue(reflect.ValueOf(iv)) + } +} + +func (e *Encoder) encodeValue(rv reflect.Value) { + for rv.Kind() == reflect.Ptr { + if rv.IsNil() { + e.e.encodeNil() + return + } + rv = rv.Elem() + } + + rt := rv.Type() + rtid := reflect.ValueOf(rt).Pointer() + + // if e.f == nil && e.s == nil { debugf("---->Creating new enc f map for type: %v\n", rt) } + var fn encFn + var ok bool + if useMapForCodecCache { + fn, ok = e.f[rtid] + } else { + for i, v := range e.x { + if v == rtid { + fn, ok = e.s[i], true + break + } + } + } + if !ok { + // debugf("\tCreating new enc fn for type: %v\n", rt) + fi := encFnInfo{ti: getTypeInfo(rtid, rt), e: e, ee: e.e} + fn.i = &fi + if rtid == rawExtTypId { + fn.f = (*encFnInfo).rawExt + } else if e.e.isBuiltinType(rtid) { + fn.f = (*encFnInfo).builtin + } else if xfTag, xfFn := e.h.getEncodeExt(rtid); xfFn != nil { + fi.xfTag, fi.xfFn = xfTag, xfFn + fn.f = (*encFnInfo).ext + } else if supportBinaryMarshal && fi.ti.m { + fn.f = (*encFnInfo).binaryMarshal + } else { + switch rk := rt.Kind(); rk { + case reflect.Bool: + fn.f = (*encFnInfo).kBool + case reflect.String: + fn.f = (*encFnInfo).kString + case reflect.Float64: + fn.f = (*encFnInfo).kFloat64 + case reflect.Float32: + fn.f = (*encFnInfo).kFloat32 + case reflect.Int, reflect.Int8, reflect.Int64, reflect.Int32, reflect.Int16: + fn.f = (*encFnInfo).kInt + case reflect.Uint8, reflect.Uint64, reflect.Uint, reflect.Uint32, reflect.Uint16: + fn.f = (*encFnInfo).kUint + case reflect.Invalid: + fn.f = (*encFnInfo).kInvalid + case reflect.Slice: + fn.f = (*encFnInfo).kSlice + case reflect.Array: + fn.f = (*encFnInfo).kArray + case reflect.Struct: + fn.f = (*encFnInfo).kStruct + // case reflect.Ptr: + // fn.f = (*encFnInfo).kPtr + case reflect.Interface: + fn.f = (*encFnInfo).kInterface + case reflect.Map: + fn.f = (*encFnInfo).kMap + default: + fn.f = (*encFnInfo).kErr + } + } + if useMapForCodecCache { + if e.f == nil { + e.f = make(map[uintptr]encFn, 16) + } + e.f[rtid] = fn + } else { + e.s = append(e.s, fn) + e.x = append(e.x, rtid) + } + } + + fn.f(fn.i, rv) + +} + +func (e *Encoder) encRawExt(re RawExt) { + if re.Data == nil { + e.e.encodeNil() + return + } + if e.hh.writeExt() { + e.e.encodeExtPreamble(re.Tag, len(re.Data)) + e.w.writeb(re.Data) + } else { + e.e.encodeStringBytes(c_RAW, re.Data) + } +} + +// --------------------------------------------- +// short circuit functions for common maps and slices + +func (e *Encoder) encSliceIntf(v []interface{}) { + e.e.encodeArrayPreamble(len(v)) + for _, v2 := range v { + e.encode(v2) + } +} + +func (e *Encoder) encSliceStr(v []string) { + e.e.encodeArrayPreamble(len(v)) + for _, v2 := range v { + e.e.encodeString(c_UTF8, v2) + } +} + +func (e *Encoder) encSliceInt64(v []int64) { + e.e.encodeArrayPreamble(len(v)) + for _, v2 := range v { + e.e.encodeInt(v2) + } +} + +func (e *Encoder) encSliceUint64(v []uint64) { + e.e.encodeArrayPreamble(len(v)) + for _, v2 := range v { + e.e.encodeUint(v2) + } +} + +func (e *Encoder) encMapStrStr(v map[string]string) { + e.e.encodeMapPreamble(len(v)) + asSymbols := e.h.AsSymbols&AsSymbolMapStringKeysFlag != 0 + for k2, v2 := range v { + if asSymbols { + e.e.encodeSymbol(k2) + } else { + e.e.encodeString(c_UTF8, k2) + } + e.e.encodeString(c_UTF8, v2) + } +} + +func (e *Encoder) encMapStrIntf(v map[string]interface{}) { + e.e.encodeMapPreamble(len(v)) + asSymbols := e.h.AsSymbols&AsSymbolMapStringKeysFlag != 0 + for k2, v2 := range v { + if asSymbols { + e.e.encodeSymbol(k2) + } else { + e.e.encodeString(c_UTF8, k2) + } + e.encode(v2) + } +} + +func (e *Encoder) encMapInt64Intf(v map[int64]interface{}) { + e.e.encodeMapPreamble(len(v)) + for k2, v2 := range v { + e.e.encodeInt(k2) + e.encode(v2) + } +} + +func (e *Encoder) encMapUint64Intf(v map[uint64]interface{}) { + e.e.encodeMapPreamble(len(v)) + for k2, v2 := range v { + e.e.encodeUint(uint64(k2)) + e.encode(v2) + } +} + +func (e *Encoder) encMapIntfIntf(v map[interface{}]interface{}) { + e.e.encodeMapPreamble(len(v)) + for k2, v2 := range v { + e.encode(k2) + e.encode(v2) + } +} + +// ---------------------------------------- + +func encErr(format string, params ...interface{}) { + doPanic(msgTagEnc, format, params...) +} diff --git a/vendor/github.com/hashicorp/go-msgpack/codec/helper.go b/vendor/github.com/hashicorp/go-msgpack/codec/helper.go new file mode 100644 index 00000000000..e6dc0563f09 --- /dev/null +++ b/vendor/github.com/hashicorp/go-msgpack/codec/helper.go @@ -0,0 +1,589 @@ +// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved. +// Use of this source code is governed by a BSD-style license found in the LICENSE file. + +package codec + +// Contains code shared by both encode and decode. + +import ( + "encoding/binary" + "fmt" + "math" + "reflect" + "sort" + "strings" + "sync" + "time" + "unicode" + "unicode/utf8" +) + +const ( + structTagName = "codec" + + // Support + // encoding.BinaryMarshaler: MarshalBinary() (data []byte, err error) + // encoding.BinaryUnmarshaler: UnmarshalBinary(data []byte) error + // This constant flag will enable or disable it. + supportBinaryMarshal = true + + // Each Encoder or Decoder uses a cache of functions based on conditionals, + // so that the conditionals are not run every time. + // + // Either a map or a slice is used to keep track of the functions. + // The map is more natural, but has a higher cost than a slice/array. + // This flag (useMapForCodecCache) controls which is used. + useMapForCodecCache = false + + // For some common container types, we can short-circuit an elaborate + // reflection dance and call encode/decode directly. + // The currently supported types are: + // - slices of strings, or id's (int64,uint64) or interfaces. + // - maps of str->str, str->intf, id(int64,uint64)->intf, intf->intf + shortCircuitReflectToFastPath = true + + // for debugging, set this to false, to catch panic traces. + // Note that this will always cause rpc tests to fail, since they need io.EOF sent via panic. + recoverPanicToErr = true +) + +type charEncoding uint8 + +const ( + c_RAW charEncoding = iota + c_UTF8 + c_UTF16LE + c_UTF16BE + c_UTF32LE + c_UTF32BE +) + +// valueType is the stream type +type valueType uint8 + +const ( + valueTypeUnset valueType = iota + valueTypeNil + valueTypeInt + valueTypeUint + valueTypeFloat + valueTypeBool + valueTypeString + valueTypeSymbol + valueTypeBytes + valueTypeMap + valueTypeArray + valueTypeTimestamp + valueTypeExt + + valueTypeInvalid = 0xff +) + +var ( + bigen = binary.BigEndian + structInfoFieldName = "_struct" + + cachedTypeInfo = make(map[uintptr]*typeInfo, 4) + cachedTypeInfoMutex sync.RWMutex + + intfSliceTyp = reflect.TypeOf([]interface{}(nil)) + intfTyp = intfSliceTyp.Elem() + + strSliceTyp = reflect.TypeOf([]string(nil)) + boolSliceTyp = reflect.TypeOf([]bool(nil)) + uintSliceTyp = reflect.TypeOf([]uint(nil)) + uint8SliceTyp = reflect.TypeOf([]uint8(nil)) + uint16SliceTyp = reflect.TypeOf([]uint16(nil)) + uint32SliceTyp = reflect.TypeOf([]uint32(nil)) + uint64SliceTyp = reflect.TypeOf([]uint64(nil)) + intSliceTyp = reflect.TypeOf([]int(nil)) + int8SliceTyp = reflect.TypeOf([]int8(nil)) + int16SliceTyp = reflect.TypeOf([]int16(nil)) + int32SliceTyp = reflect.TypeOf([]int32(nil)) + int64SliceTyp = reflect.TypeOf([]int64(nil)) + float32SliceTyp = reflect.TypeOf([]float32(nil)) + float64SliceTyp = reflect.TypeOf([]float64(nil)) + + mapIntfIntfTyp = reflect.TypeOf(map[interface{}]interface{}(nil)) + mapStrIntfTyp = reflect.TypeOf(map[string]interface{}(nil)) + mapStrStrTyp = reflect.TypeOf(map[string]string(nil)) + + mapIntIntfTyp = reflect.TypeOf(map[int]interface{}(nil)) + mapInt64IntfTyp = reflect.TypeOf(map[int64]interface{}(nil)) + mapUintIntfTyp = reflect.TypeOf(map[uint]interface{}(nil)) + mapUint64IntfTyp = reflect.TypeOf(map[uint64]interface{}(nil)) + + stringTyp = reflect.TypeOf("") + timeTyp = reflect.TypeOf(time.Time{}) + rawExtTyp = reflect.TypeOf(RawExt{}) + + mapBySliceTyp = reflect.TypeOf((*MapBySlice)(nil)).Elem() + binaryMarshalerTyp = reflect.TypeOf((*binaryMarshaler)(nil)).Elem() + binaryUnmarshalerTyp = reflect.TypeOf((*binaryUnmarshaler)(nil)).Elem() + + rawExtTypId = reflect.ValueOf(rawExtTyp).Pointer() + intfTypId = reflect.ValueOf(intfTyp).Pointer() + timeTypId = reflect.ValueOf(timeTyp).Pointer() + + intfSliceTypId = reflect.ValueOf(intfSliceTyp).Pointer() + strSliceTypId = reflect.ValueOf(strSliceTyp).Pointer() + + boolSliceTypId = reflect.ValueOf(boolSliceTyp).Pointer() + uintSliceTypId = reflect.ValueOf(uintSliceTyp).Pointer() + uint8SliceTypId = reflect.ValueOf(uint8SliceTyp).Pointer() + uint16SliceTypId = reflect.ValueOf(uint16SliceTyp).Pointer() + uint32SliceTypId = reflect.ValueOf(uint32SliceTyp).Pointer() + uint64SliceTypId = reflect.ValueOf(uint64SliceTyp).Pointer() + intSliceTypId = reflect.ValueOf(intSliceTyp).Pointer() + int8SliceTypId = reflect.ValueOf(int8SliceTyp).Pointer() + int16SliceTypId = reflect.ValueOf(int16SliceTyp).Pointer() + int32SliceTypId = reflect.ValueOf(int32SliceTyp).Pointer() + int64SliceTypId = reflect.ValueOf(int64SliceTyp).Pointer() + float32SliceTypId = reflect.ValueOf(float32SliceTyp).Pointer() + float64SliceTypId = reflect.ValueOf(float64SliceTyp).Pointer() + + mapStrStrTypId = reflect.ValueOf(mapStrStrTyp).Pointer() + mapIntfIntfTypId = reflect.ValueOf(mapIntfIntfTyp).Pointer() + mapStrIntfTypId = reflect.ValueOf(mapStrIntfTyp).Pointer() + mapIntIntfTypId = reflect.ValueOf(mapIntIntfTyp).Pointer() + mapInt64IntfTypId = reflect.ValueOf(mapInt64IntfTyp).Pointer() + mapUintIntfTypId = reflect.ValueOf(mapUintIntfTyp).Pointer() + mapUint64IntfTypId = reflect.ValueOf(mapUint64IntfTyp).Pointer() + // Id = reflect.ValueOf().Pointer() + // mapBySliceTypId = reflect.ValueOf(mapBySliceTyp).Pointer() + + binaryMarshalerTypId = reflect.ValueOf(binaryMarshalerTyp).Pointer() + binaryUnmarshalerTypId = reflect.ValueOf(binaryUnmarshalerTyp).Pointer() + + intBitsize uint8 = uint8(reflect.TypeOf(int(0)).Bits()) + uintBitsize uint8 = uint8(reflect.TypeOf(uint(0)).Bits()) + + bsAll0x00 = []byte{0, 0, 0, 0, 0, 0, 0, 0} + bsAll0xff = []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} +) + +type binaryUnmarshaler interface { + UnmarshalBinary(data []byte) error +} + +type binaryMarshaler interface { + MarshalBinary() (data []byte, err error) +} + +// MapBySlice represents a slice which should be encoded as a map in the stream. +// The slice contains a sequence of key-value pairs. +type MapBySlice interface { + MapBySlice() +} + +// WARNING: DO NOT USE DIRECTLY. EXPORTED FOR GODOC BENEFIT. WILL BE REMOVED. +// +// BasicHandle encapsulates the common options and extension functions. +type BasicHandle struct { + extHandle + EncodeOptions + DecodeOptions +} + +// Handle is the interface for a specific encoding format. +// +// Typically, a Handle is pre-configured before first time use, +// and not modified while in use. Such a pre-configured Handle +// is safe for concurrent access. +type Handle interface { + writeExt() bool + getBasicHandle() *BasicHandle + newEncDriver(w encWriter) encDriver + newDecDriver(r decReader) decDriver +} + +// RawExt represents raw unprocessed extension data. +type RawExt struct { + Tag byte + Data []byte +} + +type extTypeTagFn struct { + rtid uintptr + rt reflect.Type + tag byte + encFn func(reflect.Value) ([]byte, error) + decFn func(reflect.Value, []byte) error +} + +type extHandle []*extTypeTagFn + +// AddExt registers an encode and decode function for a reflect.Type. +// Note that the type must be a named type, and specifically not +// a pointer or Interface. An error is returned if that is not honored. +// +// To Deregister an ext, call AddExt with 0 tag, nil encfn and nil decfn. +func (o *extHandle) AddExt( + rt reflect.Type, + tag byte, + encfn func(reflect.Value) ([]byte, error), + decfn func(reflect.Value, []byte) error, +) (err error) { + // o is a pointer, because we may need to initialize it + if rt.PkgPath() == "" || rt.Kind() == reflect.Interface { + err = fmt.Errorf("codec.Handle.AddExt: Takes named type, especially not a pointer or interface: %T", + reflect.Zero(rt).Interface()) + return + } + + // o cannot be nil, since it is always embedded in a Handle. + // if nil, let it panic. + // if o == nil { + // err = errors.New("codec.Handle.AddExt: extHandle cannot be a nil pointer.") + // return + // } + + rtid := reflect.ValueOf(rt).Pointer() + for _, v := range *o { + if v.rtid == rtid { + v.tag, v.encFn, v.decFn = tag, encfn, decfn + return + } + } + + *o = append(*o, &extTypeTagFn{rtid, rt, tag, encfn, decfn}) + return +} + +func (o extHandle) getExt(rtid uintptr) *extTypeTagFn { + for _, v := range o { + if v.rtid == rtid { + return v + } + } + return nil +} + +func (o extHandle) getExtForTag(tag byte) *extTypeTagFn { + for _, v := range o { + if v.tag == tag { + return v + } + } + return nil +} + +func (o extHandle) getDecodeExtForTag(tag byte) ( + rv reflect.Value, fn func(reflect.Value, []byte) error) { + if x := o.getExtForTag(tag); x != nil { + // ext is only registered for base + rv = reflect.New(x.rt).Elem() + fn = x.decFn + } + return +} + +func (o extHandle) getDecodeExt(rtid uintptr) (tag byte, fn func(reflect.Value, []byte) error) { + if x := o.getExt(rtid); x != nil { + tag = x.tag + fn = x.decFn + } + return +} + +func (o extHandle) getEncodeExt(rtid uintptr) (tag byte, fn func(reflect.Value) ([]byte, error)) { + if x := o.getExt(rtid); x != nil { + tag = x.tag + fn = x.encFn + } + return +} + +type structFieldInfo struct { + encName string // encode name + + // only one of 'i' or 'is' can be set. If 'i' is -1, then 'is' has been set. + + is []int // (recursive/embedded) field index in struct + i int16 // field index in struct + omitEmpty bool + toArray bool // if field is _struct, is the toArray set? + + // tag string // tag + // name string // field name + // encNameBs []byte // encoded name as byte stream + // ikind int // kind of the field as an int i.e. int(reflect.Kind) +} + +func parseStructFieldInfo(fname string, stag string) *structFieldInfo { + if fname == "" { + panic("parseStructFieldInfo: No Field Name") + } + si := structFieldInfo{ + // name: fname, + encName: fname, + // tag: stag, + } + + if stag != "" { + for i, s := range strings.Split(stag, ",") { + if i == 0 { + if s != "" { + si.encName = s + } + } else { + switch s { + case "omitempty": + si.omitEmpty = true + case "toarray": + si.toArray = true + } + } + } + } + // si.encNameBs = []byte(si.encName) + return &si +} + +type sfiSortedByEncName []*structFieldInfo + +func (p sfiSortedByEncName) Len() int { + return len(p) +} + +func (p sfiSortedByEncName) Less(i, j int) bool { + return p[i].encName < p[j].encName +} + +func (p sfiSortedByEncName) Swap(i, j int) { + p[i], p[j] = p[j], p[i] +} + +// typeInfo keeps information about each type referenced in the encode/decode sequence. +// +// During an encode/decode sequence, we work as below: +// - If base is a built in type, en/decode base value +// - If base is registered as an extension, en/decode base value +// - If type is binary(M/Unm)arshaler, call Binary(M/Unm)arshal method +// - Else decode appropriately based on the reflect.Kind +type typeInfo struct { + sfi []*structFieldInfo // sorted. Used when enc/dec struct to map. + sfip []*structFieldInfo // unsorted. Used when enc/dec struct to array. + + rt reflect.Type + rtid uintptr + + // baseId gives pointer to the base reflect.Type, after deferencing + // the pointers. E.g. base type of ***time.Time is time.Time. + base reflect.Type + baseId uintptr + baseIndir int8 // number of indirections to get to base + + mbs bool // base type (T or *T) is a MapBySlice + + m bool // base type (T or *T) is a binaryMarshaler + unm bool // base type (T or *T) is a binaryUnmarshaler + mIndir int8 // number of indirections to get to binaryMarshaler type + unmIndir int8 // number of indirections to get to binaryUnmarshaler type + toArray bool // whether this (struct) type should be encoded as an array +} + +func (ti *typeInfo) indexForEncName(name string) int { + //tisfi := ti.sfi + const binarySearchThreshold = 16 + if sfilen := len(ti.sfi); sfilen < binarySearchThreshold { + // linear search. faster than binary search in my testing up to 16-field structs. + for i, si := range ti.sfi { + if si.encName == name { + return i + } + } + } else { + // binary search. adapted from sort/search.go. + h, i, j := 0, 0, sfilen + for i < j { + h = i + (j-i)/2 + if ti.sfi[h].encName < name { + i = h + 1 + } else { + j = h + } + } + if i < sfilen && ti.sfi[i].encName == name { + return i + } + } + return -1 +} + +func getTypeInfo(rtid uintptr, rt reflect.Type) (pti *typeInfo) { + var ok bool + cachedTypeInfoMutex.RLock() + pti, ok = cachedTypeInfo[rtid] + cachedTypeInfoMutex.RUnlock() + if ok { + return + } + + cachedTypeInfoMutex.Lock() + defer cachedTypeInfoMutex.Unlock() + if pti, ok = cachedTypeInfo[rtid]; ok { + return + } + + ti := typeInfo{rt: rt, rtid: rtid} + pti = &ti + + var indir int8 + if ok, indir = implementsIntf(rt, binaryMarshalerTyp); ok { + ti.m, ti.mIndir = true, indir + } + if ok, indir = implementsIntf(rt, binaryUnmarshalerTyp); ok { + ti.unm, ti.unmIndir = true, indir + } + if ok, _ = implementsIntf(rt, mapBySliceTyp); ok { + ti.mbs = true + } + + pt := rt + var ptIndir int8 + // for ; pt.Kind() == reflect.Ptr; pt, ptIndir = pt.Elem(), ptIndir+1 { } + for pt.Kind() == reflect.Ptr { + pt = pt.Elem() + ptIndir++ + } + if ptIndir == 0 { + ti.base = rt + ti.baseId = rtid + } else { + ti.base = pt + ti.baseId = reflect.ValueOf(pt).Pointer() + ti.baseIndir = ptIndir + } + + if rt.Kind() == reflect.Struct { + var siInfo *structFieldInfo + if f, ok := rt.FieldByName(structInfoFieldName); ok { + siInfo = parseStructFieldInfo(structInfoFieldName, f.Tag.Get(structTagName)) + ti.toArray = siInfo.toArray + } + sfip := make([]*structFieldInfo, 0, rt.NumField()) + rgetTypeInfo(rt, nil, make(map[string]bool), &sfip, siInfo) + + // // try to put all si close together + // const tryToPutAllStructFieldInfoTogether = true + // if tryToPutAllStructFieldInfoTogether { + // sfip2 := make([]structFieldInfo, len(sfip)) + // for i, si := range sfip { + // sfip2[i] = *si + // } + // for i := range sfip { + // sfip[i] = &sfip2[i] + // } + // } + + ti.sfip = make([]*structFieldInfo, len(sfip)) + ti.sfi = make([]*structFieldInfo, len(sfip)) + copy(ti.sfip, sfip) + sort.Sort(sfiSortedByEncName(sfip)) + copy(ti.sfi, sfip) + } + // sfi = sfip + cachedTypeInfo[rtid] = pti + return +} + +func rgetTypeInfo(rt reflect.Type, indexstack []int, fnameToHastag map[string]bool, + sfi *[]*structFieldInfo, siInfo *structFieldInfo, +) { + // for rt.Kind() == reflect.Ptr { + // // indexstack = append(indexstack, 0) + // rt = rt.Elem() + // } + for j := 0; j < rt.NumField(); j++ { + f := rt.Field(j) + stag := f.Tag.Get(structTagName) + if stag == "-" { + continue + } + if r1, _ := utf8.DecodeRuneInString(f.Name); r1 == utf8.RuneError || !unicode.IsUpper(r1) { + continue + } + // if anonymous and there is no struct tag and its a struct (or pointer to struct), inline it. + if f.Anonymous && stag == "" { + ft := f.Type + for ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + if ft.Kind() == reflect.Struct { + indexstack2 := append(append(make([]int, 0, len(indexstack)+4), indexstack...), j) + rgetTypeInfo(ft, indexstack2, fnameToHastag, sfi, siInfo) + continue + } + } + // do not let fields with same name in embedded structs override field at higher level. + // this must be done after anonymous check, to allow anonymous field + // still include their child fields + if _, ok := fnameToHastag[f.Name]; ok { + continue + } + si := parseStructFieldInfo(f.Name, stag) + // si.ikind = int(f.Type.Kind()) + if len(indexstack) == 0 { + si.i = int16(j) + } else { + si.i = -1 + si.is = append(append(make([]int, 0, len(indexstack)+4), indexstack...), j) + } + + if siInfo != nil { + if siInfo.omitEmpty { + si.omitEmpty = true + } + } + *sfi = append(*sfi, si) + fnameToHastag[f.Name] = stag != "" + } +} + +func panicToErr(err *error) { + if recoverPanicToErr { + if x := recover(); x != nil { + //debug.PrintStack() + panicValToErr(x, err) + } + } +} + +func doPanic(tag string, format string, params ...interface{}) { + params2 := make([]interface{}, len(params)+1) + params2[0] = tag + copy(params2[1:], params) + panic(fmt.Errorf("%s: "+format, params2...)) +} + +func checkOverflowFloat32(f float64, doCheck bool) { + if !doCheck { + return + } + // check overflow (logic adapted from std pkg reflect/value.go OverflowFloat() + f2 := f + if f2 < 0 { + f2 = -f + } + if math.MaxFloat32 < f2 && f2 <= math.MaxFloat64 { + decErr("Overflow float32 value: %v", f2) + } +} + +func checkOverflow(ui uint64, i int64, bitsize uint8) { + // check overflow (logic adapted from std pkg reflect/value.go OverflowUint() + if bitsize == 0 { + return + } + if i != 0 { + if trunc := (i << (64 - bitsize)) >> (64 - bitsize); i != trunc { + decErr("Overflow int value: %v", i) + } + } + if ui != 0 { + if trunc := (ui << (64 - bitsize)) >> (64 - bitsize); ui != trunc { + decErr("Overflow uint value: %v", ui) + } + } +} diff --git a/vendor/github.com/hashicorp/go-msgpack/codec/helper_internal.go b/vendor/github.com/hashicorp/go-msgpack/codec/helper_internal.go new file mode 100644 index 00000000000..58417da958f --- /dev/null +++ b/vendor/github.com/hashicorp/go-msgpack/codec/helper_internal.go @@ -0,0 +1,127 @@ +// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved. +// Use of this source code is governed by a BSD-style license found in the LICENSE file. + +package codec + +// All non-std package dependencies live in this file, +// so porting to different environment is easy (just update functions). + +import ( + "errors" + "fmt" + "math" + "reflect" +) + +var ( + raisePanicAfterRecover = false + debugging = true +) + +func panicValToErr(panicVal interface{}, err *error) { + switch xerr := panicVal.(type) { + case error: + *err = xerr + case string: + *err = errors.New(xerr) + default: + *err = fmt.Errorf("%v", panicVal) + } + if raisePanicAfterRecover { + panic(panicVal) + } + return +} + +func isEmptyValueDeref(v reflect.Value, deref bool) bool { + switch v.Kind() { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + if deref { + if v.IsNil() { + return true + } + return isEmptyValueDeref(v.Elem(), deref) + } else { + return v.IsNil() + } + case reflect.Struct: + // return true if all fields are empty. else return false. + + // we cannot use equality check, because some fields may be maps/slices/etc + // and consequently the structs are not comparable. + // return v.Interface() == reflect.Zero(v.Type()).Interface() + for i, n := 0, v.NumField(); i < n; i++ { + if !isEmptyValueDeref(v.Field(i), deref) { + return false + } + } + return true + } + return false +} + +func isEmptyValue(v reflect.Value) bool { + return isEmptyValueDeref(v, true) +} + +func debugf(format string, args ...interface{}) { + if debugging { + if len(format) == 0 || format[len(format)-1] != '\n' { + format = format + "\n" + } + fmt.Printf(format, args...) + } +} + +func pruneSignExt(v []byte, pos bool) (n int) { + if len(v) < 2 { + } else if pos && v[0] == 0 { + for ; v[n] == 0 && n+1 < len(v) && (v[n+1]&(1<<7) == 0); n++ { + } + } else if !pos && v[0] == 0xff { + for ; v[n] == 0xff && n+1 < len(v) && (v[n+1]&(1<<7) != 0); n++ { + } + } + return +} + +func implementsIntf(typ, iTyp reflect.Type) (success bool, indir int8) { + if typ == nil { + return + } + rt := typ + // The type might be a pointer and we need to keep + // dereferencing to the base type until we find an implementation. + for { + if rt.Implements(iTyp) { + return true, indir + } + if p := rt; p.Kind() == reflect.Ptr { + indir++ + if indir >= math.MaxInt8 { // insane number of indirections + return false, 0 + } + rt = p.Elem() + continue + } + break + } + // No luck yet, but if this is a base type (non-pointer), the pointer might satisfy. + if typ.Kind() != reflect.Ptr { + // Not a pointer, but does the pointer work? + if reflect.PtrTo(typ).Implements(iTyp) { + return true, -1 + } + } + return false, 0 +} diff --git a/vendor/github.com/hashicorp/go-msgpack/codec/msgpack.go b/vendor/github.com/hashicorp/go-msgpack/codec/msgpack.go new file mode 100644 index 00000000000..da0500d1922 --- /dev/null +++ b/vendor/github.com/hashicorp/go-msgpack/codec/msgpack.go @@ -0,0 +1,816 @@ +// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved. +// Use of this source code is governed by a BSD-style license found in the LICENSE file. + +/* +MSGPACK + +Msgpack-c implementation powers the c, c++, python, ruby, etc libraries. +We need to maintain compatibility with it and how it encodes integer values +without caring about the type. + +For compatibility with behaviour of msgpack-c reference implementation: + - Go intX (>0) and uintX + IS ENCODED AS + msgpack +ve fixnum, unsigned + - Go intX (<0) + IS ENCODED AS + msgpack -ve fixnum, signed + +*/ +package codec + +import ( + "fmt" + "io" + "math" + "net/rpc" +) + +const ( + mpPosFixNumMin byte = 0x00 + mpPosFixNumMax = 0x7f + mpFixMapMin = 0x80 + mpFixMapMax = 0x8f + mpFixArrayMin = 0x90 + mpFixArrayMax = 0x9f + mpFixStrMin = 0xa0 + mpFixStrMax = 0xbf + mpNil = 0xc0 + _ = 0xc1 + mpFalse = 0xc2 + mpTrue = 0xc3 + mpFloat = 0xca + mpDouble = 0xcb + mpUint8 = 0xcc + mpUint16 = 0xcd + mpUint32 = 0xce + mpUint64 = 0xcf + mpInt8 = 0xd0 + mpInt16 = 0xd1 + mpInt32 = 0xd2 + mpInt64 = 0xd3 + + // extensions below + mpBin8 = 0xc4 + mpBin16 = 0xc5 + mpBin32 = 0xc6 + mpExt8 = 0xc7 + mpExt16 = 0xc8 + mpExt32 = 0xc9 + mpFixExt1 = 0xd4 + mpFixExt2 = 0xd5 + mpFixExt4 = 0xd6 + mpFixExt8 = 0xd7 + mpFixExt16 = 0xd8 + + mpStr8 = 0xd9 // new + mpStr16 = 0xda + mpStr32 = 0xdb + + mpArray16 = 0xdc + mpArray32 = 0xdd + + mpMap16 = 0xde + mpMap32 = 0xdf + + mpNegFixNumMin = 0xe0 + mpNegFixNumMax = 0xff +) + +// MsgpackSpecRpcMultiArgs is a special type which signifies to the MsgpackSpecRpcCodec +// that the backend RPC service takes multiple arguments, which have been arranged +// in sequence in the slice. +// +// The Codec then passes it AS-IS to the rpc service (without wrapping it in an +// array of 1 element). +type MsgpackSpecRpcMultiArgs []interface{} + +// A MsgpackContainer type specifies the different types of msgpackContainers. +type msgpackContainerType struct { + fixCutoff int + bFixMin, b8, b16, b32 byte + hasFixMin, has8, has8Always bool +} + +var ( + msgpackContainerStr = msgpackContainerType{32, mpFixStrMin, mpStr8, mpStr16, mpStr32, true, true, false} + msgpackContainerBin = msgpackContainerType{0, 0, mpBin8, mpBin16, mpBin32, false, true, true} + msgpackContainerList = msgpackContainerType{16, mpFixArrayMin, 0, mpArray16, mpArray32, true, false, false} + msgpackContainerMap = msgpackContainerType{16, mpFixMapMin, 0, mpMap16, mpMap32, true, false, false} +) + +//--------------------------------------------- + +type msgpackEncDriver struct { + w encWriter + h *MsgpackHandle +} + +func (e *msgpackEncDriver) isBuiltinType(rt uintptr) bool { + //no builtin types. All encodings are based on kinds. Types supported as extensions. + return false +} + +func (e *msgpackEncDriver) encodeBuiltin(rt uintptr, v interface{}) {} + +func (e *msgpackEncDriver) encodeNil() { + e.w.writen1(mpNil) +} + +func (e *msgpackEncDriver) encodeInt(i int64) { + + switch { + case i >= 0: + e.encodeUint(uint64(i)) + case i >= -32: + e.w.writen1(byte(i)) + case i >= math.MinInt8: + e.w.writen2(mpInt8, byte(i)) + case i >= math.MinInt16: + e.w.writen1(mpInt16) + e.w.writeUint16(uint16(i)) + case i >= math.MinInt32: + e.w.writen1(mpInt32) + e.w.writeUint32(uint32(i)) + default: + e.w.writen1(mpInt64) + e.w.writeUint64(uint64(i)) + } +} + +func (e *msgpackEncDriver) encodeUint(i uint64) { + switch { + case i <= math.MaxInt8: + e.w.writen1(byte(i)) + case i <= math.MaxUint8: + e.w.writen2(mpUint8, byte(i)) + case i <= math.MaxUint16: + e.w.writen1(mpUint16) + e.w.writeUint16(uint16(i)) + case i <= math.MaxUint32: + e.w.writen1(mpUint32) + e.w.writeUint32(uint32(i)) + default: + e.w.writen1(mpUint64) + e.w.writeUint64(uint64(i)) + } +} + +func (e *msgpackEncDriver) encodeBool(b bool) { + if b { + e.w.writen1(mpTrue) + } else { + e.w.writen1(mpFalse) + } +} + +func (e *msgpackEncDriver) encodeFloat32(f float32) { + e.w.writen1(mpFloat) + e.w.writeUint32(math.Float32bits(f)) +} + +func (e *msgpackEncDriver) encodeFloat64(f float64) { + e.w.writen1(mpDouble) + e.w.writeUint64(math.Float64bits(f)) +} + +func (e *msgpackEncDriver) encodeExtPreamble(xtag byte, l int) { + switch { + case l == 1: + e.w.writen2(mpFixExt1, xtag) + case l == 2: + e.w.writen2(mpFixExt2, xtag) + case l == 4: + e.w.writen2(mpFixExt4, xtag) + case l == 8: + e.w.writen2(mpFixExt8, xtag) + case l == 16: + e.w.writen2(mpFixExt16, xtag) + case l < 256: + e.w.writen2(mpExt8, byte(l)) + e.w.writen1(xtag) + case l < 65536: + e.w.writen1(mpExt16) + e.w.writeUint16(uint16(l)) + e.w.writen1(xtag) + default: + e.w.writen1(mpExt32) + e.w.writeUint32(uint32(l)) + e.w.writen1(xtag) + } +} + +func (e *msgpackEncDriver) encodeArrayPreamble(length int) { + e.writeContainerLen(msgpackContainerList, length) +} + +func (e *msgpackEncDriver) encodeMapPreamble(length int) { + e.writeContainerLen(msgpackContainerMap, length) +} + +func (e *msgpackEncDriver) encodeString(c charEncoding, s string) { + if c == c_RAW && e.h.WriteExt { + e.writeContainerLen(msgpackContainerBin, len(s)) + } else { + e.writeContainerLen(msgpackContainerStr, len(s)) + } + if len(s) > 0 { + e.w.writestr(s) + } +} + +func (e *msgpackEncDriver) encodeSymbol(v string) { + e.encodeString(c_UTF8, v) +} + +func (e *msgpackEncDriver) encodeStringBytes(c charEncoding, bs []byte) { + if c == c_RAW && e.h.WriteExt { + e.writeContainerLen(msgpackContainerBin, len(bs)) + } else { + e.writeContainerLen(msgpackContainerStr, len(bs)) + } + if len(bs) > 0 { + e.w.writeb(bs) + } +} + +func (e *msgpackEncDriver) writeContainerLen(ct msgpackContainerType, l int) { + switch { + case ct.hasFixMin && l < ct.fixCutoff: + e.w.writen1(ct.bFixMin | byte(l)) + case ct.has8 && l < 256 && (ct.has8Always || e.h.WriteExt): + e.w.writen2(ct.b8, uint8(l)) + case l < 65536: + e.w.writen1(ct.b16) + e.w.writeUint16(uint16(l)) + default: + e.w.writen1(ct.b32) + e.w.writeUint32(uint32(l)) + } +} + +//--------------------------------------------- + +type msgpackDecDriver struct { + r decReader + h *MsgpackHandle + bd byte + bdRead bool + bdType valueType +} + +func (d *msgpackDecDriver) isBuiltinType(rt uintptr) bool { + //no builtin types. All encodings are based on kinds. Types supported as extensions. + return false +} + +func (d *msgpackDecDriver) decodeBuiltin(rt uintptr, v interface{}) {} + +// Note: This returns either a primitive (int, bool, etc) for non-containers, +// or a containerType, or a specific type denoting nil or extension. +// It is called when a nil interface{} is passed, leaving it up to the DecDriver +// to introspect the stream and decide how best to decode. +// It deciphers the value by looking at the stream first. +func (d *msgpackDecDriver) decodeNaked() (v interface{}, vt valueType, decodeFurther bool) { + d.initReadNext() + bd := d.bd + + switch bd { + case mpNil: + vt = valueTypeNil + d.bdRead = false + case mpFalse: + vt = valueTypeBool + v = false + case mpTrue: + vt = valueTypeBool + v = true + + case mpFloat: + vt = valueTypeFloat + v = float64(math.Float32frombits(d.r.readUint32())) + case mpDouble: + vt = valueTypeFloat + v = math.Float64frombits(d.r.readUint64()) + + case mpUint8: + vt = valueTypeUint + v = uint64(d.r.readn1()) + case mpUint16: + vt = valueTypeUint + v = uint64(d.r.readUint16()) + case mpUint32: + vt = valueTypeUint + v = uint64(d.r.readUint32()) + case mpUint64: + vt = valueTypeUint + v = uint64(d.r.readUint64()) + + case mpInt8: + vt = valueTypeInt + v = int64(int8(d.r.readn1())) + case mpInt16: + vt = valueTypeInt + v = int64(int16(d.r.readUint16())) + case mpInt32: + vt = valueTypeInt + v = int64(int32(d.r.readUint32())) + case mpInt64: + vt = valueTypeInt + v = int64(int64(d.r.readUint64())) + + default: + switch { + case bd >= mpPosFixNumMin && bd <= mpPosFixNumMax: + // positive fixnum (always signed) + vt = valueTypeInt + v = int64(int8(bd)) + case bd >= mpNegFixNumMin && bd <= mpNegFixNumMax: + // negative fixnum + vt = valueTypeInt + v = int64(int8(bd)) + case bd == mpStr8, bd == mpStr16, bd == mpStr32, bd >= mpFixStrMin && bd <= mpFixStrMax: + if d.h.RawToString { + var rvm string + vt = valueTypeString + v = &rvm + } else { + var rvm = []byte{} + vt = valueTypeBytes + v = &rvm + } + decodeFurther = true + case bd == mpBin8, bd == mpBin16, bd == mpBin32: + var rvm = []byte{} + vt = valueTypeBytes + v = &rvm + decodeFurther = true + case bd == mpArray16, bd == mpArray32, bd >= mpFixArrayMin && bd <= mpFixArrayMax: + vt = valueTypeArray + decodeFurther = true + case bd == mpMap16, bd == mpMap32, bd >= mpFixMapMin && bd <= mpFixMapMax: + vt = valueTypeMap + decodeFurther = true + case bd >= mpFixExt1 && bd <= mpFixExt16, bd >= mpExt8 && bd <= mpExt32: + clen := d.readExtLen() + var re RawExt + re.Tag = d.r.readn1() + re.Data = d.r.readn(clen) + v = &re + vt = valueTypeExt + default: + decErr("Nil-Deciphered DecodeValue: %s: hex: %x, dec: %d", msgBadDesc, bd, bd) + } + } + if !decodeFurther { + d.bdRead = false + } + return +} + +// int can be decoded from msgpack type: intXXX or uintXXX +func (d *msgpackDecDriver) decodeInt(bitsize uint8) (i int64) { + switch d.bd { + case mpUint8: + i = int64(uint64(d.r.readn1())) + case mpUint16: + i = int64(uint64(d.r.readUint16())) + case mpUint32: + i = int64(uint64(d.r.readUint32())) + case mpUint64: + i = int64(d.r.readUint64()) + case mpInt8: + i = int64(int8(d.r.readn1())) + case mpInt16: + i = int64(int16(d.r.readUint16())) + case mpInt32: + i = int64(int32(d.r.readUint32())) + case mpInt64: + i = int64(d.r.readUint64()) + default: + switch { + case d.bd >= mpPosFixNumMin && d.bd <= mpPosFixNumMax: + i = int64(int8(d.bd)) + case d.bd >= mpNegFixNumMin && d.bd <= mpNegFixNumMax: + i = int64(int8(d.bd)) + default: + decErr("Unhandled single-byte unsigned integer value: %s: %x", msgBadDesc, d.bd) + } + } + // check overflow (logic adapted from std pkg reflect/value.go OverflowUint() + if bitsize > 0 { + if trunc := (i << (64 - bitsize)) >> (64 - bitsize); i != trunc { + decErr("Overflow int value: %v", i) + } + } + d.bdRead = false + return +} + +// uint can be decoded from msgpack type: intXXX or uintXXX +func (d *msgpackDecDriver) decodeUint(bitsize uint8) (ui uint64) { + switch d.bd { + case mpUint8: + ui = uint64(d.r.readn1()) + case mpUint16: + ui = uint64(d.r.readUint16()) + case mpUint32: + ui = uint64(d.r.readUint32()) + case mpUint64: + ui = d.r.readUint64() + case mpInt8: + if i := int64(int8(d.r.readn1())); i >= 0 { + ui = uint64(i) + } else { + decErr("Assigning negative signed value: %v, to unsigned type", i) + } + case mpInt16: + if i := int64(int16(d.r.readUint16())); i >= 0 { + ui = uint64(i) + } else { + decErr("Assigning negative signed value: %v, to unsigned type", i) + } + case mpInt32: + if i := int64(int32(d.r.readUint32())); i >= 0 { + ui = uint64(i) + } else { + decErr("Assigning negative signed value: %v, to unsigned type", i) + } + case mpInt64: + if i := int64(d.r.readUint64()); i >= 0 { + ui = uint64(i) + } else { + decErr("Assigning negative signed value: %v, to unsigned type", i) + } + default: + switch { + case d.bd >= mpPosFixNumMin && d.bd <= mpPosFixNumMax: + ui = uint64(d.bd) + case d.bd >= mpNegFixNumMin && d.bd <= mpNegFixNumMax: + decErr("Assigning negative signed value: %v, to unsigned type", int(d.bd)) + default: + decErr("Unhandled single-byte unsigned integer value: %s: %x", msgBadDesc, d.bd) + } + } + // check overflow (logic adapted from std pkg reflect/value.go OverflowUint() + if bitsize > 0 { + if trunc := (ui << (64 - bitsize)) >> (64 - bitsize); ui != trunc { + decErr("Overflow uint value: %v", ui) + } + } + d.bdRead = false + return +} + +// float can either be decoded from msgpack type: float, double or intX +func (d *msgpackDecDriver) decodeFloat(chkOverflow32 bool) (f float64) { + switch d.bd { + case mpFloat: + f = float64(math.Float32frombits(d.r.readUint32())) + case mpDouble: + f = math.Float64frombits(d.r.readUint64()) + default: + f = float64(d.decodeInt(0)) + } + checkOverflowFloat32(f, chkOverflow32) + d.bdRead = false + return +} + +// bool can be decoded from bool, fixnum 0 or 1. +func (d *msgpackDecDriver) decodeBool() (b bool) { + switch d.bd { + case mpFalse, 0: + // b = false + case mpTrue, 1: + b = true + default: + decErr("Invalid single-byte value for bool: %s: %x", msgBadDesc, d.bd) + } + d.bdRead = false + return +} + +func (d *msgpackDecDriver) decodeString() (s string) { + clen := d.readContainerLen(msgpackContainerStr) + if clen > 0 { + s = string(d.r.readn(clen)) + } + d.bdRead = false + return +} + +// Callers must check if changed=true (to decide whether to replace the one they have) +func (d *msgpackDecDriver) decodeBytes(bs []byte) (bsOut []byte, changed bool) { + // bytes can be decoded from msgpackContainerStr or msgpackContainerBin + var clen int + switch d.bd { + case mpBin8, mpBin16, mpBin32: + clen = d.readContainerLen(msgpackContainerBin) + default: + clen = d.readContainerLen(msgpackContainerStr) + } + // if clen < 0 { + // changed = true + // panic("length cannot be zero. this cannot be nil.") + // } + if clen > 0 { + // if no contents in stream, don't update the passed byteslice + if len(bs) != clen { + // Return changed=true if length of passed slice diff from length of bytes in stream + if len(bs) > clen { + bs = bs[:clen] + } else { + bs = make([]byte, clen) + } + bsOut = bs + changed = true + } + d.r.readb(bs) + } + d.bdRead = false + return +} + +// Every top-level decode funcs (i.e. decodeValue, decode) must call this first. +func (d *msgpackDecDriver) initReadNext() { + if d.bdRead { + return + } + d.bd = d.r.readn1() + d.bdRead = true + d.bdType = valueTypeUnset +} + +func (d *msgpackDecDriver) currentEncodedType() valueType { + if d.bdType == valueTypeUnset { + bd := d.bd + switch bd { + case mpNil: + d.bdType = valueTypeNil + case mpFalse, mpTrue: + d.bdType = valueTypeBool + case mpFloat, mpDouble: + d.bdType = valueTypeFloat + case mpUint8, mpUint16, mpUint32, mpUint64: + d.bdType = valueTypeUint + case mpInt8, mpInt16, mpInt32, mpInt64: + d.bdType = valueTypeInt + default: + switch { + case bd >= mpPosFixNumMin && bd <= mpPosFixNumMax: + d.bdType = valueTypeInt + case bd >= mpNegFixNumMin && bd <= mpNegFixNumMax: + d.bdType = valueTypeInt + case bd == mpStr8, bd == mpStr16, bd == mpStr32, bd >= mpFixStrMin && bd <= mpFixStrMax: + if d.h.RawToString { + d.bdType = valueTypeString + } else { + d.bdType = valueTypeBytes + } + case bd == mpBin8, bd == mpBin16, bd == mpBin32: + d.bdType = valueTypeBytes + case bd == mpArray16, bd == mpArray32, bd >= mpFixArrayMin && bd <= mpFixArrayMax: + d.bdType = valueTypeArray + case bd == mpMap16, bd == mpMap32, bd >= mpFixMapMin && bd <= mpFixMapMax: + d.bdType = valueTypeMap + case bd >= mpFixExt1 && bd <= mpFixExt16, bd >= mpExt8 && bd <= mpExt32: + d.bdType = valueTypeExt + default: + decErr("currentEncodedType: Undeciphered descriptor: %s: hex: %x, dec: %d", msgBadDesc, bd, bd) + } + } + } + return d.bdType +} + +func (d *msgpackDecDriver) tryDecodeAsNil() bool { + if d.bd == mpNil { + d.bdRead = false + return true + } + return false +} + +func (d *msgpackDecDriver) readContainerLen(ct msgpackContainerType) (clen int) { + bd := d.bd + switch { + case bd == mpNil: + clen = -1 // to represent nil + case bd == ct.b8: + clen = int(d.r.readn1()) + case bd == ct.b16: + clen = int(d.r.readUint16()) + case bd == ct.b32: + clen = int(d.r.readUint32()) + case (ct.bFixMin & bd) == ct.bFixMin: + clen = int(ct.bFixMin ^ bd) + default: + decErr("readContainerLen: %s: hex: %x, dec: %d", msgBadDesc, bd, bd) + } + d.bdRead = false + return +} + +func (d *msgpackDecDriver) readMapLen() int { + return d.readContainerLen(msgpackContainerMap) +} + +func (d *msgpackDecDriver) readArrayLen() int { + return d.readContainerLen(msgpackContainerList) +} + +func (d *msgpackDecDriver) readExtLen() (clen int) { + switch d.bd { + case mpNil: + clen = -1 // to represent nil + case mpFixExt1: + clen = 1 + case mpFixExt2: + clen = 2 + case mpFixExt4: + clen = 4 + case mpFixExt8: + clen = 8 + case mpFixExt16: + clen = 16 + case mpExt8: + clen = int(d.r.readn1()) + case mpExt16: + clen = int(d.r.readUint16()) + case mpExt32: + clen = int(d.r.readUint32()) + default: + decErr("decoding ext bytes: found unexpected byte: %x", d.bd) + } + return +} + +func (d *msgpackDecDriver) decodeExt(verifyTag bool, tag byte) (xtag byte, xbs []byte) { + xbd := d.bd + switch { + case xbd == mpBin8, xbd == mpBin16, xbd == mpBin32: + xbs, _ = d.decodeBytes(nil) + case xbd == mpStr8, xbd == mpStr16, xbd == mpStr32, + xbd >= mpFixStrMin && xbd <= mpFixStrMax: + xbs = []byte(d.decodeString()) + default: + clen := d.readExtLen() + xtag = d.r.readn1() + if verifyTag && xtag != tag { + decErr("Wrong extension tag. Got %b. Expecting: %v", xtag, tag) + } + xbs = d.r.readn(clen) + } + d.bdRead = false + return +} + +//-------------------------------------------------- + +//MsgpackHandle is a Handle for the Msgpack Schema-Free Encoding Format. +type MsgpackHandle struct { + BasicHandle + + // RawToString controls how raw bytes are decoded into a nil interface{}. + RawToString bool + // WriteExt flag supports encoding configured extensions with extension tags. + // It also controls whether other elements of the new spec are encoded (ie Str8). + // + // With WriteExt=false, configured extensions are serialized as raw bytes + // and Str8 is not encoded. + // + // A stream can still be decoded into a typed value, provided an appropriate value + // is provided, but the type cannot be inferred from the stream. If no appropriate + // type is provided (e.g. decoding into a nil interface{}), you get back + // a []byte or string based on the setting of RawToString. + WriteExt bool +} + +func (h *MsgpackHandle) newEncDriver(w encWriter) encDriver { + return &msgpackEncDriver{w: w, h: h} +} + +func (h *MsgpackHandle) newDecDriver(r decReader) decDriver { + return &msgpackDecDriver{r: r, h: h} +} + +func (h *MsgpackHandle) writeExt() bool { + return h.WriteExt +} + +func (h *MsgpackHandle) getBasicHandle() *BasicHandle { + return &h.BasicHandle +} + +//-------------------------------------------------- + +type msgpackSpecRpcCodec struct { + rpcCodec +} + +// /////////////// Spec RPC Codec /////////////////// +func (c *msgpackSpecRpcCodec) WriteRequest(r *rpc.Request, body interface{}) error { + // WriteRequest can write to both a Go service, and other services that do + // not abide by the 1 argument rule of a Go service. + // We discriminate based on if the body is a MsgpackSpecRpcMultiArgs + var bodyArr []interface{} + if m, ok := body.(MsgpackSpecRpcMultiArgs); ok { + bodyArr = ([]interface{})(m) + } else { + bodyArr = []interface{}{body} + } + r2 := []interface{}{0, uint32(r.Seq), r.ServiceMethod, bodyArr} + return c.write(r2, nil, false, true) +} + +func (c *msgpackSpecRpcCodec) WriteResponse(r *rpc.Response, body interface{}) error { + var moe interface{} + if r.Error != "" { + moe = r.Error + } + if moe != nil && body != nil { + body = nil + } + r2 := []interface{}{1, uint32(r.Seq), moe, body} + return c.write(r2, nil, false, true) +} + +func (c *msgpackSpecRpcCodec) ReadResponseHeader(r *rpc.Response) error { + return c.parseCustomHeader(1, &r.Seq, &r.Error) +} + +func (c *msgpackSpecRpcCodec) ReadRequestHeader(r *rpc.Request) error { + return c.parseCustomHeader(0, &r.Seq, &r.ServiceMethod) +} + +func (c *msgpackSpecRpcCodec) ReadRequestBody(body interface{}) error { + if body == nil { // read and discard + return c.read(nil) + } + bodyArr := []interface{}{body} + return c.read(&bodyArr) +} + +func (c *msgpackSpecRpcCodec) parseCustomHeader(expectTypeByte byte, msgid *uint64, methodOrError *string) (err error) { + + if c.cls { + return io.EOF + } + + // We read the response header by hand + // so that the body can be decoded on its own from the stream at a later time. + + const fia byte = 0x94 //four item array descriptor value + // Not sure why the panic of EOF is swallowed above. + // if bs1 := c.dec.r.readn1(); bs1 != fia { + // err = fmt.Errorf("Unexpected value for array descriptor: Expecting %v. Received %v", fia, bs1) + // return + // } + var b byte + b, err = c.br.ReadByte() + if err != nil { + return + } + if b != fia { + err = fmt.Errorf("Unexpected value for array descriptor: Expecting %v. Received %v", fia, b) + return + } + + if err = c.read(&b); err != nil { + return + } + if b != expectTypeByte { + err = fmt.Errorf("Unexpected byte descriptor in header. Expecting %v. Received %v", expectTypeByte, b) + return + } + if err = c.read(msgid); err != nil { + return + } + if err = c.read(methodOrError); err != nil { + return + } + return +} + +//-------------------------------------------------- + +// msgpackSpecRpc is the implementation of Rpc that uses custom communication protocol +// as defined in the msgpack spec at https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md +type msgpackSpecRpc struct{} + +// MsgpackSpecRpc implements Rpc using the communication protocol defined in +// the msgpack spec at https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md . +// Its methods (ServerCodec and ClientCodec) return values that implement RpcCodecBuffered. +var MsgpackSpecRpc msgpackSpecRpc + +func (x msgpackSpecRpc) ServerCodec(conn io.ReadWriteCloser, h Handle) rpc.ServerCodec { + return &msgpackSpecRpcCodec{newRPCCodec(conn, h)} +} + +func (x msgpackSpecRpc) ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec { + return &msgpackSpecRpcCodec{newRPCCodec(conn, h)} +} + +var _ decDriver = (*msgpackDecDriver)(nil) +var _ encDriver = (*msgpackEncDriver)(nil) diff --git a/vendor/github.com/hashicorp/go-msgpack/codec/rpc.go b/vendor/github.com/hashicorp/go-msgpack/codec/rpc.go new file mode 100644 index 00000000000..d014dbdcc7d --- /dev/null +++ b/vendor/github.com/hashicorp/go-msgpack/codec/rpc.go @@ -0,0 +1,152 @@ +// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved. +// Use of this source code is governed by a BSD-style license found in the LICENSE file. + +package codec + +import ( + "bufio" + "io" + "net/rpc" + "sync" +) + +// Rpc provides a rpc Server or Client Codec for rpc communication. +type Rpc interface { + ServerCodec(conn io.ReadWriteCloser, h Handle) rpc.ServerCodec + ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec +} + +// RpcCodecBuffered allows access to the underlying bufio.Reader/Writer +// used by the rpc connection. It accomodates use-cases where the connection +// should be used by rpc and non-rpc functions, e.g. streaming a file after +// sending an rpc response. +type RpcCodecBuffered interface { + BufferedReader() *bufio.Reader + BufferedWriter() *bufio.Writer +} + +// ------------------------------------- + +// rpcCodec defines the struct members and common methods. +type rpcCodec struct { + rwc io.ReadWriteCloser + dec *Decoder + enc *Encoder + bw *bufio.Writer + br *bufio.Reader + mu sync.Mutex + cls bool +} + +func newRPCCodec(conn io.ReadWriteCloser, h Handle) rpcCodec { + bw := bufio.NewWriter(conn) + br := bufio.NewReader(conn) + return rpcCodec{ + rwc: conn, + bw: bw, + br: br, + enc: NewEncoder(bw, h), + dec: NewDecoder(br, h), + } +} + +func (c *rpcCodec) BufferedReader() *bufio.Reader { + return c.br +} + +func (c *rpcCodec) BufferedWriter() *bufio.Writer { + return c.bw +} + +func (c *rpcCodec) write(obj1, obj2 interface{}, writeObj2, doFlush bool) (err error) { + if c.cls { + return io.EOF + } + if err = c.enc.Encode(obj1); err != nil { + return + } + if writeObj2 { + if err = c.enc.Encode(obj2); err != nil { + return + } + } + if doFlush && c.bw != nil { + return c.bw.Flush() + } + return +} + +func (c *rpcCodec) read(obj interface{}) (err error) { + if c.cls { + return io.EOF + } + //If nil is passed in, we should still attempt to read content to nowhere. + if obj == nil { + var obj2 interface{} + return c.dec.Decode(&obj2) + } + return c.dec.Decode(obj) +} + +func (c *rpcCodec) Close() error { + if c.cls { + return io.EOF + } + c.cls = true + return c.rwc.Close() +} + +func (c *rpcCodec) ReadResponseBody(body interface{}) error { + return c.read(body) +} + +// ------------------------------------- + +type goRpcCodec struct { + rpcCodec +} + +func (c *goRpcCodec) WriteRequest(r *rpc.Request, body interface{}) error { + // Must protect for concurrent access as per API + c.mu.Lock() + defer c.mu.Unlock() + return c.write(r, body, true, true) +} + +func (c *goRpcCodec) WriteResponse(r *rpc.Response, body interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() + return c.write(r, body, true, true) +} + +func (c *goRpcCodec) ReadResponseHeader(r *rpc.Response) error { + return c.read(r) +} + +func (c *goRpcCodec) ReadRequestHeader(r *rpc.Request) error { + return c.read(r) +} + +func (c *goRpcCodec) ReadRequestBody(body interface{}) error { + return c.read(body) +} + +// ------------------------------------- + +// goRpc is the implementation of Rpc that uses the communication protocol +// as defined in net/rpc package. +type goRpc struct{} + +// GoRpc implements Rpc using the communication protocol defined in net/rpc package. +// Its methods (ServerCodec and ClientCodec) return values that implement RpcCodecBuffered. +var GoRpc goRpc + +func (x goRpc) ServerCodec(conn io.ReadWriteCloser, h Handle) rpc.ServerCodec { + return &goRpcCodec{newRPCCodec(conn, h)} +} + +func (x goRpc) ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec { + return &goRpcCodec{newRPCCodec(conn, h)} +} + +var _ RpcCodecBuffered = (*rpcCodec)(nil) // ensure *rpcCodec implements RpcCodecBuffered diff --git a/vendor/github.com/hashicorp/go-msgpack/codec/simple.go b/vendor/github.com/hashicorp/go-msgpack/codec/simple.go new file mode 100644 index 00000000000..9e4d148a2a1 --- /dev/null +++ b/vendor/github.com/hashicorp/go-msgpack/codec/simple.go @@ -0,0 +1,461 @@ +// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved. +// Use of this source code is governed by a BSD-style license found in the LICENSE file. + +package codec + +import "math" + +const ( + _ uint8 = iota + simpleVdNil = 1 + simpleVdFalse = 2 + simpleVdTrue = 3 + simpleVdFloat32 = 4 + simpleVdFloat64 = 5 + + // each lasts for 4 (ie n, n+1, n+2, n+3) + simpleVdPosInt = 8 + simpleVdNegInt = 12 + + // containers: each lasts for 4 (ie n, n+1, n+2, ... n+7) + simpleVdString = 216 + simpleVdByteArray = 224 + simpleVdArray = 232 + simpleVdMap = 240 + simpleVdExt = 248 +) + +type simpleEncDriver struct { + h *SimpleHandle + w encWriter + //b [8]byte +} + +func (e *simpleEncDriver) isBuiltinType(rt uintptr) bool { + return false +} + +func (e *simpleEncDriver) encodeBuiltin(rt uintptr, v interface{}) { +} + +func (e *simpleEncDriver) encodeNil() { + e.w.writen1(simpleVdNil) +} + +func (e *simpleEncDriver) encodeBool(b bool) { + if b { + e.w.writen1(simpleVdTrue) + } else { + e.w.writen1(simpleVdFalse) + } +} + +func (e *simpleEncDriver) encodeFloat32(f float32) { + e.w.writen1(simpleVdFloat32) + e.w.writeUint32(math.Float32bits(f)) +} + +func (e *simpleEncDriver) encodeFloat64(f float64) { + e.w.writen1(simpleVdFloat64) + e.w.writeUint64(math.Float64bits(f)) +} + +func (e *simpleEncDriver) encodeInt(v int64) { + if v < 0 { + e.encUint(uint64(-v), simpleVdNegInt) + } else { + e.encUint(uint64(v), simpleVdPosInt) + } +} + +func (e *simpleEncDriver) encodeUint(v uint64) { + e.encUint(v, simpleVdPosInt) +} + +func (e *simpleEncDriver) encUint(v uint64, bd uint8) { + switch { + case v <= math.MaxUint8: + e.w.writen2(bd, uint8(v)) + case v <= math.MaxUint16: + e.w.writen1(bd + 1) + e.w.writeUint16(uint16(v)) + case v <= math.MaxUint32: + e.w.writen1(bd + 2) + e.w.writeUint32(uint32(v)) + case v <= math.MaxUint64: + e.w.writen1(bd + 3) + e.w.writeUint64(v) + } +} + +func (e *simpleEncDriver) encLen(bd byte, length int) { + switch { + case length == 0: + e.w.writen1(bd) + case length <= math.MaxUint8: + e.w.writen1(bd + 1) + e.w.writen1(uint8(length)) + case length <= math.MaxUint16: + e.w.writen1(bd + 2) + e.w.writeUint16(uint16(length)) + case int64(length) <= math.MaxUint32: + e.w.writen1(bd + 3) + e.w.writeUint32(uint32(length)) + default: + e.w.writen1(bd + 4) + e.w.writeUint64(uint64(length)) + } +} + +func (e *simpleEncDriver) encodeExtPreamble(xtag byte, length int) { + e.encLen(simpleVdExt, length) + e.w.writen1(xtag) +} + +func (e *simpleEncDriver) encodeArrayPreamble(length int) { + e.encLen(simpleVdArray, length) +} + +func (e *simpleEncDriver) encodeMapPreamble(length int) { + e.encLen(simpleVdMap, length) +} + +func (e *simpleEncDriver) encodeString(c charEncoding, v string) { + e.encLen(simpleVdString, len(v)) + e.w.writestr(v) +} + +func (e *simpleEncDriver) encodeSymbol(v string) { + e.encodeString(c_UTF8, v) +} + +func (e *simpleEncDriver) encodeStringBytes(c charEncoding, v []byte) { + e.encLen(simpleVdByteArray, len(v)) + e.w.writeb(v) +} + +//------------------------------------ + +type simpleDecDriver struct { + h *SimpleHandle + r decReader + bdRead bool + bdType valueType + bd byte + //b [8]byte +} + +func (d *simpleDecDriver) initReadNext() { + if d.bdRead { + return + } + d.bd = d.r.readn1() + d.bdRead = true + d.bdType = valueTypeUnset +} + +func (d *simpleDecDriver) currentEncodedType() valueType { + if d.bdType == valueTypeUnset { + switch d.bd { + case simpleVdNil: + d.bdType = valueTypeNil + case simpleVdTrue, simpleVdFalse: + d.bdType = valueTypeBool + case simpleVdPosInt, simpleVdPosInt + 1, simpleVdPosInt + 2, simpleVdPosInt + 3: + d.bdType = valueTypeUint + case simpleVdNegInt, simpleVdNegInt + 1, simpleVdNegInt + 2, simpleVdNegInt + 3: + d.bdType = valueTypeInt + case simpleVdFloat32, simpleVdFloat64: + d.bdType = valueTypeFloat + case simpleVdString, simpleVdString + 1, simpleVdString + 2, simpleVdString + 3, simpleVdString + 4: + d.bdType = valueTypeString + case simpleVdByteArray, simpleVdByteArray + 1, simpleVdByteArray + 2, simpleVdByteArray + 3, simpleVdByteArray + 4: + d.bdType = valueTypeBytes + case simpleVdExt, simpleVdExt + 1, simpleVdExt + 2, simpleVdExt + 3, simpleVdExt + 4: + d.bdType = valueTypeExt + case simpleVdArray, simpleVdArray + 1, simpleVdArray + 2, simpleVdArray + 3, simpleVdArray + 4: + d.bdType = valueTypeArray + case simpleVdMap, simpleVdMap + 1, simpleVdMap + 2, simpleVdMap + 3, simpleVdMap + 4: + d.bdType = valueTypeMap + default: + decErr("currentEncodedType: Unrecognized d.vd: 0x%x", d.bd) + } + } + return d.bdType +} + +func (d *simpleDecDriver) tryDecodeAsNil() bool { + if d.bd == simpleVdNil { + d.bdRead = false + return true + } + return false +} + +func (d *simpleDecDriver) isBuiltinType(rt uintptr) bool { + return false +} + +func (d *simpleDecDriver) decodeBuiltin(rt uintptr, v interface{}) { +} + +func (d *simpleDecDriver) decIntAny() (ui uint64, i int64, neg bool) { + switch d.bd { + case simpleVdPosInt: + ui = uint64(d.r.readn1()) + i = int64(ui) + case simpleVdPosInt + 1: + ui = uint64(d.r.readUint16()) + i = int64(ui) + case simpleVdPosInt + 2: + ui = uint64(d.r.readUint32()) + i = int64(ui) + case simpleVdPosInt + 3: + ui = uint64(d.r.readUint64()) + i = int64(ui) + case simpleVdNegInt: + ui = uint64(d.r.readn1()) + i = -(int64(ui)) + neg = true + case simpleVdNegInt + 1: + ui = uint64(d.r.readUint16()) + i = -(int64(ui)) + neg = true + case simpleVdNegInt + 2: + ui = uint64(d.r.readUint32()) + i = -(int64(ui)) + neg = true + case simpleVdNegInt + 3: + ui = uint64(d.r.readUint64()) + i = -(int64(ui)) + neg = true + default: + decErr("decIntAny: Integer only valid from pos/neg integer1..8. Invalid descriptor: %v", d.bd) + } + // don't do this check, because callers may only want the unsigned value. + // if ui > math.MaxInt64 { + // decErr("decIntAny: Integer out of range for signed int64: %v", ui) + // } + return +} + +func (d *simpleDecDriver) decodeInt(bitsize uint8) (i int64) { + _, i, _ = d.decIntAny() + checkOverflow(0, i, bitsize) + d.bdRead = false + return +} + +func (d *simpleDecDriver) decodeUint(bitsize uint8) (ui uint64) { + ui, i, neg := d.decIntAny() + if neg { + decErr("Assigning negative signed value: %v, to unsigned type", i) + } + checkOverflow(ui, 0, bitsize) + d.bdRead = false + return +} + +func (d *simpleDecDriver) decodeFloat(chkOverflow32 bool) (f float64) { + switch d.bd { + case simpleVdFloat32: + f = float64(math.Float32frombits(d.r.readUint32())) + case simpleVdFloat64: + f = math.Float64frombits(d.r.readUint64()) + default: + if d.bd >= simpleVdPosInt && d.bd <= simpleVdNegInt+3 { + _, i, _ := d.decIntAny() + f = float64(i) + } else { + decErr("Float only valid from float32/64: Invalid descriptor: %v", d.bd) + } + } + checkOverflowFloat32(f, chkOverflow32) + d.bdRead = false + return +} + +// bool can be decoded from bool only (single byte). +func (d *simpleDecDriver) decodeBool() (b bool) { + switch d.bd { + case simpleVdTrue: + b = true + case simpleVdFalse: + default: + decErr("Invalid single-byte value for bool: %s: %x", msgBadDesc, d.bd) + } + d.bdRead = false + return +} + +func (d *simpleDecDriver) readMapLen() (length int) { + d.bdRead = false + return d.decLen() +} + +func (d *simpleDecDriver) readArrayLen() (length int) { + d.bdRead = false + return d.decLen() +} + +func (d *simpleDecDriver) decLen() int { + switch d.bd % 8 { + case 0: + return 0 + case 1: + return int(d.r.readn1()) + case 2: + return int(d.r.readUint16()) + case 3: + ui := uint64(d.r.readUint32()) + checkOverflow(ui, 0, intBitsize) + return int(ui) + case 4: + ui := d.r.readUint64() + checkOverflow(ui, 0, intBitsize) + return int(ui) + } + decErr("decLen: Cannot read length: bd%8 must be in range 0..4. Got: %d", d.bd%8) + return -1 +} + +func (d *simpleDecDriver) decodeString() (s string) { + s = string(d.r.readn(d.decLen())) + d.bdRead = false + return +} + +func (d *simpleDecDriver) decodeBytes(bs []byte) (bsOut []byte, changed bool) { + if clen := d.decLen(); clen > 0 { + // if no contents in stream, don't update the passed byteslice + if len(bs) != clen { + if len(bs) > clen { + bs = bs[:clen] + } else { + bs = make([]byte, clen) + } + bsOut = bs + changed = true + } + d.r.readb(bs) + } + d.bdRead = false + return +} + +func (d *simpleDecDriver) decodeExt(verifyTag bool, tag byte) (xtag byte, xbs []byte) { + switch d.bd { + case simpleVdExt, simpleVdExt + 1, simpleVdExt + 2, simpleVdExt + 3, simpleVdExt + 4: + l := d.decLen() + xtag = d.r.readn1() + if verifyTag && xtag != tag { + decErr("Wrong extension tag. Got %b. Expecting: %v", xtag, tag) + } + xbs = d.r.readn(l) + case simpleVdByteArray, simpleVdByteArray + 1, simpleVdByteArray + 2, simpleVdByteArray + 3, simpleVdByteArray + 4: + xbs, _ = d.decodeBytes(nil) + default: + decErr("Invalid d.vd for extensions (Expecting extensions or byte array). Got: 0x%x", d.bd) + } + d.bdRead = false + return +} + +func (d *simpleDecDriver) decodeNaked() (v interface{}, vt valueType, decodeFurther bool) { + d.initReadNext() + + switch d.bd { + case simpleVdNil: + vt = valueTypeNil + case simpleVdFalse: + vt = valueTypeBool + v = false + case simpleVdTrue: + vt = valueTypeBool + v = true + case simpleVdPosInt, simpleVdPosInt + 1, simpleVdPosInt + 2, simpleVdPosInt + 3: + vt = valueTypeUint + ui, _, _ := d.decIntAny() + v = ui + case simpleVdNegInt, simpleVdNegInt + 1, simpleVdNegInt + 2, simpleVdNegInt + 3: + vt = valueTypeInt + _, i, _ := d.decIntAny() + v = i + case simpleVdFloat32: + vt = valueTypeFloat + v = d.decodeFloat(true) + case simpleVdFloat64: + vt = valueTypeFloat + v = d.decodeFloat(false) + case simpleVdString, simpleVdString + 1, simpleVdString + 2, simpleVdString + 3, simpleVdString + 4: + vt = valueTypeString + v = d.decodeString() + case simpleVdByteArray, simpleVdByteArray + 1, simpleVdByteArray + 2, simpleVdByteArray + 3, simpleVdByteArray + 4: + vt = valueTypeBytes + v, _ = d.decodeBytes(nil) + case simpleVdExt, simpleVdExt + 1, simpleVdExt + 2, simpleVdExt + 3, simpleVdExt + 4: + vt = valueTypeExt + l := d.decLen() + var re RawExt + re.Tag = d.r.readn1() + re.Data = d.r.readn(l) + v = &re + vt = valueTypeExt + case simpleVdArray, simpleVdArray + 1, simpleVdArray + 2, simpleVdArray + 3, simpleVdArray + 4: + vt = valueTypeArray + decodeFurther = true + case simpleVdMap, simpleVdMap + 1, simpleVdMap + 2, simpleVdMap + 3, simpleVdMap + 4: + vt = valueTypeMap + decodeFurther = true + default: + decErr("decodeNaked: Unrecognized d.vd: 0x%x", d.bd) + } + + if !decodeFurther { + d.bdRead = false + } + return +} + +//------------------------------------ + +// SimpleHandle is a Handle for a very simple encoding format. +// +// simple is a simplistic codec similar to binc, but not as compact. +// - Encoding of a value is always preceeded by the descriptor byte (bd) +// - True, false, nil are encoded fully in 1 byte (the descriptor) +// - Integers (intXXX, uintXXX) are encoded in 1, 2, 4 or 8 bytes (plus a descriptor byte). +// There are positive (uintXXX and intXXX >= 0) and negative (intXXX < 0) integers. +// - Floats are encoded in 4 or 8 bytes (plus a descriptor byte) +// - Lenght of containers (strings, bytes, array, map, extensions) +// are encoded in 0, 1, 2, 4 or 8 bytes. +// Zero-length containers have no length encoded. +// For others, the number of bytes is given by pow(2, bd%3) +// - maps are encoded as [bd] [length] [[key][value]]... +// - arrays are encoded as [bd] [length] [value]... +// - extensions are encoded as [bd] [length] [tag] [byte]... +// - strings/bytearrays are encoded as [bd] [length] [byte]... +// +// The full spec will be published soon. +type SimpleHandle struct { + BasicHandle +} + +func (h *SimpleHandle) newEncDriver(w encWriter) encDriver { + return &simpleEncDriver{w: w, h: h} +} + +func (h *SimpleHandle) newDecDriver(r decReader) decDriver { + return &simpleDecDriver{r: r, h: h} +} + +func (_ *SimpleHandle) writeExt() bool { + return true +} + +func (h *SimpleHandle) getBasicHandle() *BasicHandle { + return &h.BasicHandle +} + +var _ decDriver = (*simpleDecDriver)(nil) +var _ encDriver = (*simpleEncDriver)(nil) diff --git a/vendor/github.com/hashicorp/go-msgpack/codec/time.go b/vendor/github.com/hashicorp/go-msgpack/codec/time.go new file mode 100644 index 00000000000..c86d65328d7 --- /dev/null +++ b/vendor/github.com/hashicorp/go-msgpack/codec/time.go @@ -0,0 +1,193 @@ +// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved. +// Use of this source code is governed by a BSD-style license found in the LICENSE file. + +package codec + +import ( + "time" +) + +var ( + timeDigits = [...]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'} +) + +// EncodeTime encodes a time.Time as a []byte, including +// information on the instant in time and UTC offset. +// +// Format Description +// +// A timestamp is composed of 3 components: +// +// - secs: signed integer representing seconds since unix epoch +// - nsces: unsigned integer representing fractional seconds as a +// nanosecond offset within secs, in the range 0 <= nsecs < 1e9 +// - tz: signed integer representing timezone offset in minutes east of UTC, +// and a dst (daylight savings time) flag +// +// When encoding a timestamp, the first byte is the descriptor, which +// defines which components are encoded and how many bytes are used to +// encode secs and nsecs components. *If secs/nsecs is 0 or tz is UTC, it +// is not encoded in the byte array explicitly*. +// +// Descriptor 8 bits are of the form `A B C DDD EE`: +// A: Is secs component encoded? 1 = true +// B: Is nsecs component encoded? 1 = true +// C: Is tz component encoded? 1 = true +// DDD: Number of extra bytes for secs (range 0-7). +// If A = 1, secs encoded in DDD+1 bytes. +// If A = 0, secs is not encoded, and is assumed to be 0. +// If A = 1, then we need at least 1 byte to encode secs. +// DDD says the number of extra bytes beyond that 1. +// E.g. if DDD=0, then secs is represented in 1 byte. +// if DDD=2, then secs is represented in 3 bytes. +// EE: Number of extra bytes for nsecs (range 0-3). +// If B = 1, nsecs encoded in EE+1 bytes (similar to secs/DDD above) +// +// Following the descriptor bytes, subsequent bytes are: +// +// secs component encoded in `DDD + 1` bytes (if A == 1) +// nsecs component encoded in `EE + 1` bytes (if B == 1) +// tz component encoded in 2 bytes (if C == 1) +// +// secs and nsecs components are integers encoded in a BigEndian +// 2-complement encoding format. +// +// tz component is encoded as 2 bytes (16 bits). Most significant bit 15 to +// Least significant bit 0 are described below: +// +// Timezone offset has a range of -12:00 to +14:00 (ie -720 to +840 minutes). +// Bit 15 = have\_dst: set to 1 if we set the dst flag. +// Bit 14 = dst\_on: set to 1 if dst is in effect at the time, or 0 if not. +// Bits 13..0 = timezone offset in minutes. It is a signed integer in Big Endian format. +// +func encodeTime(t time.Time) []byte { + //t := rv.Interface().(time.Time) + tsecs, tnsecs := t.Unix(), t.Nanosecond() + var ( + bd byte + btmp [8]byte + bs [16]byte + i int = 1 + ) + l := t.Location() + if l == time.UTC { + l = nil + } + if tsecs != 0 { + bd = bd | 0x80 + bigen.PutUint64(btmp[:], uint64(tsecs)) + f := pruneSignExt(btmp[:], tsecs >= 0) + bd = bd | (byte(7-f) << 2) + copy(bs[i:], btmp[f:]) + i = i + (8 - f) + } + if tnsecs != 0 { + bd = bd | 0x40 + bigen.PutUint32(btmp[:4], uint32(tnsecs)) + f := pruneSignExt(btmp[:4], true) + bd = bd | byte(3-f) + copy(bs[i:], btmp[f:4]) + i = i + (4 - f) + } + if l != nil { + bd = bd | 0x20 + // Note that Go Libs do not give access to dst flag. + _, zoneOffset := t.Zone() + //zoneName, zoneOffset := t.Zone() + zoneOffset /= 60 + z := uint16(zoneOffset) + bigen.PutUint16(btmp[:2], z) + // clear dst flags + bs[i] = btmp[0] & 0x3f + bs[i+1] = btmp[1] + i = i + 2 + } + bs[0] = bd + return bs[0:i] +} + +// DecodeTime decodes a []byte into a time.Time. +func decodeTime(bs []byte) (tt time.Time, err error) { + bd := bs[0] + var ( + tsec int64 + tnsec uint32 + tz uint16 + i byte = 1 + i2 byte + n byte + ) + if bd&(1<<7) != 0 { + var btmp [8]byte + n = ((bd >> 2) & 0x7) + 1 + i2 = i + n + copy(btmp[8-n:], bs[i:i2]) + //if first bit of bs[i] is set, then fill btmp[0..8-n] with 0xff (ie sign extend it) + if bs[i]&(1<<7) != 0 { + copy(btmp[0:8-n], bsAll0xff) + //for j,k := byte(0), 8-n; j < k; j++ { btmp[j] = 0xff } + } + i = i2 + tsec = int64(bigen.Uint64(btmp[:])) + } + if bd&(1<<6) != 0 { + var btmp [4]byte + n = (bd & 0x3) + 1 + i2 = i + n + copy(btmp[4-n:], bs[i:i2]) + i = i2 + tnsec = bigen.Uint32(btmp[:]) + } + if bd&(1<<5) == 0 { + tt = time.Unix(tsec, int64(tnsec)).UTC() + return + } + // In stdlib time.Parse, when a date is parsed without a zone name, it uses "" as zone name. + // However, we need name here, so it can be shown when time is printed. + // Zone name is in form: UTC-08:00. + // Note that Go Libs do not give access to dst flag, so we ignore dst bits + + i2 = i + 2 + tz = bigen.Uint16(bs[i:i2]) + i = i2 + // sign extend sign bit into top 2 MSB (which were dst bits): + if tz&(1<<13) == 0 { // positive + tz = tz & 0x3fff //clear 2 MSBs: dst bits + } else { // negative + tz = tz | 0xc000 //set 2 MSBs: dst bits + //tzname[3] = '-' (TODO: verify. this works here) + } + tzint := int16(tz) + if tzint == 0 { + tt = time.Unix(tsec, int64(tnsec)).UTC() + } else { + // For Go Time, do not use a descriptive timezone. + // It's unnecessary, and makes it harder to do a reflect.DeepEqual. + // The Offset already tells what the offset should be, if not on UTC and unknown zone name. + // var zoneName = timeLocUTCName(tzint) + tt = time.Unix(tsec, int64(tnsec)).In(time.FixedZone("", int(tzint)*60)) + } + return +} + +func timeLocUTCName(tzint int16) string { + if tzint == 0 { + return "UTC" + } + var tzname = []byte("UTC+00:00") + //tzname := fmt.Sprintf("UTC%s%02d:%02d", tzsign, tz/60, tz%60) //perf issue using Sprintf. inline below. + //tzhr, tzmin := tz/60, tz%60 //faster if u convert to int first + var tzhr, tzmin int16 + if tzint < 0 { + tzname[3] = '-' // (TODO: verify. this works here) + tzhr, tzmin = -tzint/60, (-tzint)%60 + } else { + tzhr, tzmin = tzint/60, tzint%60 + } + tzname[4] = timeDigits[tzhr/10] + tzname[5] = timeDigits[tzhr%10] + tzname[7] = timeDigits[tzmin/10] + tzname[8] = timeDigits[tzmin%10] + return string(tzname) + //return time.FixedZone(string(tzname), int(tzint)*60) +} diff --git a/vendor/github.com/hashicorp/raft/LICENSE b/vendor/github.com/hashicorp/raft/LICENSE new file mode 100644 index 00000000000..c33dcc7c928 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/LICENSE @@ -0,0 +1,354 @@ +Mozilla Public License, version 2.0 + +1. Definitions + +1.1. “Contributor” + + means each individual or legal entity that creates, contributes to the + creation of, or owns Covered Software. + +1.2. “Contributor Version” + + means the combination of the Contributions of others (if any) used by a + Contributor and that particular Contributor’s Contribution. + +1.3. “Contribution” + + means Covered Software of a particular Contributor. + +1.4. “Covered Software” + + means Source Code Form to which the initial Contributor has attached the + notice in Exhibit A, the Executable Form of such Source Code Form, and + Modifications of such Source Code Form, in each case including portions + thereof. + +1.5. “Incompatible With Secondary Licenses” + means + + a. that the initial Contributor has attached the notice described in + Exhibit B to the Covered Software; or + + b. that the Covered Software was made available under the terms of version + 1.1 or earlier of the License, but not also under the terms of a + Secondary License. + +1.6. “Executable Form” + + means any form of the work other than Source Code Form. + +1.7. “Larger Work” + + means a work that combines Covered Software with other material, in a separate + file or files, that is not Covered Software. + +1.8. “License” + + means this document. + +1.9. “Licensable” + + means having the right to grant, to the maximum extent possible, whether at the + time of the initial grant or subsequently, any and all of the rights conveyed by + this License. + +1.10. “Modifications” + + means any of the following: + + a. any file in Source Code Form that results from an addition to, deletion + from, or modification of the contents of Covered Software; or + + b. any new file in Source Code Form that contains any Covered Software. + +1.11. “Patent Claims” of a Contributor + + means any patent claim(s), including without limitation, method, process, + and apparatus claims, in any patent Licensable by such Contributor that + would be infringed, but for the grant of the License, by the making, + using, selling, offering for sale, having made, import, or transfer of + either its Contributions or its Contributor Version. + +1.12. “Secondary License” + + means either the GNU General Public License, Version 2.0, the GNU Lesser + General Public License, Version 2.1, the GNU Affero General Public + License, Version 3.0, or any later versions of those licenses. + +1.13. “Source Code Form” + + means the form of the work preferred for making modifications. + +1.14. “You” (or “Your”) + + means an individual or a legal entity exercising rights under this + License. For legal entities, “You” includes any entity that controls, is + controlled by, or is under common control with You. For purposes of this + definition, “control” means (a) the power, direct or indirect, to cause + the direction or management of such entity, whether by contract or + otherwise, or (b) ownership of more than fifty percent (50%) of the + outstanding shares or beneficial ownership of such entity. + + +2. License Grants and Conditions + +2.1. Grants + + Each Contributor hereby grants You a world-wide, royalty-free, + non-exclusive license: + + a. under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or as + part of a Larger Work; and + + b. under Patent Claims of such Contributor to make, use, sell, offer for + sale, have made, import, and otherwise transfer either its Contributions + or its Contributor Version. + +2.2. Effective Date + + The licenses granted in Section 2.1 with respect to any Contribution become + effective for each Contribution on the date the Contributor first distributes + such Contribution. + +2.3. Limitations on Grant Scope + + The licenses granted in this Section 2 are the only rights granted under this + License. No additional rights or licenses will be implied from the distribution + or licensing of Covered Software under this License. Notwithstanding Section + 2.1(b) above, no patent license is granted by a Contributor: + + a. for any code that a Contributor has removed from Covered Software; or + + b. for infringements caused by: (i) Your and any other third party’s + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + + c. under Patent Claims infringed by Covered Software in the absence of its + Contributions. + + This License does not grant any rights in the trademarks, service marks, or + logos of any Contributor (except as may be necessary to comply with the + notice requirements in Section 3.4). + +2.4. Subsequent Licenses + + No Contributor makes additional grants as a result of Your choice to + distribute the Covered Software under a subsequent version of this License + (see Section 10.2) or under the terms of a Secondary License (if permitted + under the terms of Section 3.3). + +2.5. Representation + + Each Contributor represents that the Contributor believes its Contributions + are its original creation(s) or it has sufficient rights to grant the + rights to its Contributions conveyed by this License. + +2.6. Fair Use + + This License is not intended to limit any rights You have under applicable + copyright doctrines of fair use, fair dealing, or other equivalents. + +2.7. Conditions + + Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in + Section 2.1. + + +3. Responsibilities + +3.1. Distribution of Source Form + + All distribution of Covered Software in Source Code Form, including any + Modifications that You create or to which You contribute, must be under the + terms of this License. You must inform recipients that the Source Code Form + of the Covered Software is governed by the terms of this License, and how + they can obtain a copy of this License. You may not attempt to alter or + restrict the recipients’ rights in the Source Code Form. + +3.2. Distribution of Executable Form + + If You distribute Covered Software in Executable Form then: + + a. such Covered Software must also be made available in Source Code Form, + as described in Section 3.1, and You must inform recipients of the + Executable Form how they can obtain a copy of such Source Code Form by + reasonable means in a timely manner, at a charge no more than the cost + of distribution to the recipient; and + + b. You may distribute such Executable Form under the terms of this License, + or sublicense it under different terms, provided that the license for + the Executable Form does not attempt to limit or alter the recipients’ + rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + + You may create and distribute a Larger Work under terms of Your choice, + provided that You also comply with the requirements of this License for the + Covered Software. If the Larger Work is a combination of Covered Software + with a work governed by one or more Secondary Licenses, and the Covered + Software is not Incompatible With Secondary Licenses, this License permits + You to additionally distribute such Covered Software under the terms of + such Secondary License(s), so that the recipient of the Larger Work may, at + their option, further distribute the Covered Software under the terms of + either this License or such Secondary License(s). + +3.4. Notices + + You may not remove or alter the substance of any license notices (including + copyright notices, patent notices, disclaimers of warranty, or limitations + of liability) contained within the Source Code Form of the Covered + Software, except that You may alter any license notices to the extent + required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + + You may choose to offer, and to charge a fee for, warranty, support, + indemnity or liability obligations to one or more recipients of Covered + Software. However, You may do so only on Your own behalf, and not on behalf + of any Contributor. You must make it absolutely clear that any such + warranty, support, indemnity, or liability obligation is offered by You + alone, and You hereby agree to indemnify every Contributor for any + liability incurred by such Contributor as a result of warranty, support, + indemnity or liability terms You offer. You may include additional + disclaimers of warranty and limitations of liability specific to any + jurisdiction. + +4. Inability to Comply Due to Statute or Regulation + + If it is impossible for You to comply with any of the terms of this License + with respect to some or all of the Covered Software due to statute, judicial + order, or regulation then You must: (a) comply with the terms of this License + to the maximum extent possible; and (b) describe the limitations and the code + they affect. Such description must be placed in a text file included with all + distributions of the Covered Software under this License. Except to the + extent prohibited by statute or regulation, such description must be + sufficiently detailed for a recipient of ordinary skill to be able to + understand it. + +5. Termination + +5.1. The rights granted under this License will terminate automatically if You + fail to comply with any of its terms. However, if You become compliant, + then the rights granted under this License from a particular Contributor + are reinstated (a) provisionally, unless and until such Contributor + explicitly and finally terminates Your grants, and (b) on an ongoing basis, + if such Contributor fails to notify You of the non-compliance by some + reasonable means prior to 60 days after You have come back into compliance. + Moreover, Your grants from a particular Contributor are reinstated on an + ongoing basis if such Contributor notifies You of the non-compliance by + some reasonable means, this is the first time You have received notice of + non-compliance with this License from such Contributor, and You become + compliant prior to 30 days after Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent + infringement claim (excluding declaratory judgment actions, counter-claims, + and cross-claims) alleging that a Contributor Version directly or + indirectly infringes any patent, then the rights granted to You by any and + all Contributors for the Covered Software under Section 2.1 of this License + shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user + license agreements (excluding distributors and resellers) which have been + validly granted by You or Your distributors under this License prior to + termination shall survive termination. + +6. Disclaimer of Warranty + + Covered Software is provided under this License on an “as is” basis, without + warranty of any kind, either expressed, implied, or statutory, including, + without limitation, warranties that the Covered Software is free of defects, + merchantable, fit for a particular purpose or non-infringing. The entire + risk as to the quality and performance of the Covered Software is with You. + Should any Covered Software prove defective in any respect, You (not any + Contributor) assume the cost of any necessary servicing, repair, or + correction. This disclaimer of warranty constitutes an essential part of this + License. No use of any Covered Software is authorized under this License + except under this disclaimer. + +7. Limitation of Liability + + Under no circumstances and under no legal theory, whether tort (including + negligence), contract, or otherwise, shall any Contributor, or anyone who + distributes Covered Software as permitted above, be liable to You for any + direct, indirect, special, incidental, or consequential damages of any + character including, without limitation, damages for lost profits, loss of + goodwill, work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses, even if such party shall have been + informed of the possibility of such damages. This limitation of liability + shall not apply to liability for death or personal injury resulting from such + party’s negligence to the extent applicable law prohibits such limitation. + Some jurisdictions do not allow the exclusion or limitation of incidental or + consequential damages, so this exclusion and limitation may not apply to You. + +8. Litigation + + Any litigation relating to this License may be brought only in the courts of + a jurisdiction where the defendant maintains its principal place of business + and such litigation shall be governed by laws of that jurisdiction, without + reference to its conflict-of-law provisions. Nothing in this Section shall + prevent a party’s ability to bring cross-claims or counter-claims. + +9. Miscellaneous + + This License represents the complete agreement concerning the subject matter + hereof. If any provision of this License is held to be unenforceable, such + provision shall be reformed only to the extent necessary to make it + enforceable. Any law or regulation which provides that the language of a + contract shall be construed against the drafter shall not be used to construe + this License against a Contributor. + + +10. Versions of the License + +10.1. New Versions + + Mozilla Foundation is the license steward. Except as provided in Section + 10.3, no one other than the license steward has the right to modify or + publish new versions of this License. Each version will be given a + distinguishing version number. + +10.2. Effect of New Versions + + You may distribute the Covered Software under the terms of the version of + the License under which You originally received the Covered Software, or + under the terms of any subsequent version published by the license + steward. + +10.3. Modified Versions + + If you create software not governed by this License, and you want to + create a new license for such software, you may create and use a modified + version of this License if you rename the license and remove any + references to the name of the license steward (except to note that such + modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses + If You choose to distribute Source Code Form that is Incompatible With + Secondary Licenses under the terms of this version of the License, the + notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice + + This Source Code Form is subject to the + terms of the Mozilla Public License, v. + 2.0. If a copy of the MPL was not + distributed with this file, You can + obtain one at + http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular file, then +You may include the notice in a location (such as a LICENSE file in a relevant +directory) where a recipient would be likely to look for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - “Incompatible With Secondary Licenses” Notice + + This Source Code Form is “Incompatible + With Secondary Licenses”, as defined by + the Mozilla Public License, v. 2.0. + diff --git a/vendor/github.com/hashicorp/raft/api.go b/vendor/github.com/hashicorp/raft/api.go new file mode 100644 index 00000000000..73f057c9858 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/api.go @@ -0,0 +1,1008 @@ +package raft + +import ( + "errors" + "fmt" + "io" + "log" + "os" + "strconv" + "sync" + "time" + + "github.com/armon/go-metrics" +) + +var ( + // ErrLeader is returned when an operation can't be completed on a + // leader node. + ErrLeader = errors.New("node is the leader") + + // ErrNotLeader is returned when an operation can't be completed on a + // follower or candidate node. + ErrNotLeader = errors.New("node is not the leader") + + // ErrLeadershipLost is returned when a leader fails to commit a log entry + // because it's been deposed in the process. + ErrLeadershipLost = errors.New("leadership lost while committing log") + + // ErrAbortedByRestore is returned when a leader fails to commit a log + // entry because it's been superseded by a user snapshot restore. + ErrAbortedByRestore = errors.New("snapshot restored while committing log") + + // ErrRaftShutdown is returned when operations are requested against an + // inactive Raft. + ErrRaftShutdown = errors.New("raft is already shutdown") + + // ErrEnqueueTimeout is returned when a command fails due to a timeout. + ErrEnqueueTimeout = errors.New("timed out enqueuing operation") + + // ErrNothingNewToSnapshot is returned when trying to create a snapshot + // but there's nothing new commited to the FSM since we started. + ErrNothingNewToSnapshot = errors.New("nothing new to snapshot") + + // ErrUnsupportedProtocol is returned when an operation is attempted + // that's not supported by the current protocol version. + ErrUnsupportedProtocol = errors.New("operation not supported with current protocol version") + + // ErrCantBootstrap is returned when attempt is made to bootstrap a + // cluster that already has state present. + ErrCantBootstrap = errors.New("bootstrap only works on new clusters") +) + +// Raft implements a Raft node. +type Raft struct { + raftState + + // protocolVersion is used to inter-operate with Raft servers running + // different versions of the library. See comments in config.go for more + // details. + protocolVersion ProtocolVersion + + // applyCh is used to async send logs to the main thread to + // be committed and applied to the FSM. + applyCh chan *logFuture + + // Configuration provided at Raft initialization + conf Config + + // FSM is the client state machine to apply commands to + fsm FSM + + // fsmMutateCh is used to send state-changing updates to the FSM. This + // receives pointers to commitTuple structures when applying logs or + // pointers to restoreFuture structures when restoring a snapshot. We + // need control over the order of these operations when doing user + // restores so that we finish applying any old log applies before we + // take a user snapshot on the leader, otherwise we might restore the + // snapshot and apply old logs to it that were in the pipe. + fsmMutateCh chan interface{} + + // fsmSnapshotCh is used to trigger a new snapshot being taken + fsmSnapshotCh chan *reqSnapshotFuture + + // lastContact is the last time we had contact from the + // leader node. This can be used to gauge staleness. + lastContact time.Time + lastContactLock sync.RWMutex + + // Leader is the current cluster leader + leader ServerAddress + leaderLock sync.RWMutex + + // leaderCh is used to notify of leadership changes + leaderCh chan bool + + // leaderState used only while state is leader + leaderState leaderState + + // Stores our local server ID, used to avoid sending RPCs to ourself + localID ServerID + + // Stores our local addr + localAddr ServerAddress + + // Used for our logging + logger *log.Logger + + // LogStore provides durable storage for logs + logs LogStore + + // Used to request the leader to make configuration changes. + configurationChangeCh chan *configurationChangeFuture + + // Tracks the latest configuration and latest committed configuration from + // the log/snapshot. + configurations configurations + + // RPC chan comes from the transport layer + rpcCh <-chan RPC + + // Shutdown channel to exit, protected to prevent concurrent exits + shutdown bool + shutdownCh chan struct{} + shutdownLock sync.Mutex + + // snapshots is used to store and retrieve snapshots + snapshots SnapshotStore + + // userSnapshotCh is used for user-triggered snapshots + userSnapshotCh chan *userSnapshotFuture + + // userRestoreCh is used for user-triggered restores of external + // snapshots + userRestoreCh chan *userRestoreFuture + + // stable is a StableStore implementation for durable state + // It provides stable storage for many fields in raftState + stable StableStore + + // The transport layer we use + trans Transport + + // verifyCh is used to async send verify futures to the main thread + // to verify we are still the leader + verifyCh chan *verifyFuture + + // configurationsCh is used to get the configuration data safely from + // outside of the main thread. + configurationsCh chan *configurationsFuture + + // bootstrapCh is used to attempt an initial bootstrap from outside of + // the main thread. + bootstrapCh chan *bootstrapFuture + + // List of observers and the mutex that protects them. The observers list + // is indexed by an artificial ID which is used for deregistration. + observersLock sync.RWMutex + observers map[uint64]*Observer +} + +// BootstrapCluster initializes a server's storage with the given cluster +// configuration. This should only be called at the beginning of time for the +// cluster, and you absolutely must make sure that you call it with the same +// configuration on all the Voter servers. There is no need to bootstrap +// Nonvoter and Staging servers. +// +// One sane approach is to boostrap a single server with a configuration +// listing just itself as a Voter, then invoke AddVoter() on it to add other +// servers to the cluster. +func BootstrapCluster(conf *Config, logs LogStore, stable StableStore, + snaps SnapshotStore, trans Transport, configuration Configuration) error { + // Validate the Raft server config. + if err := ValidateConfig(conf); err != nil { + return err + } + + // Sanity check the Raft peer configuration. + if err := checkConfiguration(configuration); err != nil { + return err + } + + // Make sure the cluster is in a clean state. + hasState, err := HasExistingState(logs, stable, snaps) + if err != nil { + return fmt.Errorf("failed to check for existing state: %v", err) + } + if hasState { + return ErrCantBootstrap + } + + // Set current term to 1. + if err := stable.SetUint64(keyCurrentTerm, 1); err != nil { + return fmt.Errorf("failed to save current term: %v", err) + } + + // Append configuration entry to log. + entry := &Log{ + Index: 1, + Term: 1, + } + if conf.ProtocolVersion < 3 { + entry.Type = LogRemovePeerDeprecated + entry.Data = encodePeers(configuration, trans) + } else { + entry.Type = LogConfiguration + entry.Data = encodeConfiguration(configuration) + } + if err := logs.StoreLog(entry); err != nil { + return fmt.Errorf("failed to append configuration entry to log: %v", err) + } + + return nil +} + +// RecoverCluster is used to manually force a new configuration in order to +// recover from a loss of quorum where the current configuration cannot be +// restored, such as when several servers die at the same time. This works by +// reading all the current state for this server, creating a snapshot with the +// supplied configuration, and then truncating the Raft log. This is the only +// safe way to force a given configuration without actually altering the log to +// insert any new entries, which could cause conflicts with other servers with +// different state. +// +// WARNING! This operation implicitly commits all entries in the Raft log, so +// in general this is an extremely unsafe operation. If you've lost your other +// servers and are performing a manual recovery, then you've also lost the +// commit information, so this is likely the best you can do, but you should be +// aware that calling this can cause Raft log entries that were in the process +// of being replicated but not yet be committed to be committed. +// +// Note the FSM passed here is used for the snapshot operations and will be +// left in a state that should not be used by the application. Be sure to +// discard this FSM and any associated state and provide a fresh one when +// calling NewRaft later. +// +// A typical way to recover the cluster is to shut down all servers and then +// run RecoverCluster on every server using an identical configuration. When +// the cluster is then restarted, and election should occur and then Raft will +// resume normal operation. If it's desired to make a particular server the +// leader, this can be used to inject a new configuration with that server as +// the sole voter, and then join up other new clean-state peer servers using +// the usual APIs in order to bring the cluster back into a known state. +func RecoverCluster(conf *Config, fsm FSM, logs LogStore, stable StableStore, + snaps SnapshotStore, trans Transport, configuration Configuration) error { + // Validate the Raft server config. + if err := ValidateConfig(conf); err != nil { + return err + } + + // Sanity check the Raft peer configuration. + if err := checkConfiguration(configuration); err != nil { + return err + } + + // Refuse to recover if there's no existing state. This would be safe to + // do, but it is likely an indication of an operator error where they + // expect data to be there and it's not. By refusing, we force them + // to show intent to start a cluster fresh by explicitly doing a + // bootstrap, rather than quietly fire up a fresh cluster here. + hasState, err := HasExistingState(logs, stable, snaps) + if err != nil { + return fmt.Errorf("failed to check for existing state: %v", err) + } + if !hasState { + return fmt.Errorf("refused to recover cluster with no initial state, this is probably an operator error") + } + + // Attempt to restore any snapshots we find, newest to oldest. + var snapshotIndex uint64 + var snapshotTerm uint64 + snapshots, err := snaps.List() + if err != nil { + return fmt.Errorf("failed to list snapshots: %v", err) + } + for _, snapshot := range snapshots { + _, source, err := snaps.Open(snapshot.ID) + if err != nil { + // Skip this one and try the next. We will detect if we + // couldn't open any snapshots. + continue + } + defer source.Close() + + if err := fsm.Restore(source); err != nil { + // Same here, skip and try the next one. + continue + } + + snapshotIndex = snapshot.Index + snapshotTerm = snapshot.Term + break + } + if len(snapshots) > 0 && (snapshotIndex == 0 || snapshotTerm == 0) { + return fmt.Errorf("failed to restore any of the available snapshots") + } + + // The snapshot information is the best known end point for the data + // until we play back the Raft log entries. + lastIndex := snapshotIndex + lastTerm := snapshotTerm + + // Apply any Raft log entries past the snapshot. + lastLogIndex, err := logs.LastIndex() + if err != nil { + return fmt.Errorf("failed to find last log: %v", err) + } + for index := snapshotIndex + 1; index <= lastLogIndex; index++ { + var entry Log + if err := logs.GetLog(index, &entry); err != nil { + return fmt.Errorf("failed to get log at index %d: %v", index, err) + } + if entry.Type == LogCommand { + _ = fsm.Apply(&entry) + } + lastIndex = entry.Index + lastTerm = entry.Term + } + + // Create a new snapshot, placing the configuration in as if it was + // committed at index 1. + snapshot, err := fsm.Snapshot() + if err != nil { + return fmt.Errorf("failed to snapshot FSM: %v", err) + } + version := getSnapshotVersion(conf.ProtocolVersion) + sink, err := snaps.Create(version, lastIndex, lastTerm, configuration, 1, trans) + if err != nil { + return fmt.Errorf("failed to create snapshot: %v", err) + } + if err := snapshot.Persist(sink); err != nil { + return fmt.Errorf("failed to persist snapshot: %v", err) + } + if err := sink.Close(); err != nil { + return fmt.Errorf("failed to finalize snapshot: %v", err) + } + + // Compact the log so that we don't get bad interference from any + // configuration change log entries that might be there. + firstLogIndex, err := logs.FirstIndex() + if err != nil { + return fmt.Errorf("failed to get first log index: %v", err) + } + if err := logs.DeleteRange(firstLogIndex, lastLogIndex); err != nil { + return fmt.Errorf("log compaction failed: %v", err) + } + + return nil +} + +// HasExistingState returns true if the server has any existing state (logs, +// knowledge of a current term, or any snapshots). +func HasExistingState(logs LogStore, stable StableStore, snaps SnapshotStore) (bool, error) { + // Make sure we don't have a current term. + currentTerm, err := stable.GetUint64(keyCurrentTerm) + if err == nil { + if currentTerm > 0 { + return true, nil + } + } else { + if err.Error() != "not found" { + return false, fmt.Errorf("failed to read current term: %v", err) + } + } + + // Make sure we have an empty log. + lastIndex, err := logs.LastIndex() + if err != nil { + return false, fmt.Errorf("failed to get last log index: %v", err) + } + if lastIndex > 0 { + return true, nil + } + + // Make sure we have no snapshots + snapshots, err := snaps.List() + if err != nil { + return false, fmt.Errorf("failed to list snapshots: %v", err) + } + if len(snapshots) > 0 { + return true, nil + } + + return false, nil +} + +// NewRaft is used to construct a new Raft node. It takes a configuration, as well +// as implementations of various interfaces that are required. If we have any +// old state, such as snapshots, logs, peers, etc, all those will be restored +// when creating the Raft node. +func NewRaft(conf *Config, fsm FSM, logs LogStore, stable StableStore, snaps SnapshotStore, trans Transport) (*Raft, error) { + // Validate the configuration. + if err := ValidateConfig(conf); err != nil { + return nil, err + } + + // Ensure we have a LogOutput. + var logger *log.Logger + if conf.Logger != nil { + logger = conf.Logger + } else { + if conf.LogOutput == nil { + conf.LogOutput = os.Stderr + } + logger = log.New(conf.LogOutput, "", log.LstdFlags) + } + + // Try to restore the current term. + currentTerm, err := stable.GetUint64(keyCurrentTerm) + if err != nil && err.Error() != "not found" { + return nil, fmt.Errorf("failed to load current term: %v", err) + } + + // Read the index of the last log entry. + lastIndex, err := logs.LastIndex() + if err != nil { + return nil, fmt.Errorf("failed to find last log: %v", err) + } + + // Get the last log entry. + var lastLog Log + if lastIndex > 0 { + if err = logs.GetLog(lastIndex, &lastLog); err != nil { + return nil, fmt.Errorf("failed to get last log at index %d: %v", lastIndex, err) + } + } + + // Make sure we have a valid server address and ID. + protocolVersion := conf.ProtocolVersion + localAddr := ServerAddress(trans.LocalAddr()) + localID := conf.LocalID + + // TODO (slackpad) - When we deprecate protocol version 2, remove this + // along with the AddPeer() and RemovePeer() APIs. + if protocolVersion < 3 && string(localID) != string(localAddr) { + return nil, fmt.Errorf("when running with ProtocolVersion < 3, LocalID must be set to the network address") + } + + // Create Raft struct. + r := &Raft{ + protocolVersion: protocolVersion, + applyCh: make(chan *logFuture), + conf: *conf, + fsm: fsm, + fsmMutateCh: make(chan interface{}, 128), + fsmSnapshotCh: make(chan *reqSnapshotFuture), + leaderCh: make(chan bool), + localID: localID, + localAddr: localAddr, + logger: logger, + logs: logs, + configurationChangeCh: make(chan *configurationChangeFuture), + configurations: configurations{}, + rpcCh: trans.Consumer(), + snapshots: snaps, + userSnapshotCh: make(chan *userSnapshotFuture), + userRestoreCh: make(chan *userRestoreFuture), + shutdownCh: make(chan struct{}), + stable: stable, + trans: trans, + verifyCh: make(chan *verifyFuture, 64), + configurationsCh: make(chan *configurationsFuture, 8), + bootstrapCh: make(chan *bootstrapFuture), + observers: make(map[uint64]*Observer), + } + + // Initialize as a follower. + r.setState(Follower) + + // Start as leader if specified. This should only be used + // for testing purposes. + if conf.StartAsLeader { + r.setState(Leader) + r.setLeader(r.localAddr) + } + + // Restore the current term and the last log. + r.setCurrentTerm(currentTerm) + r.setLastLog(lastLog.Index, lastLog.Term) + + // Attempt to restore a snapshot if there are any. + if err := r.restoreSnapshot(); err != nil { + return nil, err + } + + // Scan through the log for any configuration change entries. + snapshotIndex, _ := r.getLastSnapshot() + for index := snapshotIndex + 1; index <= lastLog.Index; index++ { + var entry Log + if err := r.logs.GetLog(index, &entry); err != nil { + r.logger.Printf("[ERR] raft: Failed to get log at %d: %v", index, err) + panic(err) + } + r.processConfigurationLogEntry(&entry) + } + + r.logger.Printf("[INFO] raft: Initial configuration (index=%d): %+v", + r.configurations.latestIndex, r.configurations.latest.Servers) + + // Setup a heartbeat fast-path to avoid head-of-line + // blocking where possible. It MUST be safe for this + // to be called concurrently with a blocking RPC. + trans.SetHeartbeatHandler(r.processHeartbeat) + + // Start the background work. + r.goFunc(r.run) + r.goFunc(r.runFSM) + r.goFunc(r.runSnapshots) + return r, nil +} + +// restoreSnapshot attempts to restore the latest snapshots, and fails if none +// of them can be restored. This is called at initialization time, and is +// completely unsafe to call at any other time. +func (r *Raft) restoreSnapshot() error { + snapshots, err := r.snapshots.List() + if err != nil { + r.logger.Printf("[ERR] raft: Failed to list snapshots: %v", err) + return err + } + + // Try to load in order of newest to oldest + for _, snapshot := range snapshots { + _, source, err := r.snapshots.Open(snapshot.ID) + if err != nil { + r.logger.Printf("[ERR] raft: Failed to open snapshot %v: %v", snapshot.ID, err) + continue + } + defer source.Close() + + if err := r.fsm.Restore(source); err != nil { + r.logger.Printf("[ERR] raft: Failed to restore snapshot %v: %v", snapshot.ID, err) + continue + } + + // Log success + r.logger.Printf("[INFO] raft: Restored from snapshot %v", snapshot.ID) + + // Update the lastApplied so we don't replay old logs + r.setLastApplied(snapshot.Index) + + // Update the last stable snapshot info + r.setLastSnapshot(snapshot.Index, snapshot.Term) + + // Update the configuration + if snapshot.Version > 0 { + r.configurations.committed = snapshot.Configuration + r.configurations.committedIndex = snapshot.ConfigurationIndex + r.configurations.latest = snapshot.Configuration + r.configurations.latestIndex = snapshot.ConfigurationIndex + } else { + configuration := decodePeers(snapshot.Peers, r.trans) + r.configurations.committed = configuration + r.configurations.committedIndex = snapshot.Index + r.configurations.latest = configuration + r.configurations.latestIndex = snapshot.Index + } + + // Success! + return nil + } + + // If we had snapshots and failed to load them, its an error + if len(snapshots) > 0 { + return fmt.Errorf("failed to load any existing snapshots") + } + return nil +} + +// BootstrapCluster is equivalent to non-member BootstrapCluster but can be +// called on an un-bootstrapped Raft instance after it has been created. This +// should only be called at the beginning of time for the cluster, and you +// absolutely must make sure that you call it with the same configuration on all +// the Voter servers. There is no need to bootstrap Nonvoter and Staging +// servers. +func (r *Raft) BootstrapCluster(configuration Configuration) Future { + bootstrapReq := &bootstrapFuture{} + bootstrapReq.init() + bootstrapReq.configuration = configuration + select { + case <-r.shutdownCh: + return errorFuture{ErrRaftShutdown} + case r.bootstrapCh <- bootstrapReq: + return bootstrapReq + } +} + +// Leader is used to return the current leader of the cluster. +// It may return empty string if there is no current leader +// or the leader is unknown. +func (r *Raft) Leader() ServerAddress { + r.leaderLock.RLock() + leader := r.leader + r.leaderLock.RUnlock() + return leader +} + +// Apply is used to apply a command to the FSM in a highly consistent +// manner. This returns a future that can be used to wait on the application. +// An optional timeout can be provided to limit the amount of time we wait +// for the command to be started. This must be run on the leader or it +// will fail. +func (r *Raft) Apply(cmd []byte, timeout time.Duration) ApplyFuture { + metrics.IncrCounter([]string{"raft", "apply"}, 1) + var timer <-chan time.Time + if timeout > 0 { + timer = time.After(timeout) + } + + // Create a log future, no index or term yet + logFuture := &logFuture{ + log: Log{ + Type: LogCommand, + Data: cmd, + }, + } + logFuture.init() + + select { + case <-timer: + return errorFuture{ErrEnqueueTimeout} + case <-r.shutdownCh: + return errorFuture{ErrRaftShutdown} + case r.applyCh <- logFuture: + return logFuture + } +} + +// Barrier is used to issue a command that blocks until all preceeding +// operations have been applied to the FSM. It can be used to ensure the +// FSM reflects all queued writes. An optional timeout can be provided to +// limit the amount of time we wait for the command to be started. This +// must be run on the leader or it will fail. +func (r *Raft) Barrier(timeout time.Duration) Future { + metrics.IncrCounter([]string{"raft", "barrier"}, 1) + var timer <-chan time.Time + if timeout > 0 { + timer = time.After(timeout) + } + + // Create a log future, no index or term yet + logFuture := &logFuture{ + log: Log{ + Type: LogBarrier, + }, + } + logFuture.init() + + select { + case <-timer: + return errorFuture{ErrEnqueueTimeout} + case <-r.shutdownCh: + return errorFuture{ErrRaftShutdown} + case r.applyCh <- logFuture: + return logFuture + } +} + +// VerifyLeader is used to ensure the current node is still +// the leader. This can be done to prevent stale reads when a +// new leader has potentially been elected. +func (r *Raft) VerifyLeader() Future { + metrics.IncrCounter([]string{"raft", "verify_leader"}, 1) + verifyFuture := &verifyFuture{} + verifyFuture.init() + select { + case <-r.shutdownCh: + return errorFuture{ErrRaftShutdown} + case r.verifyCh <- verifyFuture: + return verifyFuture + } +} + +// GetConfiguration returns the latest configuration and its associated index +// currently in use. This may not yet be committed. This must not be called on +// the main thread (which can access the information directly). +func (r *Raft) GetConfiguration() ConfigurationFuture { + configReq := &configurationsFuture{} + configReq.init() + select { + case <-r.shutdownCh: + configReq.respond(ErrRaftShutdown) + return configReq + case r.configurationsCh <- configReq: + return configReq + } +} + +// AddPeer (deprecated) is used to add a new peer into the cluster. This must be +// run on the leader or it will fail. Use AddVoter/AddNonvoter instead. +func (r *Raft) AddPeer(peer ServerAddress) Future { + if r.protocolVersion > 2 { + return errorFuture{ErrUnsupportedProtocol} + } + + return r.requestConfigChange(configurationChangeRequest{ + command: AddStaging, + serverID: ServerID(peer), + serverAddress: peer, + prevIndex: 0, + }, 0) +} + +// RemovePeer (deprecated) is used to remove a peer from the cluster. If the +// current leader is being removed, it will cause a new election +// to occur. This must be run on the leader or it will fail. +// Use RemoveServer instead. +func (r *Raft) RemovePeer(peer ServerAddress) Future { + if r.protocolVersion > 2 { + return errorFuture{ErrUnsupportedProtocol} + } + + return r.requestConfigChange(configurationChangeRequest{ + command: RemoveServer, + serverID: ServerID(peer), + prevIndex: 0, + }, 0) +} + +// AddVoter will add the given server to the cluster as a staging server. If the +// server is already in the cluster as a voter, this does nothing. This must be +// run on the leader or it will fail. The leader will promote the staging server +// to a voter once that server is ready. If nonzero, prevIndex is the index of +// the only configuration upon which this change may be applied; if another +// configuration entry has been added in the meantime, this request will fail. +// If nonzero, timeout is how long this server should wait before the +// configuration change log entry is appended. +func (r *Raft) AddVoter(id ServerID, address ServerAddress, prevIndex uint64, timeout time.Duration) IndexFuture { + if r.protocolVersion < 2 { + return errorFuture{ErrUnsupportedProtocol} + } + + return r.requestConfigChange(configurationChangeRequest{ + command: AddStaging, + serverID: id, + serverAddress: address, + prevIndex: prevIndex, + }, timeout) +} + +// AddNonvoter will add the given server to the cluster but won't assign it a +// vote. The server will receive log entries, but it won't participate in +// elections or log entry commitment. If the server is already in the cluster as +// a staging server or voter, this does nothing. This must be run on the leader +// or it will fail. For prevIndex and timeout, see AddVoter. +func (r *Raft) AddNonvoter(id ServerID, address ServerAddress, prevIndex uint64, timeout time.Duration) IndexFuture { + if r.protocolVersion < 3 { + return errorFuture{ErrUnsupportedProtocol} + } + + return r.requestConfigChange(configurationChangeRequest{ + command: AddNonvoter, + serverID: id, + serverAddress: address, + prevIndex: prevIndex, + }, timeout) +} + +// RemoveServer will remove the given server from the cluster. If the current +// leader is being removed, it will cause a new election to occur. This must be +// run on the leader or it will fail. For prevIndex and timeout, see AddVoter. +func (r *Raft) RemoveServer(id ServerID, prevIndex uint64, timeout time.Duration) IndexFuture { + if r.protocolVersion < 2 { + return errorFuture{ErrUnsupportedProtocol} + } + + return r.requestConfigChange(configurationChangeRequest{ + command: RemoveServer, + serverID: id, + prevIndex: prevIndex, + }, timeout) +} + +// DemoteVoter will take away a server's vote, if it has one. If present, the +// server will continue to receive log entries, but it won't participate in +// elections or log entry commitment. If the server is not in the cluster, this +// does nothing. This must be run on the leader or it will fail. For prevIndex +// and timeout, see AddVoter. +func (r *Raft) DemoteVoter(id ServerID, prevIndex uint64, timeout time.Duration) IndexFuture { + if r.protocolVersion < 3 { + return errorFuture{ErrUnsupportedProtocol} + } + + return r.requestConfigChange(configurationChangeRequest{ + command: DemoteVoter, + serverID: id, + prevIndex: prevIndex, + }, timeout) +} + +// Shutdown is used to stop the Raft background routines. +// This is not a graceful operation. Provides a future that +// can be used to block until all background routines have exited. +func (r *Raft) Shutdown() Future { + r.shutdownLock.Lock() + defer r.shutdownLock.Unlock() + + if !r.shutdown { + close(r.shutdownCh) + r.shutdown = true + r.setState(Shutdown) + return &shutdownFuture{r} + } + + // avoid closing transport twice + return &shutdownFuture{nil} +} + +// Snapshot is used to manually force Raft to take a snapshot. Returns a future +// that can be used to block until complete, and that contains a function that +// can be used to open the snapshot. +func (r *Raft) Snapshot() SnapshotFuture { + future := &userSnapshotFuture{} + future.init() + select { + case r.userSnapshotCh <- future: + return future + case <-r.shutdownCh: + future.respond(ErrRaftShutdown) + return future + } +} + +// Restore is used to manually force Raft to consume an external snapshot, such +// as if restoring from a backup. We will use the current Raft configuration, +// not the one from the snapshot, so that we can restore into a new cluster. We +// will also use the higher of the index of the snapshot, or the current index, +// and then add 1 to that, so we force a new state with a hole in the Raft log, +// so that the snapshot will be sent to followers and used for any new joiners. +// This can only be run on the leader, and blocks until the restore is complete +// or an error occurs. +// +// WARNING! This operation has the leader take on the state of the snapshot and +// then sets itself up so that it replicates that to its followers though the +// install snapshot process. This involves a potentially dangerous period where +// the leader commits ahead of its followers, so should only be used for disaster +// recovery into a fresh cluster, and should not be used in normal operations. +func (r *Raft) Restore(meta *SnapshotMeta, reader io.Reader, timeout time.Duration) error { + metrics.IncrCounter([]string{"raft", "restore"}, 1) + var timer <-chan time.Time + if timeout > 0 { + timer = time.After(timeout) + } + + // Perform the restore. + restore := &userRestoreFuture{ + meta: meta, + reader: reader, + } + restore.init() + select { + case <-timer: + return ErrEnqueueTimeout + case <-r.shutdownCh: + return ErrRaftShutdown + case r.userRestoreCh <- restore: + // If the restore is ingested then wait for it to complete. + if err := restore.Error(); err != nil { + return err + } + } + + // Apply a no-op log entry. Waiting for this allows us to wait until the + // followers have gotten the restore and replicated at least this new + // entry, which shows that we've also faulted and installed the + // snapshot with the contents of the restore. + noop := &logFuture{ + log: Log{ + Type: LogNoop, + }, + } + noop.init() + select { + case <-timer: + return ErrEnqueueTimeout + case <-r.shutdownCh: + return ErrRaftShutdown + case r.applyCh <- noop: + return noop.Error() + } +} + +// State is used to return the current raft state. +func (r *Raft) State() RaftState { + return r.getState() +} + +// LeaderCh is used to get a channel which delivers signals on +// acquiring or losing leadership. It sends true if we become +// the leader, and false if we lose it. The channel is not buffered, +// and does not block on writes. +func (r *Raft) LeaderCh() <-chan bool { + return r.leaderCh +} + +// String returns a string representation of this Raft node. +func (r *Raft) String() string { + return fmt.Sprintf("Node at %s [%v]", r.localAddr, r.getState()) +} + +// LastContact returns the time of last contact by a leader. +// This only makes sense if we are currently a follower. +func (r *Raft) LastContact() time.Time { + r.lastContactLock.RLock() + last := r.lastContact + r.lastContactLock.RUnlock() + return last +} + +// Stats is used to return a map of various internal stats. This +// should only be used for informative purposes or debugging. +// +// Keys are: "state", "term", "last_log_index", "last_log_term", +// "commit_index", "applied_index", "fsm_pending", +// "last_snapshot_index", "last_snapshot_term", +// "latest_configuration", "last_contact", and "num_peers". +// +// The value of "state" is a numerical value representing a +// RaftState const. +// +// The value of "latest_configuration" is a string which contains +// the id of each server, its suffrage status, and its address. +// +// The value of "last_contact" is either "never" if there +// has been no contact with a leader, "0" if the node is in the +// leader state, or the time since last contact with a leader +// formatted as a string. +// +// The value of "num_peers" is the number of other voting servers in the +// cluster, not including this node. If this node isn't part of the +// configuration then this will be "0". +// +// All other values are uint64s, formatted as strings. +func (r *Raft) Stats() map[string]string { + toString := func(v uint64) string { + return strconv.FormatUint(v, 10) + } + lastLogIndex, lastLogTerm := r.getLastLog() + lastSnapIndex, lastSnapTerm := r.getLastSnapshot() + s := map[string]string{ + "state": r.getState().String(), + "term": toString(r.getCurrentTerm()), + "last_log_index": toString(lastLogIndex), + "last_log_term": toString(lastLogTerm), + "commit_index": toString(r.getCommitIndex()), + "applied_index": toString(r.getLastApplied()), + "fsm_pending": toString(uint64(len(r.fsmMutateCh))), + "last_snapshot_index": toString(lastSnapIndex), + "last_snapshot_term": toString(lastSnapTerm), + "protocol_version": toString(uint64(r.protocolVersion)), + "protocol_version_min": toString(uint64(ProtocolVersionMin)), + "protocol_version_max": toString(uint64(ProtocolVersionMax)), + "snapshot_version_min": toString(uint64(SnapshotVersionMin)), + "snapshot_version_max": toString(uint64(SnapshotVersionMax)), + } + + future := r.GetConfiguration() + if err := future.Error(); err != nil { + r.logger.Printf("[WARN] raft: could not get configuration for Stats: %v", err) + } else { + configuration := future.Configuration() + s["latest_configuration_index"] = toString(future.Index()) + s["latest_configuration"] = fmt.Sprintf("%+v", configuration.Servers) + + // This is a legacy metric that we've seen people use in the wild. + hasUs := false + numPeers := 0 + for _, server := range configuration.Servers { + if server.Suffrage == Voter { + if server.ID == r.localID { + hasUs = true + } else { + numPeers++ + } + } + } + if !hasUs { + numPeers = 0 + } + s["num_peers"] = toString(uint64(numPeers)) + } + + last := r.LastContact() + if r.getState() == Leader { + s["last_contact"] = "0" + } else if last.IsZero() { + s["last_contact"] = "never" + } else { + s["last_contact"] = fmt.Sprintf("%v", time.Now().Sub(last)) + } + return s +} + +// LastIndex returns the last index in stable storage, +// either from the last log or from the last snapshot. +func (r *Raft) LastIndex() uint64 { + return r.getLastIndex() +} + +// AppliedIndex returns the last index applied to the FSM. This is generally +// lagging behind the last index, especially for indexes that are persisted but +// have not yet been considered committed by the leader. NOTE - this reflects +// the last index that was sent to the application's FSM over the apply channel +// but DOES NOT mean that the application's FSM has yet consumed it and applied +// it to its internal state. Thus, the application's state may lag behind this +// index. +func (r *Raft) AppliedIndex() uint64 { + return r.getLastApplied() +} diff --git a/vendor/github.com/hashicorp/raft/commands.go b/vendor/github.com/hashicorp/raft/commands.go new file mode 100644 index 00000000000..5d89e7bcdb1 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/commands.go @@ -0,0 +1,151 @@ +package raft + +// RPCHeader is a common sub-structure used to pass along protocol version and +// other information about the cluster. For older Raft implementations before +// versioning was added this will default to a zero-valued structure when read +// by newer Raft versions. +type RPCHeader struct { + // ProtocolVersion is the version of the protocol the sender is + // speaking. + ProtocolVersion ProtocolVersion +} + +// WithRPCHeader is an interface that exposes the RPC header. +type WithRPCHeader interface { + GetRPCHeader() RPCHeader +} + +// AppendEntriesRequest is the command used to append entries to the +// replicated log. +type AppendEntriesRequest struct { + RPCHeader + + // Provide the current term and leader + Term uint64 + Leader []byte + + // Provide the previous entries for integrity checking + PrevLogEntry uint64 + PrevLogTerm uint64 + + // New entries to commit + Entries []*Log + + // Commit index on the leader + LeaderCommitIndex uint64 +} + +// See WithRPCHeader. +func (r *AppendEntriesRequest) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// AppendEntriesResponse is the response returned from an +// AppendEntriesRequest. +type AppendEntriesResponse struct { + RPCHeader + + // Newer term if leader is out of date + Term uint64 + + // Last Log is a hint to help accelerate rebuilding slow nodes + LastLog uint64 + + // We may not succeed if we have a conflicting entry + Success bool + + // There are scenarios where this request didn't succeed + // but there's no need to wait/back-off the next attempt. + NoRetryBackoff bool +} + +// See WithRPCHeader. +func (r *AppendEntriesResponse) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// RequestVoteRequest is the command used by a candidate to ask a Raft peer +// for a vote in an election. +type RequestVoteRequest struct { + RPCHeader + + // Provide the term and our id + Term uint64 + Candidate []byte + + // Used to ensure safety + LastLogIndex uint64 + LastLogTerm uint64 +} + +// See WithRPCHeader. +func (r *RequestVoteRequest) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// RequestVoteResponse is the response returned from a RequestVoteRequest. +type RequestVoteResponse struct { + RPCHeader + + // Newer term if leader is out of date. + Term uint64 + + // Peers is deprecated, but required by servers that only understand + // protocol version 0. This is not populated in protocol version 2 + // and later. + Peers []byte + + // Is the vote granted. + Granted bool +} + +// See WithRPCHeader. +func (r *RequestVoteResponse) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// InstallSnapshotRequest is the command sent to a Raft peer to bootstrap its +// log (and state machine) from a snapshot on another peer. +type InstallSnapshotRequest struct { + RPCHeader + SnapshotVersion SnapshotVersion + + Term uint64 + Leader []byte + + // These are the last index/term included in the snapshot + LastLogIndex uint64 + LastLogTerm uint64 + + // Peer Set in the snapshot. This is deprecated in favor of Configuration + // but remains here in case we receive an InstallSnapshot from a leader + // that's running old code. + Peers []byte + + // Cluster membership. + Configuration []byte + // Log index where 'Configuration' entry was originally written. + ConfigurationIndex uint64 + + // Size of the snapshot + Size int64 +} + +// See WithRPCHeader. +func (r *InstallSnapshotRequest) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// InstallSnapshotResponse is the response returned from an +// InstallSnapshotRequest. +type InstallSnapshotResponse struct { + RPCHeader + + Term uint64 + Success bool +} + +// See WithRPCHeader. +func (r *InstallSnapshotResponse) GetRPCHeader() RPCHeader { + return r.RPCHeader +} diff --git a/vendor/github.com/hashicorp/raft/commitment.go b/vendor/github.com/hashicorp/raft/commitment.go new file mode 100644 index 00000000000..b5ba2634ef2 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/commitment.go @@ -0,0 +1,101 @@ +package raft + +import ( + "sort" + "sync" +) + +// Commitment is used to advance the leader's commit index. The leader and +// replication goroutines report in newly written entries with Match(), and +// this notifies on commitCh when the commit index has advanced. +type commitment struct { + // protectes matchIndexes and commitIndex + sync.Mutex + // notified when commitIndex increases + commitCh chan struct{} + // voter ID to log index: the server stores up through this log entry + matchIndexes map[ServerID]uint64 + // a quorum stores up through this log entry. monotonically increases. + commitIndex uint64 + // the first index of this leader's term: this needs to be replicated to a + // majority of the cluster before this leader may mark anything committed + // (per Raft's commitment rule) + startIndex uint64 +} + +// newCommitment returns an commitment struct that notifies the provided +// channel when log entries have been committed. A new commitment struct is +// created each time this server becomes leader for a particular term. +// 'configuration' is the servers in the cluster. +// 'startIndex' is the first index created in this term (see +// its description above). +func newCommitment(commitCh chan struct{}, configuration Configuration, startIndex uint64) *commitment { + matchIndexes := make(map[ServerID]uint64) + for _, server := range configuration.Servers { + if server.Suffrage == Voter { + matchIndexes[server.ID] = 0 + } + } + return &commitment{ + commitCh: commitCh, + matchIndexes: matchIndexes, + commitIndex: 0, + startIndex: startIndex, + } +} + +// Called when a new cluster membership configuration is created: it will be +// used to determine commitment from now on. 'configuration' is the servers in +// the cluster. +func (c *commitment) setConfiguration(configuration Configuration) { + c.Lock() + defer c.Unlock() + oldMatchIndexes := c.matchIndexes + c.matchIndexes = make(map[ServerID]uint64) + for _, server := range configuration.Servers { + if server.Suffrage == Voter { + c.matchIndexes[server.ID] = oldMatchIndexes[server.ID] // defaults to 0 + } + } + c.recalculate() +} + +// Called by leader after commitCh is notified +func (c *commitment) getCommitIndex() uint64 { + c.Lock() + defer c.Unlock() + return c.commitIndex +} + +// Match is called once a server completes writing entries to disk: either the +// leader has written the new entry or a follower has replied to an +// AppendEntries RPC. The given server's disk agrees with this server's log up +// through the given index. +func (c *commitment) match(server ServerID, matchIndex uint64) { + c.Lock() + defer c.Unlock() + if prev, hasVote := c.matchIndexes[server]; hasVote && matchIndex > prev { + c.matchIndexes[server] = matchIndex + c.recalculate() + } +} + +// Internal helper to calculate new commitIndex from matchIndexes. +// Must be called with lock held. +func (c *commitment) recalculate() { + if len(c.matchIndexes) == 0 { + return + } + + matched := make([]uint64, 0, len(c.matchIndexes)) + for _, idx := range c.matchIndexes { + matched = append(matched, idx) + } + sort.Sort(uint64Slice(matched)) + quorumMatchIndex := matched[(len(matched)-1)/2] + + if quorumMatchIndex > c.commitIndex && quorumMatchIndex >= c.startIndex { + c.commitIndex = quorumMatchIndex + asyncNotifyCh(c.commitCh) + } +} diff --git a/vendor/github.com/hashicorp/raft/config.go b/vendor/github.com/hashicorp/raft/config.go new file mode 100644 index 00000000000..c1ce03ac22b --- /dev/null +++ b/vendor/github.com/hashicorp/raft/config.go @@ -0,0 +1,258 @@ +package raft + +import ( + "fmt" + "io" + "log" + "time" +) + +// These are the versions of the protocol (which includes RPC messages as +// well as Raft-specific log entries) that this server can _understand_. Use +// the ProtocolVersion member of the Config object to control the version of +// the protocol to use when _speaking_ to other servers. Note that depending on +// the protocol version being spoken, some otherwise understood RPC messages +// may be refused. See dispositionRPC for details of this logic. +// +// There are notes about the upgrade path in the description of the versions +// below. If you are starting a fresh cluster then there's no reason not to +// jump right to the latest protocol version. If you need to interoperate with +// older, version 0 Raft servers you'll need to drive the cluster through the +// different versions in order. +// +// The version details are complicated, but here's a summary of what's required +// to get from a version 0 cluster to version 3: +// +// 1. In version N of your app that starts using the new Raft library with +// versioning, set ProtocolVersion to 1. +// 2. Make version N+1 of your app require version N as a prerequisite (all +// servers must be upgraded). For version N+1 of your app set ProtocolVersion +// to 2. +// 3. Similarly, make version N+2 of your app require version N+1 as a +// prerequisite. For version N+2 of your app, set ProtocolVersion to 3. +// +// During this upgrade, older cluster members will still have Server IDs equal +// to their network addresses. To upgrade an older member and give it an ID, it +// needs to leave the cluster and re-enter: +// +// 1. Remove the server from the cluster with RemoveServer, using its network +// address as its ServerID. +// 2. Update the server's config to a better ID (restarting the server). +// 3. Add the server back to the cluster with AddVoter, using its new ID. +// +// You can do this during the rolling upgrade from N+1 to N+2 of your app, or +// as a rolling change at any time after the upgrade. +// +// Version History +// +// 0: Original Raft library before versioning was added. Servers running this +// version of the Raft library use AddPeerDeprecated/RemovePeerDeprecated +// for all configuration changes, and have no support for LogConfiguration. +// 1: First versioned protocol, used to interoperate with old servers, and begin +// the migration path to newer versions of the protocol. Under this version +// all configuration changes are propagated using the now-deprecated +// RemovePeerDeprecated Raft log entry. This means that server IDs are always +// set to be the same as the server addresses (since the old log entry type +// cannot transmit an ID), and only AddPeer/RemovePeer APIs are supported. +// Servers running this version of the protocol can understand the new +// LogConfiguration Raft log entry but will never generate one so they can +// remain compatible with version 0 Raft servers in the cluster. +// 2: Transitional protocol used when migrating an existing cluster to the new +// server ID system. Server IDs are still set to be the same as server +// addresses, but all configuration changes are propagated using the new +// LogConfiguration Raft log entry type, which can carry full ID information. +// This version supports the old AddPeer/RemovePeer APIs as well as the new +// ID-based AddVoter/RemoveServer APIs which should be used when adding +// version 3 servers to the cluster later. This version sheds all +// interoperability with version 0 servers, but can interoperate with newer +// Raft servers running with protocol version 1 since they can understand the +// new LogConfiguration Raft log entry, and this version can still understand +// their RemovePeerDeprecated Raft log entries. We need this protocol version +// as an intermediate step between 1 and 3 so that servers will propagate the +// ID information that will come from newly-added (or -rolled) servers using +// protocol version 3, but since they are still using their address-based IDs +// from the previous step they will still be able to track commitments and +// their own voting status properly. If we skipped this step, servers would +// be started with their new IDs, but they wouldn't see themselves in the old +// address-based configuration, so none of the servers would think they had a +// vote. +// 3: Protocol adding full support for server IDs and new ID-based server APIs +// (AddVoter, AddNonvoter, etc.), old AddPeer/RemovePeer APIs are no longer +// supported. Version 2 servers should be swapped out by removing them from +// the cluster one-by-one and re-adding them with updated configuration for +// this protocol version, along with their server ID. The remove/add cycle +// is required to populate their server ID. Note that removing must be done +// by ID, which will be the old server's address. +type ProtocolVersion int + +const ( + ProtocolVersionMin ProtocolVersion = 0 + ProtocolVersionMax = 3 +) + +// These are versions of snapshots that this server can _understand_. Currently, +// it is always assumed that this server generates the latest version, though +// this may be changed in the future to include a configurable version. +// +// Version History +// +// 0: Original Raft library before versioning was added. The peers portion of +// these snapshots is encoded in the legacy format which requires decodePeers +// to parse. This version of snapshots should only be produced by the +// unversioned Raft library. +// 1: New format which adds support for a full configuration structure and its +// associated log index, with support for server IDs and non-voting server +// modes. To ease upgrades, this also includes the legacy peers structure but +// that will never be used by servers that understand version 1 snapshots. +// Since the original Raft library didn't enforce any versioning, we must +// include the legacy peers structure for this version, but we can deprecate +// it in the next snapshot version. +type SnapshotVersion int + +const ( + SnapshotVersionMin SnapshotVersion = 0 + SnapshotVersionMax = 1 +) + +// Config provides any necessary configuration for the Raft server. +type Config struct { + // ProtocolVersion allows a Raft server to inter-operate with older + // Raft servers running an older version of the code. This is used to + // version the wire protocol as well as Raft-specific log entries that + // the server uses when _speaking_ to other servers. There is currently + // no auto-negotiation of versions so all servers must be manually + // configured with compatible versions. See ProtocolVersionMin and + // ProtocolVersionMax for the versions of the protocol that this server + // can _understand_. + ProtocolVersion ProtocolVersion + + // HeartbeatTimeout specifies the time in follower state without + // a leader before we attempt an election. + HeartbeatTimeout time.Duration + + // ElectionTimeout specifies the time in candidate state without + // a leader before we attempt an election. + ElectionTimeout time.Duration + + // CommitTimeout controls the time without an Apply() operation + // before we heartbeat to ensure a timely commit. Due to random + // staggering, may be delayed as much as 2x this value. + CommitTimeout time.Duration + + // MaxAppendEntries controls the maximum number of append entries + // to send at once. We want to strike a balance between efficiency + // and avoiding waste if the follower is going to reject because of + // an inconsistent log. + MaxAppendEntries int + + // If we are a member of a cluster, and RemovePeer is invoked for the + // local node, then we forget all peers and transition into the follower state. + // If ShutdownOnRemove is is set, we additional shutdown Raft. Otherwise, + // we can become a leader of a cluster containing only this node. + ShutdownOnRemove bool + + // TrailingLogs controls how many logs we leave after a snapshot. This is + // used so that we can quickly replay logs on a follower instead of being + // forced to send an entire snapshot. + TrailingLogs uint64 + + // SnapshotInterval controls how often we check if we should perform a snapshot. + // We randomly stagger between this value and 2x this value to avoid the entire + // cluster from performing a snapshot at once. + SnapshotInterval time.Duration + + // SnapshotThreshold controls how many outstanding logs there must be before + // we perform a snapshot. This is to prevent excessive snapshots when we can + // just replay a small set of logs. + SnapshotThreshold uint64 + + // LeaderLeaseTimeout is used to control how long the "lease" lasts + // for being the leader without being able to contact a quorum + // of nodes. If we reach this interval without contact, we will + // step down as leader. + LeaderLeaseTimeout time.Duration + + // StartAsLeader forces Raft to start in the leader state. This should + // never be used except for testing purposes, as it can cause a split-brain. + StartAsLeader bool + + // The unique ID for this server across all time. When running with + // ProtocolVersion < 3, you must set this to be the same as the network + // address of your transport. + LocalID ServerID + + // NotifyCh is used to provide a channel that will be notified of leadership + // changes. Raft will block writing to this channel, so it should either be + // buffered or aggressively consumed. + NotifyCh chan<- bool + + // LogOutput is used as a sink for logs, unless Logger is specified. + // Defaults to os.Stderr. + LogOutput io.Writer + + // Logger is a user-provided logger. If nil, a logger writing to LogOutput + // is used. + Logger *log.Logger +} + +// DefaultConfig returns a Config with usable defaults. +func DefaultConfig() *Config { + return &Config{ + ProtocolVersion: ProtocolVersionMax, + HeartbeatTimeout: 1000 * time.Millisecond, + ElectionTimeout: 1000 * time.Millisecond, + CommitTimeout: 50 * time.Millisecond, + MaxAppendEntries: 64, + ShutdownOnRemove: true, + TrailingLogs: 10240, + SnapshotInterval: 120 * time.Second, + SnapshotThreshold: 8192, + LeaderLeaseTimeout: 500 * time.Millisecond, + } +} + +// ValidateConfig is used to validate a sane configuration +func ValidateConfig(config *Config) error { + // We don't actually support running as 0 in the library any more, but + // we do understand it. + protocolMin := ProtocolVersionMin + if protocolMin == 0 { + protocolMin = 1 + } + if config.ProtocolVersion < protocolMin || + config.ProtocolVersion > ProtocolVersionMax { + return fmt.Errorf("Protocol version %d must be >= %d and <= %d", + config.ProtocolVersion, protocolMin, ProtocolVersionMax) + } + if len(config.LocalID) == 0 { + return fmt.Errorf("LocalID cannot be empty") + } + if config.HeartbeatTimeout < 5*time.Millisecond { + return fmt.Errorf("Heartbeat timeout is too low") + } + if config.ElectionTimeout < 5*time.Millisecond { + return fmt.Errorf("Election timeout is too low") + } + if config.CommitTimeout < time.Millisecond { + return fmt.Errorf("Commit timeout is too low") + } + if config.MaxAppendEntries <= 0 { + return fmt.Errorf("MaxAppendEntries must be positive") + } + if config.MaxAppendEntries > 1024 { + return fmt.Errorf("MaxAppendEntries is too large") + } + if config.SnapshotInterval < 5*time.Millisecond { + return fmt.Errorf("Snapshot interval is too low") + } + if config.LeaderLeaseTimeout < 5*time.Millisecond { + return fmt.Errorf("Leader lease timeout is too low") + } + if config.LeaderLeaseTimeout > config.HeartbeatTimeout { + return fmt.Errorf("Leader lease timeout cannot be larger than heartbeat timeout") + } + if config.ElectionTimeout < config.HeartbeatTimeout { + return fmt.Errorf("Election timeout must be equal or greater than Heartbeat Timeout") + } + return nil +} diff --git a/vendor/github.com/hashicorp/raft/configuration.go b/vendor/github.com/hashicorp/raft/configuration.go new file mode 100644 index 00000000000..8afc38bd93e --- /dev/null +++ b/vendor/github.com/hashicorp/raft/configuration.go @@ -0,0 +1,343 @@ +package raft + +import "fmt" + +// ServerSuffrage determines whether a Server in a Configuration gets a vote. +type ServerSuffrage int + +// Note: Don't renumber these, since the numbers are written into the log. +const ( + // Voter is a server whose vote is counted in elections and whose match index + // is used in advancing the leader's commit index. + Voter ServerSuffrage = iota + // Nonvoter is a server that receives log entries but is not considered for + // elections or commitment purposes. + Nonvoter + // Staging is a server that acts like a nonvoter with one exception: once a + // staging server receives enough log entries to be sufficiently caught up to + // the leader's log, the leader will invoke a membership change to change + // the Staging server to a Voter. + Staging +) + +func (s ServerSuffrage) String() string { + switch s { + case Voter: + return "Voter" + case Nonvoter: + return "Nonvoter" + case Staging: + return "Staging" + } + return "ServerSuffrage" +} + +// ServerID is a unique string identifying a server for all time. +type ServerID string + +// ServerAddress is a network address for a server that a transport can contact. +type ServerAddress string + +// Server tracks the information about a single server in a configuration. +type Server struct { + // Suffrage determines whether the server gets a vote. + Suffrage ServerSuffrage + // ID is a unique string identifying this server for all time. + ID ServerID + // Address is its network address that a transport can contact. + Address ServerAddress +} + +// Configuration tracks which servers are in the cluster, and whether they have +// votes. This should include the local server, if it's a member of the cluster. +// The servers are listed no particular order, but each should only appear once. +// These entries are appended to the log during membership changes. +type Configuration struct { + Servers []Server +} + +// Clone makes a deep copy of a Configuration. +func (c *Configuration) Clone() (copy Configuration) { + copy.Servers = append(copy.Servers, c.Servers...) + return +} + +// ConfigurationChangeCommand is the different ways to change the cluster +// configuration. +type ConfigurationChangeCommand uint8 + +const ( + // AddStaging makes a server Staging unless its Voter. + AddStaging ConfigurationChangeCommand = iota + // AddNonvoter makes a server Nonvoter unless its Staging or Voter. + AddNonvoter + // DemoteVoter makes a server Nonvoter unless its absent. + DemoteVoter + // RemoveServer removes a server entirely from the cluster membership. + RemoveServer + // Promote is created automatically by a leader; it turns a Staging server + // into a Voter. + Promote +) + +func (c ConfigurationChangeCommand) String() string { + switch c { + case AddStaging: + return "AddStaging" + case AddNonvoter: + return "AddNonvoter" + case DemoteVoter: + return "DemoteVoter" + case RemoveServer: + return "RemoveServer" + case Promote: + return "Promote" + } + return "ConfigurationChangeCommand" +} + +// configurationChangeRequest describes a change that a leader would like to +// make to its current configuration. It's used only within a single server +// (never serialized into the log), as part of `configurationChangeFuture`. +type configurationChangeRequest struct { + command ConfigurationChangeCommand + serverID ServerID + serverAddress ServerAddress // only present for AddStaging, AddNonvoter + // prevIndex, if nonzero, is the index of the only configuration upon which + // this change may be applied; if another configuration entry has been + // added in the meantime, this request will fail. + prevIndex uint64 +} + +// configurations is state tracked on every server about its Configurations. +// Note that, per Diego's dissertation, there can be at most one uncommitted +// configuration at a time (the next configuration may not be created until the +// prior one has been committed). +// +// One downside to storing just two configurations is that if you try to take a +// snahpsot when your state machine hasn't yet applied the committedIndex, we +// have no record of the configuration that would logically fit into that +// snapshot. We disallow snapshots in that case now. An alternative approach, +// which LogCabin uses, is to track every configuration change in the +// log. +type configurations struct { + // committed is the latest configuration in the log/snapshot that has been + // committed (the one with the largest index). + committed Configuration + // committedIndex is the log index where 'committed' was written. + committedIndex uint64 + // latest is the latest configuration in the log/snapshot (may be committed + // or uncommitted) + latest Configuration + // latestIndex is the log index where 'latest' was written. + latestIndex uint64 +} + +// Clone makes a deep copy of a configurations object. +func (c *configurations) Clone() (copy configurations) { + copy.committed = c.committed.Clone() + copy.committedIndex = c.committedIndex + copy.latest = c.latest.Clone() + copy.latestIndex = c.latestIndex + return +} + +// hasVote returns true if the server identified by 'id' is a Voter in the +// provided Configuration. +func hasVote(configuration Configuration, id ServerID) bool { + for _, server := range configuration.Servers { + if server.ID == id { + return server.Suffrage == Voter + } + } + return false +} + +// checkConfiguration tests a cluster membership configuration for common +// errors. +func checkConfiguration(configuration Configuration) error { + idSet := make(map[ServerID]bool) + addressSet := make(map[ServerAddress]bool) + var voters int + for _, server := range configuration.Servers { + if server.ID == "" { + return fmt.Errorf("Empty ID in configuration: %v", configuration) + } + if server.Address == "" { + return fmt.Errorf("Empty address in configuration: %v", server) + } + if idSet[server.ID] { + return fmt.Errorf("Found duplicate ID in configuration: %v", server.ID) + } + idSet[server.ID] = true + if addressSet[server.Address] { + return fmt.Errorf("Found duplicate address in configuration: %v", server.Address) + } + addressSet[server.Address] = true + if server.Suffrage == Voter { + voters++ + } + } + if voters == 0 { + return fmt.Errorf("Need at least one voter in configuration: %v", configuration) + } + return nil +} + +// nextConfiguration generates a new Configuration from the current one and a +// configuration change request. It's split from appendConfigurationEntry so +// that it can be unit tested easily. +func nextConfiguration(current Configuration, currentIndex uint64, change configurationChangeRequest) (Configuration, error) { + if change.prevIndex > 0 && change.prevIndex != currentIndex { + return Configuration{}, fmt.Errorf("Configuration changed since %v (latest is %v)", change.prevIndex, currentIndex) + } + + configuration := current.Clone() + switch change.command { + case AddStaging: + // TODO: barf on new address? + newServer := Server{ + // TODO: This should add the server as Staging, to be automatically + // promoted to Voter later. However, the promoton to Voter is not yet + // implemented, and doing so is not trivial with the way the leader loop + // coordinates with the replication goroutines today. So, for now, the + // server will have a vote right away, and the Promote case below is + // unused. + Suffrage: Voter, + ID: change.serverID, + Address: change.serverAddress, + } + found := false + for i, server := range configuration.Servers { + if server.ID == change.serverID { + if server.Suffrage == Voter { + configuration.Servers[i].Address = change.serverAddress + } else { + configuration.Servers[i] = newServer + } + found = true + break + } + } + if !found { + configuration.Servers = append(configuration.Servers, newServer) + } + case AddNonvoter: + newServer := Server{ + Suffrage: Nonvoter, + ID: change.serverID, + Address: change.serverAddress, + } + found := false + for i, server := range configuration.Servers { + if server.ID == change.serverID { + if server.Suffrage != Nonvoter { + configuration.Servers[i].Address = change.serverAddress + } else { + configuration.Servers[i] = newServer + } + found = true + break + } + } + if !found { + configuration.Servers = append(configuration.Servers, newServer) + } + case DemoteVoter: + for i, server := range configuration.Servers { + if server.ID == change.serverID { + configuration.Servers[i].Suffrage = Nonvoter + break + } + } + case RemoveServer: + for i, server := range configuration.Servers { + if server.ID == change.serverID { + configuration.Servers = append(configuration.Servers[:i], configuration.Servers[i+1:]...) + break + } + } + case Promote: + for i, server := range configuration.Servers { + if server.ID == change.serverID && server.Suffrage == Staging { + configuration.Servers[i].Suffrage = Voter + break + } + } + } + + // Make sure we didn't do something bad like remove the last voter + if err := checkConfiguration(configuration); err != nil { + return Configuration{}, err + } + + return configuration, nil +} + +// encodePeers is used to serialize a Configuration into the old peers format. +// This is here for backwards compatibility when operating with a mix of old +// servers and should be removed once we deprecate support for protocol version 1. +func encodePeers(configuration Configuration, trans Transport) []byte { + // Gather up all the voters, other suffrage types are not supported by + // this data format. + var encPeers [][]byte + for _, server := range configuration.Servers { + if server.Suffrage == Voter { + encPeers = append(encPeers, trans.EncodePeer(server.ID, server.Address)) + } + } + + // Encode the entire array. + buf, err := encodeMsgPack(encPeers) + if err != nil { + panic(fmt.Errorf("failed to encode peers: %v", err)) + } + + return buf.Bytes() +} + +// decodePeers is used to deserialize an old list of peers into a Configuration. +// This is here for backwards compatibility with old log entries and snapshots; +// it should be removed eventually. +func decodePeers(buf []byte, trans Transport) Configuration { + // Decode the buffer first. + var encPeers [][]byte + if err := decodeMsgPack(buf, &encPeers); err != nil { + panic(fmt.Errorf("failed to decode peers: %v", err)) + } + + // Deserialize each peer. + var servers []Server + for _, enc := range encPeers { + p := trans.DecodePeer(enc) + servers = append(servers, Server{ + Suffrage: Voter, + ID: ServerID(p), + Address: ServerAddress(p), + }) + } + + return Configuration{ + Servers: servers, + } +} + +// encodeConfiguration serializes a Configuration using MsgPack, or panics on +// errors. +func encodeConfiguration(configuration Configuration) []byte { + buf, err := encodeMsgPack(configuration) + if err != nil { + panic(fmt.Errorf("failed to encode configuration: %v", err)) + } + return buf.Bytes() +} + +// decodeConfiguration deserializes a Configuration using MsgPack, or panics on +// errors. +func decodeConfiguration(buf []byte) Configuration { + var configuration Configuration + if err := decodeMsgPack(buf, &configuration); err != nil { + panic(fmt.Errorf("failed to decode configuration: %v", err)) + } + return configuration +} diff --git a/vendor/github.com/hashicorp/raft/discard_snapshot.go b/vendor/github.com/hashicorp/raft/discard_snapshot.go new file mode 100644 index 00000000000..5e93a9fe01f --- /dev/null +++ b/vendor/github.com/hashicorp/raft/discard_snapshot.go @@ -0,0 +1,49 @@ +package raft + +import ( + "fmt" + "io" +) + +// DiscardSnapshotStore is used to successfully snapshot while +// always discarding the snapshot. This is useful for when the +// log should be truncated but no snapshot should be retained. +// This should never be used for production use, and is only +// suitable for testing. +type DiscardSnapshotStore struct{} + +type DiscardSnapshotSink struct{} + +// NewDiscardSnapshotStore is used to create a new DiscardSnapshotStore. +func NewDiscardSnapshotStore() *DiscardSnapshotStore { + return &DiscardSnapshotStore{} +} + +func (d *DiscardSnapshotStore) Create(version SnapshotVersion, index, term uint64, + configuration Configuration, configurationIndex uint64, trans Transport) (SnapshotSink, error) { + return &DiscardSnapshotSink{}, nil +} + +func (d *DiscardSnapshotStore) List() ([]*SnapshotMeta, error) { + return nil, nil +} + +func (d *DiscardSnapshotStore) Open(id string) (*SnapshotMeta, io.ReadCloser, error) { + return nil, nil, fmt.Errorf("open is not supported") +} + +func (d *DiscardSnapshotSink) Write(b []byte) (int, error) { + return len(b), nil +} + +func (d *DiscardSnapshotSink) Close() error { + return nil +} + +func (d *DiscardSnapshotSink) ID() string { + return "discard" +} + +func (d *DiscardSnapshotSink) Cancel() error { + return nil +} diff --git a/vendor/github.com/hashicorp/raft/file_snapshot.go b/vendor/github.com/hashicorp/raft/file_snapshot.go new file mode 100644 index 00000000000..ffc9414542f --- /dev/null +++ b/vendor/github.com/hashicorp/raft/file_snapshot.go @@ -0,0 +1,528 @@ +package raft + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "hash" + "hash/crc64" + "io" + "io/ioutil" + "log" + "os" + "path/filepath" + "runtime" + "sort" + "strings" + "time" +) + +const ( + testPath = "permTest" + snapPath = "snapshots" + metaFilePath = "meta.json" + stateFilePath = "state.bin" + tmpSuffix = ".tmp" +) + +// FileSnapshotStore implements the SnapshotStore interface and allows +// snapshots to be made on the local disk. +type FileSnapshotStore struct { + path string + retain int + logger *log.Logger +} + +type snapMetaSlice []*fileSnapshotMeta + +// FileSnapshotSink implements SnapshotSink with a file. +type FileSnapshotSink struct { + store *FileSnapshotStore + logger *log.Logger + dir string + parentDir string + meta fileSnapshotMeta + + stateFile *os.File + stateHash hash.Hash64 + buffered *bufio.Writer + + closed bool +} + +// fileSnapshotMeta is stored on disk. We also put a CRC +// on disk so that we can verify the snapshot. +type fileSnapshotMeta struct { + SnapshotMeta + CRC []byte +} + +// bufferedFile is returned when we open a snapshot. This way +// reads are buffered and the file still gets closed. +type bufferedFile struct { + bh *bufio.Reader + fh *os.File +} + +func (b *bufferedFile) Read(p []byte) (n int, err error) { + return b.bh.Read(p) +} + +func (b *bufferedFile) Close() error { + return b.fh.Close() +} + +// NewFileSnapshotStoreWithLogger creates a new FileSnapshotStore based +// on a base directory. The `retain` parameter controls how many +// snapshots are retained. Must be at least 1. +func NewFileSnapshotStoreWithLogger(base string, retain int, logger *log.Logger) (*FileSnapshotStore, error) { + if retain < 1 { + return nil, fmt.Errorf("must retain at least one snapshot") + } + if logger == nil { + logger = log.New(os.Stderr, "", log.LstdFlags) + } + + // Ensure our path exists + path := filepath.Join(base, snapPath) + if err := os.MkdirAll(path, 0755); err != nil && !os.IsExist(err) { + return nil, fmt.Errorf("snapshot path not accessible: %v", err) + } + + // Setup the store + store := &FileSnapshotStore{ + path: path, + retain: retain, + logger: logger, + } + + // Do a permissions test + if err := store.testPermissions(); err != nil { + return nil, fmt.Errorf("permissions test failed: %v", err) + } + return store, nil +} + +// NewFileSnapshotStore creates a new FileSnapshotStore based +// on a base directory. The `retain` parameter controls how many +// snapshots are retained. Must be at least 1. +func NewFileSnapshotStore(base string, retain int, logOutput io.Writer) (*FileSnapshotStore, error) { + if logOutput == nil { + logOutput = os.Stderr + } + return NewFileSnapshotStoreWithLogger(base, retain, log.New(logOutput, "", log.LstdFlags)) +} + +// testPermissions tries to touch a file in our path to see if it works. +func (f *FileSnapshotStore) testPermissions() error { + path := filepath.Join(f.path, testPath) + fh, err := os.Create(path) + if err != nil { + return err + } + + if err = fh.Close(); err != nil { + return err + } + + if err = os.Remove(path); err != nil { + return err + } + return nil +} + +// snapshotName generates a name for the snapshot. +func snapshotName(term, index uint64) string { + now := time.Now() + msec := now.UnixNano() / int64(time.Millisecond) + return fmt.Sprintf("%d-%d-%d", term, index, msec) +} + +// Create is used to start a new snapshot +func (f *FileSnapshotStore) Create(version SnapshotVersion, index, term uint64, + configuration Configuration, configurationIndex uint64, trans Transport) (SnapshotSink, error) { + // We only support version 1 snapshots at this time. + if version != 1 { + return nil, fmt.Errorf("unsupported snapshot version %d", version) + } + + // Create a new path + name := snapshotName(term, index) + path := filepath.Join(f.path, name+tmpSuffix) + f.logger.Printf("[INFO] snapshot: Creating new snapshot at %s", path) + + // Make the directory + if err := os.MkdirAll(path, 0755); err != nil { + f.logger.Printf("[ERR] snapshot: Failed to make snapshot directory: %v", err) + return nil, err + } + + // Create the sink + sink := &FileSnapshotSink{ + store: f, + logger: f.logger, + dir: path, + parentDir: f.path, + meta: fileSnapshotMeta{ + SnapshotMeta: SnapshotMeta{ + Version: version, + ID: name, + Index: index, + Term: term, + Peers: encodePeers(configuration, trans), + Configuration: configuration, + ConfigurationIndex: configurationIndex, + }, + CRC: nil, + }, + } + + // Write out the meta data + if err := sink.writeMeta(); err != nil { + f.logger.Printf("[ERR] snapshot: Failed to write metadata: %v", err) + return nil, err + } + + // Open the state file + statePath := filepath.Join(path, stateFilePath) + fh, err := os.Create(statePath) + if err != nil { + f.logger.Printf("[ERR] snapshot: Failed to create state file: %v", err) + return nil, err + } + sink.stateFile = fh + + // Create a CRC64 hash + sink.stateHash = crc64.New(crc64.MakeTable(crc64.ECMA)) + + // Wrap both the hash and file in a MultiWriter with buffering + multi := io.MultiWriter(sink.stateFile, sink.stateHash) + sink.buffered = bufio.NewWriter(multi) + + // Done + return sink, nil +} + +// List returns available snapshots in the store. +func (f *FileSnapshotStore) List() ([]*SnapshotMeta, error) { + // Get the eligible snapshots + snapshots, err := f.getSnapshots() + if err != nil { + f.logger.Printf("[ERR] snapshot: Failed to get snapshots: %v", err) + return nil, err + } + + var snapMeta []*SnapshotMeta + for _, meta := range snapshots { + snapMeta = append(snapMeta, &meta.SnapshotMeta) + if len(snapMeta) == f.retain { + break + } + } + return snapMeta, nil +} + +// getSnapshots returns all the known snapshots. +func (f *FileSnapshotStore) getSnapshots() ([]*fileSnapshotMeta, error) { + // Get the eligible snapshots + snapshots, err := ioutil.ReadDir(f.path) + if err != nil { + f.logger.Printf("[ERR] snapshot: Failed to scan snapshot dir: %v", err) + return nil, err + } + + // Populate the metadata + var snapMeta []*fileSnapshotMeta + for _, snap := range snapshots { + // Ignore any files + if !snap.IsDir() { + continue + } + + // Ignore any temporary snapshots + dirName := snap.Name() + if strings.HasSuffix(dirName, tmpSuffix) { + f.logger.Printf("[WARN] snapshot: Found temporary snapshot: %v", dirName) + continue + } + + // Try to read the meta data + meta, err := f.readMeta(dirName) + if err != nil { + f.logger.Printf("[WARN] snapshot: Failed to read metadata for %v: %v", dirName, err) + continue + } + + // Make sure we can understand this version. + if meta.Version < SnapshotVersionMin || meta.Version > SnapshotVersionMax { + f.logger.Printf("[WARN] snapshot: Snapshot version for %v not supported: %d", dirName, meta.Version) + continue + } + + // Append, but only return up to the retain count + snapMeta = append(snapMeta, meta) + } + + // Sort the snapshot, reverse so we get new -> old + sort.Sort(sort.Reverse(snapMetaSlice(snapMeta))) + + return snapMeta, nil +} + +// readMeta is used to read the meta data for a given named backup +func (f *FileSnapshotStore) readMeta(name string) (*fileSnapshotMeta, error) { + // Open the meta file + metaPath := filepath.Join(f.path, name, metaFilePath) + fh, err := os.Open(metaPath) + if err != nil { + return nil, err + } + defer fh.Close() + + // Buffer the file IO + buffered := bufio.NewReader(fh) + + // Read in the JSON + meta := &fileSnapshotMeta{} + dec := json.NewDecoder(buffered) + if err := dec.Decode(meta); err != nil { + return nil, err + } + return meta, nil +} + +// Open takes a snapshot ID and returns a ReadCloser for that snapshot. +func (f *FileSnapshotStore) Open(id string) (*SnapshotMeta, io.ReadCloser, error) { + // Get the metadata + meta, err := f.readMeta(id) + if err != nil { + f.logger.Printf("[ERR] snapshot: Failed to get meta data to open snapshot: %v", err) + return nil, nil, err + } + + // Open the state file + statePath := filepath.Join(f.path, id, stateFilePath) + fh, err := os.Open(statePath) + if err != nil { + f.logger.Printf("[ERR] snapshot: Failed to open state file: %v", err) + return nil, nil, err + } + + // Create a CRC64 hash + stateHash := crc64.New(crc64.MakeTable(crc64.ECMA)) + + // Compute the hash + _, err = io.Copy(stateHash, fh) + if err != nil { + f.logger.Printf("[ERR] snapshot: Failed to read state file: %v", err) + fh.Close() + return nil, nil, err + } + + // Verify the hash + computed := stateHash.Sum(nil) + if bytes.Compare(meta.CRC, computed) != 0 { + f.logger.Printf("[ERR] snapshot: CRC checksum failed (stored: %v computed: %v)", + meta.CRC, computed) + fh.Close() + return nil, nil, fmt.Errorf("CRC mismatch") + } + + // Seek to the start + if _, err := fh.Seek(0, 0); err != nil { + f.logger.Printf("[ERR] snapshot: State file seek failed: %v", err) + fh.Close() + return nil, nil, err + } + + // Return a buffered file + buffered := &bufferedFile{ + bh: bufio.NewReader(fh), + fh: fh, + } + + return &meta.SnapshotMeta, buffered, nil +} + +// ReapSnapshots reaps any snapshots beyond the retain count. +func (f *FileSnapshotStore) ReapSnapshots() error { + snapshots, err := f.getSnapshots() + if err != nil { + f.logger.Printf("[ERR] snapshot: Failed to get snapshots: %v", err) + return err + } + + for i := f.retain; i < len(snapshots); i++ { + path := filepath.Join(f.path, snapshots[i].ID) + f.logger.Printf("[INFO] snapshot: reaping snapshot %v", path) + if err := os.RemoveAll(path); err != nil { + f.logger.Printf("[ERR] snapshot: Failed to reap snapshot %v: %v", path, err) + return err + } + } + return nil +} + +// ID returns the ID of the snapshot, can be used with Open() +// after the snapshot is finalized. +func (s *FileSnapshotSink) ID() string { + return s.meta.ID +} + +// Write is used to append to the state file. We write to the +// buffered IO object to reduce the amount of context switches. +func (s *FileSnapshotSink) Write(b []byte) (int, error) { + return s.buffered.Write(b) +} + +// Close is used to indicate a successful end. +func (s *FileSnapshotSink) Close() error { + // Make sure close is idempotent + if s.closed { + return nil + } + s.closed = true + + // Close the open handles + if err := s.finalize(); err != nil { + s.logger.Printf("[ERR] snapshot: Failed to finalize snapshot: %v", err) + if delErr := os.RemoveAll(s.dir); delErr != nil { + s.logger.Printf("[ERR] snapshot: Failed to delete temporary snapshot directory at path %v: %v", s.dir, delErr) + return delErr + } + return err + } + + // Write out the meta data + if err := s.writeMeta(); err != nil { + s.logger.Printf("[ERR] snapshot: Failed to write metadata: %v", err) + return err + } + + // Move the directory into place + newPath := strings.TrimSuffix(s.dir, tmpSuffix) + if err := os.Rename(s.dir, newPath); err != nil { + s.logger.Printf("[ERR] snapshot: Failed to move snapshot into place: %v", err) + return err + } + + if runtime.GOOS != "windows" { //skipping fsync for directory entry edits on Windows, only needed for *nix style file systems + parentFH, err := os.Open(s.parentDir) + defer parentFH.Close() + if err != nil { + s.logger.Printf("[ERR] snapshot: Failed to open snapshot parent directory %v, error: %v", s.parentDir, err) + return err + } + + if err = parentFH.Sync(); err != nil { + s.logger.Printf("[ERR] snapshot: Failed syncing parent directory %v, error: %v", s.parentDir, err) + return err + } + } + + // Reap any old snapshots + if err := s.store.ReapSnapshots(); err != nil { + return err + } + + return nil +} + +// Cancel is used to indicate an unsuccessful end. +func (s *FileSnapshotSink) Cancel() error { + // Make sure close is idempotent + if s.closed { + return nil + } + s.closed = true + + // Close the open handles + if err := s.finalize(); err != nil { + s.logger.Printf("[ERR] snapshot: Failed to finalize snapshot: %v", err) + return err + } + + // Attempt to remove all artifacts + return os.RemoveAll(s.dir) +} + +// finalize is used to close all of our resources. +func (s *FileSnapshotSink) finalize() error { + // Flush any remaining data + if err := s.buffered.Flush(); err != nil { + return err + } + + // Sync to force fsync to disk + if err := s.stateFile.Sync(); err != nil { + return err + } + + // Get the file size + stat, statErr := s.stateFile.Stat() + + // Close the file + if err := s.stateFile.Close(); err != nil { + return err + } + + // Set the file size, check after we close + if statErr != nil { + return statErr + } + s.meta.Size = stat.Size() + + // Set the CRC + s.meta.CRC = s.stateHash.Sum(nil) + return nil +} + +// writeMeta is used to write out the metadata we have. +func (s *FileSnapshotSink) writeMeta() error { + // Open the meta file + metaPath := filepath.Join(s.dir, metaFilePath) + fh, err := os.Create(metaPath) + if err != nil { + return err + } + defer fh.Close() + + // Buffer the file IO + buffered := bufio.NewWriter(fh) + + // Write out as JSON + enc := json.NewEncoder(buffered) + if err := enc.Encode(&s.meta); err != nil { + return err + } + + if err = buffered.Flush(); err != nil { + return err + } + + if err = fh.Sync(); err != nil { + return err + } + + return nil +} + +// Implement the sort interface for []*fileSnapshotMeta. +func (s snapMetaSlice) Len() int { + return len(s) +} + +func (s snapMetaSlice) Less(i, j int) bool { + if s[i].Term != s[j].Term { + return s[i].Term < s[j].Term + } + if s[i].Index != s[j].Index { + return s[i].Index < s[j].Index + } + return s[i].ID < s[j].ID +} + +func (s snapMetaSlice) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} diff --git a/vendor/github.com/hashicorp/raft/fsm.go b/vendor/github.com/hashicorp/raft/fsm.go new file mode 100644 index 00000000000..c89986c0fad --- /dev/null +++ b/vendor/github.com/hashicorp/raft/fsm.go @@ -0,0 +1,136 @@ +package raft + +import ( + "fmt" + "io" + "time" + + "github.com/armon/go-metrics" +) + +// FSM provides an interface that can be implemented by +// clients to make use of the replicated log. +type FSM interface { + // Apply log is invoked once a log entry is committed. + // It returns a value which will be made available in the + // ApplyFuture returned by Raft.Apply method if that + // method was called on the same Raft node as the FSM. + Apply(*Log) interface{} + + // Snapshot is used to support log compaction. This call should + // return an FSMSnapshot which can be used to save a point-in-time + // snapshot of the FSM. Apply and Snapshot are not called in multiple + // threads, but Apply will be called concurrently with Persist. This means + // the FSM should be implemented in a fashion that allows for concurrent + // updates while a snapshot is happening. + Snapshot() (FSMSnapshot, error) + + // Restore is used to restore an FSM from a snapshot. It is not called + // concurrently with any other command. The FSM must discard all previous + // state. + Restore(io.ReadCloser) error +} + +// FSMSnapshot is returned by an FSM in response to a Snapshot +// It must be safe to invoke FSMSnapshot methods with concurrent +// calls to Apply. +type FSMSnapshot interface { + // Persist should dump all necessary state to the WriteCloser 'sink', + // and call sink.Close() when finished or call sink.Cancel() on error. + Persist(sink SnapshotSink) error + + // Release is invoked when we are finished with the snapshot. + Release() +} + +// runFSM is a long running goroutine responsible for applying logs +// to the FSM. This is done async of other logs since we don't want +// the FSM to block our internal operations. +func (r *Raft) runFSM() { + var lastIndex, lastTerm uint64 + + commit := func(req *commitTuple) { + // Apply the log if a command + var resp interface{} + if req.log.Type == LogCommand { + start := time.Now() + resp = r.fsm.Apply(req.log) + metrics.MeasureSince([]string{"raft", "fsm", "apply"}, start) + } + + // Update the indexes + lastIndex = req.log.Index + lastTerm = req.log.Term + + // Invoke the future if given + if req.future != nil { + req.future.response = resp + req.future.respond(nil) + } + } + + restore := func(req *restoreFuture) { + // Open the snapshot + meta, source, err := r.snapshots.Open(req.ID) + if err != nil { + req.respond(fmt.Errorf("failed to open snapshot %v: %v", req.ID, err)) + return + } + + // Attempt to restore + start := time.Now() + if err := r.fsm.Restore(source); err != nil { + req.respond(fmt.Errorf("failed to restore snapshot %v: %v", req.ID, err)) + source.Close() + return + } + source.Close() + metrics.MeasureSince([]string{"raft", "fsm", "restore"}, start) + + // Update the last index and term + lastIndex = meta.Index + lastTerm = meta.Term + req.respond(nil) + } + + snapshot := func(req *reqSnapshotFuture) { + // Is there something to snapshot? + if lastIndex == 0 { + req.respond(ErrNothingNewToSnapshot) + return + } + + // Start a snapshot + start := time.Now() + snap, err := r.fsm.Snapshot() + metrics.MeasureSince([]string{"raft", "fsm", "snapshot"}, start) + + // Respond to the request + req.index = lastIndex + req.term = lastTerm + req.snapshot = snap + req.respond(err) + } + + for { + select { + case ptr := <-r.fsmMutateCh: + switch req := ptr.(type) { + case *commitTuple: + commit(req) + + case *restoreFuture: + restore(req) + + default: + panic(fmt.Errorf("bad type passed to fsmMutateCh: %#v", ptr)) + } + + case req := <-r.fsmSnapshotCh: + snapshot(req) + + case <-r.shutdownCh: + return + } + } +} diff --git a/vendor/github.com/hashicorp/raft/future.go b/vendor/github.com/hashicorp/raft/future.go new file mode 100644 index 00000000000..fac59a5cc47 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/future.go @@ -0,0 +1,289 @@ +package raft + +import ( + "fmt" + "io" + "sync" + "time" +) + +// Future is used to represent an action that may occur in the future. +type Future interface { + // Error blocks until the future arrives and then + // returns the error status of the future. + // This may be called any number of times - all + // calls will return the same value. + // Note that it is not OK to call this method + // twice concurrently on the same Future instance. + Error() error +} + +// IndexFuture is used for future actions that can result in a raft log entry +// being created. +type IndexFuture interface { + Future + + // Index holds the index of the newly applied log entry. + // This must not be called until after the Error method has returned. + Index() uint64 +} + +// ApplyFuture is used for Apply and can return the FSM response. +type ApplyFuture interface { + IndexFuture + + // Response returns the FSM response as returned + // by the FSM.Apply method. This must not be called + // until after the Error method has returned. + Response() interface{} +} + +// ConfigurationFuture is used for GetConfiguration and can return the +// latest configuration in use by Raft. +type ConfigurationFuture interface { + IndexFuture + + // Configuration contains the latest configuration. This must + // not be called until after the Error method has returned. + Configuration() Configuration +} + +// SnapshotFuture is used for waiting on a user-triggered snapshot to complete. +type SnapshotFuture interface { + Future + + // Open is a function you can call to access the underlying snapshot and + // its metadata. This must not be called until after the Error method + // has returned. + Open() (*SnapshotMeta, io.ReadCloser, error) +} + +// errorFuture is used to return a static error. +type errorFuture struct { + err error +} + +func (e errorFuture) Error() error { + return e.err +} + +func (e errorFuture) Response() interface{} { + return nil +} + +func (e errorFuture) Index() uint64 { + return 0 +} + +// deferError can be embedded to allow a future +// to provide an error in the future. +type deferError struct { + err error + errCh chan error + responded bool +} + +func (d *deferError) init() { + d.errCh = make(chan error, 1) +} + +func (d *deferError) Error() error { + if d.err != nil { + // Note that when we've received a nil error, this + // won't trigger, but the channel is closed after + // send so we'll still return nil below. + return d.err + } + if d.errCh == nil { + panic("waiting for response on nil channel") + } + d.err = <-d.errCh + return d.err +} + +func (d *deferError) respond(err error) { + if d.errCh == nil { + return + } + if d.responded { + return + } + d.errCh <- err + close(d.errCh) + d.responded = true +} + +// There are several types of requests that cause a configuration entry to +// be appended to the log. These are encoded here for leaderLoop() to process. +// This is internal to a single server. +type configurationChangeFuture struct { + logFuture + req configurationChangeRequest +} + +// bootstrapFuture is used to attempt a live bootstrap of the cluster. See the +// Raft object's BootstrapCluster member function for more details. +type bootstrapFuture struct { + deferError + + // configuration is the proposed bootstrap configuration to apply. + configuration Configuration +} + +// logFuture is used to apply a log entry and waits until +// the log is considered committed. +type logFuture struct { + deferError + log Log + response interface{} + dispatch time.Time +} + +func (l *logFuture) Response() interface{} { + return l.response +} + +func (l *logFuture) Index() uint64 { + return l.log.Index +} + +type shutdownFuture struct { + raft *Raft +} + +func (s *shutdownFuture) Error() error { + if s.raft == nil { + return nil + } + s.raft.waitShutdown() + if closeable, ok := s.raft.trans.(WithClose); ok { + closeable.Close() + } + return nil +} + +// userSnapshotFuture is used for waiting on a user-triggered snapshot to +// complete. +type userSnapshotFuture struct { + deferError + + // opener is a function used to open the snapshot. This is filled in + // once the future returns with no error. + opener func() (*SnapshotMeta, io.ReadCloser, error) +} + +// Open is a function you can call to access the underlying snapshot and its +// metadata. +func (u *userSnapshotFuture) Open() (*SnapshotMeta, io.ReadCloser, error) { + if u.opener == nil { + return nil, nil, fmt.Errorf("no snapshot available") + } else { + // Invalidate the opener so it can't get called multiple times, + // which isn't generally safe. + defer func() { + u.opener = nil + }() + return u.opener() + } +} + +// userRestoreFuture is used for waiting on a user-triggered restore of an +// external snapshot to complete. +type userRestoreFuture struct { + deferError + + // meta is the metadata that belongs with the snapshot. + meta *SnapshotMeta + + // reader is the interface to read the snapshot contents from. + reader io.Reader +} + +// reqSnapshotFuture is used for requesting a snapshot start. +// It is only used internally. +type reqSnapshotFuture struct { + deferError + + // snapshot details provided by the FSM runner before responding + index uint64 + term uint64 + snapshot FSMSnapshot +} + +// restoreFuture is used for requesting an FSM to perform a +// snapshot restore. Used internally only. +type restoreFuture struct { + deferError + ID string +} + +// verifyFuture is used to verify the current node is still +// the leader. This is to prevent a stale read. +type verifyFuture struct { + deferError + notifyCh chan *verifyFuture + quorumSize int + votes int + voteLock sync.Mutex +} + +// configurationsFuture is used to retrieve the current configurations. This is +// used to allow safe access to this information outside of the main thread. +type configurationsFuture struct { + deferError + configurations configurations +} + +// Configuration returns the latest configuration in use by Raft. +func (c *configurationsFuture) Configuration() Configuration { + return c.configurations.latest +} + +// Index returns the index of the latest configuration in use by Raft. +func (c *configurationsFuture) Index() uint64 { + return c.configurations.latestIndex +} + +// vote is used to respond to a verifyFuture. +// This may block when responding on the notifyCh. +func (v *verifyFuture) vote(leader bool) { + v.voteLock.Lock() + defer v.voteLock.Unlock() + + // Guard against having notified already + if v.notifyCh == nil { + return + } + + if leader { + v.votes++ + if v.votes >= v.quorumSize { + v.notifyCh <- v + v.notifyCh = nil + } + } else { + v.notifyCh <- v + v.notifyCh = nil + } +} + +// appendFuture is used for waiting on a pipelined append +// entries RPC. +type appendFuture struct { + deferError + start time.Time + args *AppendEntriesRequest + resp *AppendEntriesResponse +} + +func (a *appendFuture) Start() time.Time { + return a.start +} + +func (a *appendFuture) Request() *AppendEntriesRequest { + return a.args +} + +func (a *appendFuture) Response() *AppendEntriesResponse { + return a.resp +} diff --git a/vendor/github.com/hashicorp/raft/inmem_snapshot.go b/vendor/github.com/hashicorp/raft/inmem_snapshot.go new file mode 100644 index 00000000000..3aa92b3e9a2 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/inmem_snapshot.go @@ -0,0 +1,106 @@ +package raft + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "sync" +) + +// InmemSnapshotStore implements the SnapshotStore interface and +// retains only the most recent snapshot +type InmemSnapshotStore struct { + latest *InmemSnapshotSink + hasSnapshot bool + sync.RWMutex +} + +// InmemSnapshotSink implements SnapshotSink in memory +type InmemSnapshotSink struct { + meta SnapshotMeta + contents *bytes.Buffer +} + +// NewInmemSnapshotStore creates a blank new InmemSnapshotStore +func NewInmemSnapshotStore() *InmemSnapshotStore { + return &InmemSnapshotStore{ + latest: &InmemSnapshotSink{ + contents: &bytes.Buffer{}, + }, + } +} + +// Create replaces the stored snapshot with a new one using the given args +func (m *InmemSnapshotStore) Create(version SnapshotVersion, index, term uint64, + configuration Configuration, configurationIndex uint64, trans Transport) (SnapshotSink, error) { + // We only support version 1 snapshots at this time. + if version != 1 { + return nil, fmt.Errorf("unsupported snapshot version %d", version) + } + + name := snapshotName(term, index) + + m.Lock() + defer m.Unlock() + + sink := &InmemSnapshotSink{ + meta: SnapshotMeta{ + Version: version, + ID: name, + Index: index, + Term: term, + Peers: encodePeers(configuration, trans), + Configuration: configuration, + ConfigurationIndex: configurationIndex, + }, + contents: &bytes.Buffer{}, + } + m.hasSnapshot = true + m.latest = sink + + return sink, nil +} + +// List returns the latest snapshot taken +func (m *InmemSnapshotStore) List() ([]*SnapshotMeta, error) { + m.RLock() + defer m.RUnlock() + + if !m.hasSnapshot { + return []*SnapshotMeta{}, nil + } + return []*SnapshotMeta{&m.latest.meta}, nil +} + +// Open wraps an io.ReadCloser around the snapshot contents +func (m *InmemSnapshotStore) Open(id string) (*SnapshotMeta, io.ReadCloser, error) { + m.RLock() + defer m.RUnlock() + + if m.latest.meta.ID != id { + return nil, nil, fmt.Errorf("[ERR] snapshot: failed to open snapshot id: %s", id) + } + + return &m.latest.meta, ioutil.NopCloser(m.latest.contents), nil +} + +// Write appends the given bytes to the snapshot contents +func (s *InmemSnapshotSink) Write(p []byte) (n int, err error) { + written, err := io.Copy(s.contents, bytes.NewReader(p)) + s.meta.Size += written + return int(written), err +} + +// Close updates the Size and is otherwise a no-op +func (s *InmemSnapshotSink) Close() error { + return nil +} + +func (s *InmemSnapshotSink) ID() string { + return s.meta.ID +} + +func (s *InmemSnapshotSink) Cancel() error { + return nil +} diff --git a/vendor/github.com/hashicorp/raft/inmem_store.go b/vendor/github.com/hashicorp/raft/inmem_store.go new file mode 100644 index 00000000000..e5d579e1b31 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/inmem_store.go @@ -0,0 +1,125 @@ +package raft + +import ( + "sync" +) + +// InmemStore implements the LogStore and StableStore interface. +// It should NOT EVER be used for production. It is used only for +// unit tests. Use the MDBStore implementation instead. +type InmemStore struct { + l sync.RWMutex + lowIndex uint64 + highIndex uint64 + logs map[uint64]*Log + kv map[string][]byte + kvInt map[string]uint64 +} + +// NewInmemStore returns a new in-memory backend. Do not ever +// use for production. Only for testing. +func NewInmemStore() *InmemStore { + i := &InmemStore{ + logs: make(map[uint64]*Log), + kv: make(map[string][]byte), + kvInt: make(map[string]uint64), + } + return i +} + +// FirstIndex implements the LogStore interface. +func (i *InmemStore) FirstIndex() (uint64, error) { + i.l.RLock() + defer i.l.RUnlock() + return i.lowIndex, nil +} + +// LastIndex implements the LogStore interface. +func (i *InmemStore) LastIndex() (uint64, error) { + i.l.RLock() + defer i.l.RUnlock() + return i.highIndex, nil +} + +// GetLog implements the LogStore interface. +func (i *InmemStore) GetLog(index uint64, log *Log) error { + i.l.RLock() + defer i.l.RUnlock() + l, ok := i.logs[index] + if !ok { + return ErrLogNotFound + } + *log = *l + return nil +} + +// StoreLog implements the LogStore interface. +func (i *InmemStore) StoreLog(log *Log) error { + return i.StoreLogs([]*Log{log}) +} + +// StoreLogs implements the LogStore interface. +func (i *InmemStore) StoreLogs(logs []*Log) error { + i.l.Lock() + defer i.l.Unlock() + for _, l := range logs { + i.logs[l.Index] = l + if i.lowIndex == 0 { + i.lowIndex = l.Index + } + if l.Index > i.highIndex { + i.highIndex = l.Index + } + } + return nil +} + +// DeleteRange implements the LogStore interface. +func (i *InmemStore) DeleteRange(min, max uint64) error { + i.l.Lock() + defer i.l.Unlock() + for j := min; j <= max; j++ { + delete(i.logs, j) + } + if min <= i.lowIndex { + i.lowIndex = max + 1 + } + if max >= i.highIndex { + i.highIndex = min - 1 + } + if i.lowIndex > i.highIndex { + i.lowIndex = 0 + i.highIndex = 0 + } + return nil +} + +// Set implements the StableStore interface. +func (i *InmemStore) Set(key []byte, val []byte) error { + i.l.Lock() + defer i.l.Unlock() + i.kv[string(key)] = val + return nil +} + +// Get implements the StableStore interface. +func (i *InmemStore) Get(key []byte) ([]byte, error) { + i.l.RLock() + defer i.l.RUnlock() + return i.kv[string(key)], nil +} + +// SetUint64 implements the StableStore interface. +func (i *InmemStore) SetUint64(key []byte, val uint64) error { + i.l.Lock() + defer i.l.Unlock() + i.kvInt[string(key)] = val + return nil +} + +// GetUint64 implements the StableStore interface. +func (i *InmemStore) GetUint64(key []byte) (uint64, error) { + i.l.RLock() + defer i.l.RUnlock() + return i.kvInt[string(key)], nil +} diff --git a/vendor/github.com/hashicorp/raft/inmem_transport.go b/vendor/github.com/hashicorp/raft/inmem_transport.go new file mode 100644 index 00000000000..ce37f63aa84 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/inmem_transport.go @@ -0,0 +1,322 @@ +package raft + +import ( + "fmt" + "io" + "sync" + "time" +) + +// NewInmemAddr returns a new in-memory addr with +// a randomly generate UUID as the ID. +func NewInmemAddr() ServerAddress { + return ServerAddress(generateUUID()) +} + +// inmemPipeline is used to pipeline requests for the in-mem transport. +type inmemPipeline struct { + trans *InmemTransport + peer *InmemTransport + peerAddr ServerAddress + + doneCh chan AppendFuture + inprogressCh chan *inmemPipelineInflight + + shutdown bool + shutdownCh chan struct{} + shutdownLock sync.Mutex +} + +type inmemPipelineInflight struct { + future *appendFuture + respCh <-chan RPCResponse +} + +// InmemTransport Implements the Transport interface, to allow Raft to be +// tested in-memory without going over a network. +type InmemTransport struct { + sync.RWMutex + consumerCh chan RPC + localAddr ServerAddress + peers map[ServerAddress]*InmemTransport + pipelines []*inmemPipeline + timeout time.Duration +} + +// NewInmemTransport is used to initialize a new transport +// and generates a random local address if none is specified +func NewInmemTransport(addr ServerAddress) (ServerAddress, *InmemTransport) { + if string(addr) == "" { + addr = NewInmemAddr() + } + trans := &InmemTransport{ + consumerCh: make(chan RPC, 16), + localAddr: addr, + peers: make(map[ServerAddress]*InmemTransport), + timeout: 50 * time.Millisecond, + } + return addr, trans +} + +// SetHeartbeatHandler is used to set optional fast-path for +// heartbeats, not supported for this transport. +func (i *InmemTransport) SetHeartbeatHandler(cb func(RPC)) { +} + +// Consumer implements the Transport interface. +func (i *InmemTransport) Consumer() <-chan RPC { + return i.consumerCh +} + +// LocalAddr implements the Transport interface. +func (i *InmemTransport) LocalAddr() ServerAddress { + return i.localAddr +} + +// AppendEntriesPipeline returns an interface that can be used to pipeline +// AppendEntries requests. +func (i *InmemTransport) AppendEntriesPipeline(id ServerID, target ServerAddress) (AppendPipeline, error) { + i.RLock() + peer, ok := i.peers[target] + i.RUnlock() + if !ok { + return nil, fmt.Errorf("failed to connect to peer: %v", target) + } + pipeline := newInmemPipeline(i, peer, target) + i.Lock() + i.pipelines = append(i.pipelines, pipeline) + i.Unlock() + return pipeline, nil +} + +// AppendEntries implements the Transport interface. +func (i *InmemTransport) AppendEntries(id ServerID, target ServerAddress, args *AppendEntriesRequest, resp *AppendEntriesResponse) error { + rpcResp, err := i.makeRPC(target, args, nil, i.timeout) + if err != nil { + return err + } + + // Copy the result back + out := rpcResp.Response.(*AppendEntriesResponse) + *resp = *out + return nil +} + +// RequestVote implements the Transport interface. +func (i *InmemTransport) RequestVote(id ServerID, target ServerAddress, args *RequestVoteRequest, resp *RequestVoteResponse) error { + rpcResp, err := i.makeRPC(target, args, nil, i.timeout) + if err != nil { + return err + } + + // Copy the result back + out := rpcResp.Response.(*RequestVoteResponse) + *resp = *out + return nil +} + +// InstallSnapshot implements the Transport interface. +func (i *InmemTransport) InstallSnapshot(id ServerID, target ServerAddress, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error { + rpcResp, err := i.makeRPC(target, args, data, 10*i.timeout) + if err != nil { + return err + } + + // Copy the result back + out := rpcResp.Response.(*InstallSnapshotResponse) + *resp = *out + return nil +} + +func (i *InmemTransport) makeRPC(target ServerAddress, args interface{}, r io.Reader, timeout time.Duration) (rpcResp RPCResponse, err error) { + i.RLock() + peer, ok := i.peers[target] + i.RUnlock() + + if !ok { + err = fmt.Errorf("failed to connect to peer: %v", target) + return + } + + // Send the RPC over + respCh := make(chan RPCResponse) + peer.consumerCh <- RPC{ + Command: args, + Reader: r, + RespChan: respCh, + } + + // Wait for a response + select { + case rpcResp = <-respCh: + if rpcResp.Error != nil { + err = rpcResp.Error + } + case <-time.After(timeout): + err = fmt.Errorf("command timed out") + } + return +} + +// EncodePeer implements the Transport interface. +func (i *InmemTransport) EncodePeer(id ServerID, p ServerAddress) []byte { + return []byte(p) +} + +// DecodePeer implements the Transport interface. +func (i *InmemTransport) DecodePeer(buf []byte) ServerAddress { + return ServerAddress(buf) +} + +// Connect is used to connect this transport to another transport for +// a given peer name. This allows for local routing. +func (i *InmemTransport) Connect(peer ServerAddress, t Transport) { + trans := t.(*InmemTransport) + i.Lock() + defer i.Unlock() + i.peers[peer] = trans +} + +// Disconnect is used to remove the ability to route to a given peer. +func (i *InmemTransport) Disconnect(peer ServerAddress) { + i.Lock() + defer i.Unlock() + delete(i.peers, peer) + + // Disconnect any pipelines + n := len(i.pipelines) + for idx := 0; idx < n; idx++ { + if i.pipelines[idx].peerAddr == peer { + i.pipelines[idx].Close() + i.pipelines[idx], i.pipelines[n-1] = i.pipelines[n-1], nil + idx-- + n-- + } + } + i.pipelines = i.pipelines[:n] +} + +// DisconnectAll is used to remove all routes to peers. +func (i *InmemTransport) DisconnectAll() { + i.Lock() + defer i.Unlock() + i.peers = make(map[ServerAddress]*InmemTransport) + + // Handle pipelines + for _, pipeline := range i.pipelines { + pipeline.Close() + } + i.pipelines = nil +} + +// Close is used to permanently disable the transport +func (i *InmemTransport) Close() error { + i.DisconnectAll() + return nil +} + +func newInmemPipeline(trans *InmemTransport, peer *InmemTransport, addr ServerAddress) *inmemPipeline { + i := &inmemPipeline{ + trans: trans, + peer: peer, + peerAddr: addr, + doneCh: make(chan AppendFuture, 16), + inprogressCh: make(chan *inmemPipelineInflight, 16), + shutdownCh: make(chan struct{}), + } + go i.decodeResponses() + return i +} + +func (i *inmemPipeline) decodeResponses() { + timeout := i.trans.timeout + for { + select { + case inp := <-i.inprogressCh: + var timeoutCh <-chan time.Time + if timeout > 0 { + timeoutCh = time.After(timeout) + } + + select { + case rpcResp := <-inp.respCh: + // Copy the result back + *inp.future.resp = *rpcResp.Response.(*AppendEntriesResponse) + inp.future.respond(rpcResp.Error) + + select { + case i.doneCh <- inp.future: + case <-i.shutdownCh: + return + } + + case <-timeoutCh: + inp.future.respond(fmt.Errorf("command timed out")) + select { + case i.doneCh <- inp.future: + case <-i.shutdownCh: + return + } + + case <-i.shutdownCh: + return + } + case <-i.shutdownCh: + return + } + } +} + +func (i *inmemPipeline) AppendEntries(args *AppendEntriesRequest, resp *AppendEntriesResponse) (AppendFuture, error) { + // Create a new future + future := &appendFuture{ + start: time.Now(), + args: args, + resp: resp, + } + future.init() + + // Handle a timeout + var timeout <-chan time.Time + if i.trans.timeout > 0 { + timeout = time.After(i.trans.timeout) + } + + // Send the RPC over + respCh := make(chan RPCResponse, 1) + rpc := RPC{ + Command: args, + RespChan: respCh, + } + select { + case i.peer.consumerCh <- rpc: + case <-timeout: + return nil, fmt.Errorf("command enqueue timeout") + case <-i.shutdownCh: + return nil, ErrPipelineShutdown + } + + // Send to be decoded + select { + case i.inprogressCh <- &inmemPipelineInflight{future, respCh}: + return future, nil + case <-i.shutdownCh: + return nil, ErrPipelineShutdown + } +} + +func (i *inmemPipeline) Consumer() <-chan AppendFuture { + return i.doneCh +} + +func (i *inmemPipeline) Close() error { + i.shutdownLock.Lock() + defer i.shutdownLock.Unlock() + if i.shutdown { + return nil + } + + i.shutdown = true + close(i.shutdownCh) + return nil +} diff --git a/vendor/github.com/hashicorp/raft/log.go b/vendor/github.com/hashicorp/raft/log.go new file mode 100644 index 00000000000..4ade38ecc12 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/log.go @@ -0,0 +1,72 @@ +package raft + +// LogType describes various types of log entries. +type LogType uint8 + +const ( + // LogCommand is applied to a user FSM. + LogCommand LogType = iota + + // LogNoop is used to assert leadership. + LogNoop + + // LogAddPeer is used to add a new peer. This should only be used with + // older protocol versions designed to be compatible with unversioned + // Raft servers. See comments in config.go for details. + LogAddPeerDeprecated + + // LogRemovePeer is used to remove an existing peer. This should only be + // used with older protocol versions designed to be compatible with + // unversioned Raft servers. See comments in config.go for details. + LogRemovePeerDeprecated + + // LogBarrier is used to ensure all preceding operations have been + // applied to the FSM. It is similar to LogNoop, but instead of returning + // once committed, it only returns once the FSM manager acks it. Otherwise + // it is possible there are operations committed but not yet applied to + // the FSM. + LogBarrier + + // LogConfiguration establishes a membership change configuration. It is + // created when a server is added, removed, promoted, etc. Only used + // when protocol version 1 or greater is in use. + LogConfiguration +) + +// Log entries are replicated to all members of the Raft cluster +// and form the heart of the replicated state machine. +type Log struct { + // Index holds the index of the log entry. + Index uint64 + + // Term holds the election term of the log entry. + Term uint64 + + // Type holds the type of the log entry. + Type LogType + + // Data holds the log entry's type-specific data. + Data []byte +} + +// LogStore is used to provide an interface for storing +// and retrieving logs in a durable fashion. +type LogStore interface { + // FirstIndex returns the first index written. 0 for no entries. + FirstIndex() (uint64, error) + + // LastIndex returns the last index written. 0 for no entries. + LastIndex() (uint64, error) + + // GetLog gets a log entry at a given index. + GetLog(index uint64, log *Log) error + + // StoreLog stores a log entry. + StoreLog(log *Log) error + + // StoreLogs stores multiple log entries. + StoreLogs(logs []*Log) error + + // DeleteRange deletes a range of log entries. The range is inclusive. + DeleteRange(min, max uint64) error +} diff --git a/vendor/github.com/hashicorp/raft/log_cache.go b/vendor/github.com/hashicorp/raft/log_cache.go new file mode 100644 index 00000000000..952e98c2282 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/log_cache.go @@ -0,0 +1,79 @@ +package raft + +import ( + "fmt" + "sync" +) + +// LogCache wraps any LogStore implementation to provide an +// in-memory ring buffer. This is used to cache access to +// the recently written entries. For implementations that do not +// cache themselves, this can provide a substantial boost by +// avoiding disk I/O on recent entries. +type LogCache struct { + store LogStore + + cache []*Log + l sync.RWMutex +} + +// NewLogCache is used to create a new LogCache with the +// given capacity and backend store. +func NewLogCache(capacity int, store LogStore) (*LogCache, error) { + if capacity <= 0 { + return nil, fmt.Errorf("capacity must be positive") + } + c := &LogCache{ + store: store, + cache: make([]*Log, capacity), + } + return c, nil +} + +func (c *LogCache) GetLog(idx uint64, log *Log) error { + // Check the buffer for an entry + c.l.RLock() + cached := c.cache[idx%uint64(len(c.cache))] + c.l.RUnlock() + + // Check if entry is valid + if cached != nil && cached.Index == idx { + *log = *cached + return nil + } + + // Forward request on cache miss + return c.store.GetLog(idx, log) +} + +func (c *LogCache) StoreLog(log *Log) error { + return c.StoreLogs([]*Log{log}) +} + +func (c *LogCache) StoreLogs(logs []*Log) error { + // Insert the logs into the ring buffer + c.l.Lock() + for _, l := range logs { + c.cache[l.Index%uint64(len(c.cache))] = l + } + c.l.Unlock() + + return c.store.StoreLogs(logs) +} + +func (c *LogCache) FirstIndex() (uint64, error) { + return c.store.FirstIndex() +} + +func (c *LogCache) LastIndex() (uint64, error) { + return c.store.LastIndex() +} + +func (c *LogCache) DeleteRange(min, max uint64) error { + // Invalidate the cache on deletes + c.l.Lock() + c.cache = make([]*Log, len(c.cache)) + c.l.Unlock() + + return c.store.DeleteRange(min, max) +} diff --git a/vendor/github.com/hashicorp/raft/net_transport.go b/vendor/github.com/hashicorp/raft/net_transport.go new file mode 100644 index 00000000000..9555a0eaeb2 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/net_transport.go @@ -0,0 +1,676 @@ +package raft + +import ( + "bufio" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "sync" + "time" + + "github.com/hashicorp/go-msgpack/codec" +) + +const ( + rpcAppendEntries uint8 = iota + rpcRequestVote + rpcInstallSnapshot + + // DefaultTimeoutScale is the default TimeoutScale in a NetworkTransport. + DefaultTimeoutScale = 256 * 1024 // 256KB + + // rpcMaxPipeline controls the maximum number of outstanding + // AppendEntries RPC calls. + rpcMaxPipeline = 128 +) + +var ( + // ErrTransportShutdown is returned when operations on a transport are + // invoked after it's been terminated. + ErrTransportShutdown = errors.New("transport shutdown") + + // ErrPipelineShutdown is returned when the pipeline is closed. + ErrPipelineShutdown = errors.New("append pipeline closed") +) + +/* + +NetworkTransport provides a network based transport that can be +used to communicate with Raft on remote machines. It requires +an underlying stream layer to provide a stream abstraction, which can +be simple TCP, TLS, etc. + +This transport is very simple and lightweight. Each RPC request is +framed by sending a byte that indicates the message type, followed +by the MsgPack encoded request. + +The response is an error string followed by the response object, +both are encoded using MsgPack. + +InstallSnapshot is special, in that after the RPC request we stream +the entire state. That socket is not re-used as the connection state +is not known if there is an error. + +*/ +type NetworkTransport struct { + connPool map[ServerAddress][]*netConn + connPoolLock sync.Mutex + + consumeCh chan RPC + + heartbeatFn func(RPC) + heartbeatFnLock sync.Mutex + + logger *log.Logger + + maxPool int + + serverAddressProvider ServerAddressProvider + + shutdown bool + shutdownCh chan struct{} + shutdownLock sync.Mutex + + stream StreamLayer + + timeout time.Duration + TimeoutScale int +} + +// NetworkTransportConfig encapsulates configuration for the network transport layer. +type NetworkTransportConfig struct { + // ServerAddressProvider is used to override the target address when establishing a connection to invoke an RPC + ServerAddressProvider ServerAddressProvider + + Logger *log.Logger + + // Dialer + Stream StreamLayer + + // MaxPool controls how many connections we will pool + MaxPool int + + // Timeout is used to apply I/O deadlines. For InstallSnapshot, we multiply + // the timeout by (SnapshotSize / TimeoutScale). + Timeout time.Duration +} + +type ServerAddressProvider interface { + ServerAddr(id ServerID) (ServerAddress, error) +} + +// StreamLayer is used with the NetworkTransport to provide +// the low level stream abstraction. +type StreamLayer interface { + net.Listener + + // Dial is used to create a new outgoing connection + Dial(address ServerAddress, timeout time.Duration) (net.Conn, error) +} + +type netConn struct { + target ServerAddress + conn net.Conn + r *bufio.Reader + w *bufio.Writer + dec *codec.Decoder + enc *codec.Encoder +} + +func (n *netConn) Release() error { + return n.conn.Close() +} + +type netPipeline struct { + conn *netConn + trans *NetworkTransport + + doneCh chan AppendFuture + inprogressCh chan *appendFuture + + shutdown bool + shutdownCh chan struct{} + shutdownLock sync.Mutex +} + +// NewNetworkTransportWithConfig creates a new network transport with the given config struct +func NewNetworkTransportWithConfig( + config *NetworkTransportConfig, +) *NetworkTransport { + if config.Logger == nil { + config.Logger = log.New(os.Stderr, "", log.LstdFlags) + } + trans := &NetworkTransport{ + connPool: make(map[ServerAddress][]*netConn), + consumeCh: make(chan RPC), + logger: config.Logger, + maxPool: config.MaxPool, + shutdownCh: make(chan struct{}), + stream: config.Stream, + timeout: config.Timeout, + TimeoutScale: DefaultTimeoutScale, + serverAddressProvider: config.ServerAddressProvider, + } + go trans.listen() + return trans +} + +// NewNetworkTransport creates a new network transport with the given dialer +// and listener. The maxPool controls how many connections we will pool. The +// timeout is used to apply I/O deadlines. For InstallSnapshot, we multiply +// the timeout by (SnapshotSize / TimeoutScale). +func NewNetworkTransport( + stream StreamLayer, + maxPool int, + timeout time.Duration, + logOutput io.Writer, +) *NetworkTransport { + if logOutput == nil { + logOutput = os.Stderr + } + logger := log.New(logOutput, "", log.LstdFlags) + config := &NetworkTransportConfig{Stream: stream, MaxPool: maxPool, Timeout: timeout, Logger: logger} + return NewNetworkTransportWithConfig(config) +} + +// NewNetworkTransportWithLogger creates a new network transport with the given logger, dialer +// and listener. The maxPool controls how many connections we will pool. The +// timeout is used to apply I/O deadlines. For InstallSnapshot, we multiply +// the timeout by (SnapshotSize / TimeoutScale). +func NewNetworkTransportWithLogger( + stream StreamLayer, + maxPool int, + timeout time.Duration, + logger *log.Logger, +) *NetworkTransport { + config := &NetworkTransportConfig{Stream: stream, MaxPool: maxPool, Timeout: timeout, Logger: logger} + return NewNetworkTransportWithConfig(config) +} + +// SetHeartbeatHandler is used to setup a heartbeat handler +// as a fast-pass. This is to avoid head-of-line blocking from +// disk IO. +func (n *NetworkTransport) SetHeartbeatHandler(cb func(rpc RPC)) { + n.heartbeatFnLock.Lock() + defer n.heartbeatFnLock.Unlock() + n.heartbeatFn = cb +} + +// Close is used to stop the network transport. +func (n *NetworkTransport) Close() error { + n.shutdownLock.Lock() + defer n.shutdownLock.Unlock() + + if !n.shutdown { + close(n.shutdownCh) + n.stream.Close() + n.shutdown = true + } + return nil +} + +// Consumer implements the Transport interface. +func (n *NetworkTransport) Consumer() <-chan RPC { + return n.consumeCh +} + +// LocalAddr implements the Transport interface. +func (n *NetworkTransport) LocalAddr() ServerAddress { + return ServerAddress(n.stream.Addr().String()) +} + +// IsShutdown is used to check if the transport is shutdown. +func (n *NetworkTransport) IsShutdown() bool { + select { + case <-n.shutdownCh: + return true + default: + return false + } +} + +// getExistingConn is used to grab a pooled connection. +func (n *NetworkTransport) getPooledConn(target ServerAddress) *netConn { + n.connPoolLock.Lock() + defer n.connPoolLock.Unlock() + + conns, ok := n.connPool[target] + if !ok || len(conns) == 0 { + return nil + } + + var conn *netConn + num := len(conns) + conn, conns[num-1] = conns[num-1], nil + n.connPool[target] = conns[:num-1] + return conn +} + +// getConnFromAddressProvider returns a connection from the server address provider if available, or defaults to a connection using the target server address +func (n *NetworkTransport) getConnFromAddressProvider(id ServerID, target ServerAddress) (*netConn, error) { + address := n.getProviderAddressOrFallback(id, target) + return n.getConn(address) +} + +func (n *NetworkTransport) getProviderAddressOrFallback(id ServerID, target ServerAddress) ServerAddress { + if n.serverAddressProvider != nil { + serverAddressOverride, err := n.serverAddressProvider.ServerAddr(id) + if err != nil { + n.logger.Printf("[WARN] Unable to get address for server id %v, using fallback address %v: %v", id, target, err) + } else { + return serverAddressOverride + } + } + return target +} + +// getConn is used to get a connection from the pool. +func (n *NetworkTransport) getConn(target ServerAddress) (*netConn, error) { + // Check for a pooled conn + if conn := n.getPooledConn(target); conn != nil { + return conn, nil + } + + // Dial a new connection + conn, err := n.stream.Dial(target, n.timeout) + if err != nil { + return nil, err + } + + // Wrap the conn + netConn := &netConn{ + target: target, + conn: conn, + r: bufio.NewReader(conn), + w: bufio.NewWriter(conn), + } + + // Setup encoder/decoders + netConn.dec = codec.NewDecoder(netConn.r, &codec.MsgpackHandle{}) + netConn.enc = codec.NewEncoder(netConn.w, &codec.MsgpackHandle{}) + + // Done + return netConn, nil +} + +// returnConn returns a connection back to the pool. +func (n *NetworkTransport) returnConn(conn *netConn) { + n.connPoolLock.Lock() + defer n.connPoolLock.Unlock() + + key := conn.target + conns, _ := n.connPool[key] + + if !n.IsShutdown() && len(conns) < n.maxPool { + n.connPool[key] = append(conns, conn) + } else { + conn.Release() + } +} + +// AppendEntriesPipeline returns an interface that can be used to pipeline +// AppendEntries requests. +func (n *NetworkTransport) AppendEntriesPipeline(id ServerID, target ServerAddress) (AppendPipeline, error) { + // Get a connection + conn, err := n.getConnFromAddressProvider(id, target) + if err != nil { + return nil, err + } + + // Create the pipeline + return newNetPipeline(n, conn), nil +} + +// AppendEntries implements the Transport interface. +func (n *NetworkTransport) AppendEntries(id ServerID, target ServerAddress, args *AppendEntriesRequest, resp *AppendEntriesResponse) error { + return n.genericRPC(id, target, rpcAppendEntries, args, resp) +} + +// RequestVote implements the Transport interface. +func (n *NetworkTransport) RequestVote(id ServerID, target ServerAddress, args *RequestVoteRequest, resp *RequestVoteResponse) error { + return n.genericRPC(id, target, rpcRequestVote, args, resp) +} + +// genericRPC handles a simple request/response RPC. +func (n *NetworkTransport) genericRPC(id ServerID, target ServerAddress, rpcType uint8, args interface{}, resp interface{}) error { + // Get a conn + conn, err := n.getConnFromAddressProvider(id, target) + if err != nil { + return err + } + + // Set a deadline + if n.timeout > 0 { + conn.conn.SetDeadline(time.Now().Add(n.timeout)) + } + + // Send the RPC + if err = sendRPC(conn, rpcType, args); err != nil { + return err + } + + // Decode the response + canReturn, err := decodeResponse(conn, resp) + if canReturn { + n.returnConn(conn) + } + return err +} + +// InstallSnapshot implements the Transport interface. +func (n *NetworkTransport) InstallSnapshot(id ServerID, target ServerAddress, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error { + // Get a conn, always close for InstallSnapshot + conn, err := n.getConnFromAddressProvider(id, target) + if err != nil { + return err + } + defer conn.Release() + + // Set a deadline, scaled by request size + if n.timeout > 0 { + timeout := n.timeout * time.Duration(args.Size/int64(n.TimeoutScale)) + if timeout < n.timeout { + timeout = n.timeout + } + conn.conn.SetDeadline(time.Now().Add(timeout)) + } + + // Send the RPC + if err = sendRPC(conn, rpcInstallSnapshot, args); err != nil { + return err + } + + // Stream the state + if _, err = io.Copy(conn.w, data); err != nil { + return err + } + + // Flush + if err = conn.w.Flush(); err != nil { + return err + } + + // Decode the response, do not return conn + _, err = decodeResponse(conn, resp) + return err +} + +// EncodePeer implements the Transport interface. +func (n *NetworkTransport) EncodePeer(id ServerID, p ServerAddress) []byte { + address := n.getProviderAddressOrFallback(id, p) + return []byte(address) +} + +// DecodePeer implements the Transport interface. +func (n *NetworkTransport) DecodePeer(buf []byte) ServerAddress { + return ServerAddress(buf) +} + +// listen is used to handling incoming connections. +func (n *NetworkTransport) listen() { + for { + // Accept incoming connections + conn, err := n.stream.Accept() + if err != nil { + if n.IsShutdown() { + return + } + n.logger.Printf("[ERR] raft-net: Failed to accept connection: %v", err) + continue + } + n.logger.Printf("[DEBUG] raft-net: %v accepted connection from: %v", n.LocalAddr(), conn.RemoteAddr()) + + // Handle the connection in dedicated routine + go n.handleConn(conn) + } +} + +// handleConn is used to handle an inbound connection for its lifespan. +func (n *NetworkTransport) handleConn(conn net.Conn) { + defer conn.Close() + r := bufio.NewReader(conn) + w := bufio.NewWriter(conn) + dec := codec.NewDecoder(r, &codec.MsgpackHandle{}) + enc := codec.NewEncoder(w, &codec.MsgpackHandle{}) + + for { + if err := n.handleCommand(r, dec, enc); err != nil { + if err != io.EOF { + n.logger.Printf("[ERR] raft-net: Failed to decode incoming command: %v", err) + } + return + } + if err := w.Flush(); err != nil { + n.logger.Printf("[ERR] raft-net: Failed to flush response: %v", err) + return + } + } +} + +// handleCommand is used to decode and dispatch a single command. +func (n *NetworkTransport) handleCommand(r *bufio.Reader, dec *codec.Decoder, enc *codec.Encoder) error { + // Get the rpc type + rpcType, err := r.ReadByte() + if err != nil { + return err + } + + // Create the RPC object + respCh := make(chan RPCResponse, 1) + rpc := RPC{ + RespChan: respCh, + } + + // Decode the command + isHeartbeat := false + switch rpcType { + case rpcAppendEntries: + var req AppendEntriesRequest + if err := dec.Decode(&req); err != nil { + return err + } + rpc.Command = &req + + // Check if this is a heartbeat + if req.Term != 0 && req.Leader != nil && + req.PrevLogEntry == 0 && req.PrevLogTerm == 0 && + len(req.Entries) == 0 && req.LeaderCommitIndex == 0 { + isHeartbeat = true + } + + case rpcRequestVote: + var req RequestVoteRequest + if err := dec.Decode(&req); err != nil { + return err + } + rpc.Command = &req + + case rpcInstallSnapshot: + var req InstallSnapshotRequest + if err := dec.Decode(&req); err != nil { + return err + } + rpc.Command = &req + rpc.Reader = io.LimitReader(r, req.Size) + + default: + return fmt.Errorf("unknown rpc type %d", rpcType) + } + + // Check for heartbeat fast-path + if isHeartbeat { + n.heartbeatFnLock.Lock() + fn := n.heartbeatFn + n.heartbeatFnLock.Unlock() + if fn != nil { + fn(rpc) + goto RESP + } + } + + // Dispatch the RPC + select { + case n.consumeCh <- rpc: + case <-n.shutdownCh: + return ErrTransportShutdown + } + + // Wait for response +RESP: + select { + case resp := <-respCh: + // Send the error first + respErr := "" + if resp.Error != nil { + respErr = resp.Error.Error() + } + if err := enc.Encode(respErr); err != nil { + return err + } + + // Send the response + if err := enc.Encode(resp.Response); err != nil { + return err + } + case <-n.shutdownCh: + return ErrTransportShutdown + } + return nil +} + +// decodeResponse is used to decode an RPC response and reports whether +// the connection can be reused. +func decodeResponse(conn *netConn, resp interface{}) (bool, error) { + // Decode the error if any + var rpcError string + if err := conn.dec.Decode(&rpcError); err != nil { + conn.Release() + return false, err + } + + // Decode the response + if err := conn.dec.Decode(resp); err != nil { + conn.Release() + return false, err + } + + // Format an error if any + if rpcError != "" { + return true, fmt.Errorf(rpcError) + } + return true, nil +} + +// sendRPC is used to encode and send the RPC. +func sendRPC(conn *netConn, rpcType uint8, args interface{}) error { + // Write the request type + if err := conn.w.WriteByte(rpcType); err != nil { + conn.Release() + return err + } + + // Send the request + if err := conn.enc.Encode(args); err != nil { + conn.Release() + return err + } + + // Flush + if err := conn.w.Flush(); err != nil { + conn.Release() + return err + } + return nil +} + +// newNetPipeline is used to construct a netPipeline from a given +// transport and connection. +func newNetPipeline(trans *NetworkTransport, conn *netConn) *netPipeline { + n := &netPipeline{ + conn: conn, + trans: trans, + doneCh: make(chan AppendFuture, rpcMaxPipeline), + inprogressCh: make(chan *appendFuture, rpcMaxPipeline), + shutdownCh: make(chan struct{}), + } + go n.decodeResponses() + return n +} + +// decodeResponses is a long running routine that decodes the responses +// sent on the connection. +func (n *netPipeline) decodeResponses() { + timeout := n.trans.timeout + for { + select { + case future := <-n.inprogressCh: + if timeout > 0 { + n.conn.conn.SetReadDeadline(time.Now().Add(timeout)) + } + + _, err := decodeResponse(n.conn, future.resp) + future.respond(err) + select { + case n.doneCh <- future: + case <-n.shutdownCh: + return + } + case <-n.shutdownCh: + return + } + } +} + +// AppendEntries is used to pipeline a new append entries request. +func (n *netPipeline) AppendEntries(args *AppendEntriesRequest, resp *AppendEntriesResponse) (AppendFuture, error) { + // Create a new future + future := &appendFuture{ + start: time.Now(), + args: args, + resp: resp, + } + future.init() + + // Add a send timeout + if timeout := n.trans.timeout; timeout > 0 { + n.conn.conn.SetWriteDeadline(time.Now().Add(timeout)) + } + + // Send the RPC + if err := sendRPC(n.conn, rpcAppendEntries, future.args); err != nil { + return nil, err + } + + // Hand-off for decoding, this can also cause back-pressure + // to prevent too many inflight requests + select { + case n.inprogressCh <- future: + return future, nil + case <-n.shutdownCh: + return nil, ErrPipelineShutdown + } +} + +// Consumer returns a channel that can be used to consume complete futures. +func (n *netPipeline) Consumer() <-chan AppendFuture { + return n.doneCh +} + +// Closed is used to shutdown the pipeline connection. +func (n *netPipeline) Close() error { + n.shutdownLock.Lock() + defer n.shutdownLock.Unlock() + if n.shutdown { + return nil + } + + // Release the connection + n.conn.Release() + + n.shutdown = true + close(n.shutdownCh) + return nil +} diff --git a/vendor/github.com/hashicorp/raft/observer.go b/vendor/github.com/hashicorp/raft/observer.go new file mode 100644 index 00000000000..76c4d555dfa --- /dev/null +++ b/vendor/github.com/hashicorp/raft/observer.go @@ -0,0 +1,117 @@ +package raft + +import ( + "sync/atomic" +) + +// Observation is sent along the given channel to observers when an event occurs. +type Observation struct { + // Raft holds the Raft instance generating the observation. + Raft *Raft + // Data holds observation-specific data. Possible types are + // *RequestVoteRequest and RaftState. + Data interface{} +} + +// nextObserverId is used to provide a unique ID for each observer to aid in +// deregistration. +var nextObserverID uint64 + +// FilterFn is a function that can be registered in order to filter observations. +// The function reports whether the observation should be included - if +// it returns false, the observation will be filtered out. +type FilterFn func(o *Observation) bool + +// Observer describes what to do with a given observation. +type Observer struct { + // numObserved and numDropped are performance counters for this observer. + // 64 bit types must be 64 bit aligned to use with atomic operations on + // 32 bit platforms, so keep them at the top of the struct. + numObserved uint64 + numDropped uint64 + + // channel receives observations. + channel chan Observation + + // blocking, if true, will cause Raft to block when sending an observation + // to this observer. This should generally be set to false. + blocking bool + + // filter will be called to determine if an observation should be sent to + // the channel. + filter FilterFn + + // id is the ID of this observer in the Raft map. + id uint64 +} + +// NewObserver creates a new observer that can be registered +// to make observations on a Raft instance. Observations +// will be sent on the given channel if they satisfy the +// given filter. +// +// If blocking is true, the observer will block when it can't +// send on the channel, otherwise it may discard events. +func NewObserver(channel chan Observation, blocking bool, filter FilterFn) *Observer { + return &Observer{ + channel: channel, + blocking: blocking, + filter: filter, + id: atomic.AddUint64(&nextObserverID, 1), + } +} + +// GetNumObserved returns the number of observations. +func (or *Observer) GetNumObserved() uint64 { + return atomic.LoadUint64(&or.numObserved) +} + +// GetNumDropped returns the number of dropped observations due to blocking. +func (or *Observer) GetNumDropped() uint64 { + return atomic.LoadUint64(&or.numDropped) +} + +// RegisterObserver registers a new observer. +func (r *Raft) RegisterObserver(or *Observer) { + r.observersLock.Lock() + defer r.observersLock.Unlock() + r.observers[or.id] = or +} + +// DeregisterObserver deregisters an observer. +func (r *Raft) DeregisterObserver(or *Observer) { + r.observersLock.Lock() + defer r.observersLock.Unlock() + delete(r.observers, or.id) +} + +// observe sends an observation to every observer. +func (r *Raft) observe(o interface{}) { + // In general observers should not block. But in any case this isn't + // disastrous as we only hold a read lock, which merely prevents + // registration / deregistration of observers. + r.observersLock.RLock() + defer r.observersLock.RUnlock() + for _, or := range r.observers { + // It's wasteful to do this in the loop, but for the common case + // where there are no observers we won't create any objects. + ob := Observation{Raft: r, Data: o} + if or.filter != nil && !or.filter(&ob) { + continue + } + if or.channel == nil { + continue + } + if or.blocking { + or.channel <- ob + atomic.AddUint64(&or.numObserved, 1) + } else { + select { + case or.channel <- ob: + atomic.AddUint64(&or.numObserved, 1) + default: + atomic.AddUint64(&or.numDropped, 1) + } + } + } +} diff --git a/vendor/github.com/hashicorp/raft/peersjson.go b/vendor/github.com/hashicorp/raft/peersjson.go new file mode 100644 index 00000000000..38ca2a8b845 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/peersjson.go @@ -0,0 +1,98 @@ +package raft + +import ( + "bytes" + "encoding/json" + "io/ioutil" +) + +// ReadPeersJSON consumes a legacy peers.json file in the format of the old JSON +// peer store and creates a new-style configuration structure. This can be used +// to migrate this data or perform manual recovery when running protocol versions +// that can interoperate with older, unversioned Raft servers. This should not be +// used once server IDs are in use, because the old peers.json file didn't have +// support for these, nor non-voter suffrage types. +func ReadPeersJSON(path string) (Configuration, error) { + // Read in the file. + buf, err := ioutil.ReadFile(path) + if err != nil { + return Configuration{}, err + } + + // Parse it as JSON. + var peers []string + dec := json.NewDecoder(bytes.NewReader(buf)) + if err := dec.Decode(&peers); err != nil { + return Configuration{}, err + } + + // Map it into the new-style configuration structure. We can only specify + // voter roles here, and the ID has to be the same as the address. + var configuration Configuration + for _, peer := range peers { + server := Server{ + Suffrage: Voter, + ID: ServerID(peer), + Address: ServerAddress(peer), + } + configuration.Servers = append(configuration.Servers, server) + } + + // We should only ingest valid configurations. + if err := checkConfiguration(configuration); err != nil { + return Configuration{}, err + } + return configuration, nil +} + +// configEntry is used when decoding a new-style peers.json. +type configEntry struct { + // ID is the ID of the server (a UUID, usually). + ID ServerID `json:"id"` + + // Address is the host:port of the server. + Address ServerAddress `json:"address"` + + // NonVoter controls the suffrage. We choose this sense so people + // can leave this out and get a Voter by default. + NonVoter bool `json:"non_voter"` +} + +// ReadConfigJSON reads a new-style peers.json and returns a configuration +// structure. This can be used to perform manual recovery when running protocol +// versions that use server IDs. +func ReadConfigJSON(path string) (Configuration, error) { + // Read in the file. + buf, err := ioutil.ReadFile(path) + if err != nil { + return Configuration{}, err + } + + // Parse it as JSON. + var peers []configEntry + dec := json.NewDecoder(bytes.NewReader(buf)) + if err := dec.Decode(&peers); err != nil { + return Configuration{}, err + } + + // Map it into the new-style configuration structure. + var configuration Configuration + for _, peer := range peers { + suffrage := Voter + if peer.NonVoter { + suffrage = Nonvoter + } + server := Server{ + Suffrage: suffrage, + ID: peer.ID, + Address: peer.Address, + } + configuration.Servers = append(configuration.Servers, server) + } + + // We should only ingest valid configurations. + if err := checkConfiguration(configuration); err != nil { + return Configuration{}, err + } + return configuration, nil +} diff --git a/vendor/github.com/hashicorp/raft/raft.go b/vendor/github.com/hashicorp/raft/raft.go new file mode 100644 index 00000000000..50ae6e916c6 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/raft.go @@ -0,0 +1,1459 @@ +package raft + +import ( + "bytes" + "container/list" + "fmt" + "io" + "io/ioutil" + "time" + + "github.com/armon/go-metrics" +) + +const ( + minCheckInterval = 10 * time.Millisecond +) + +var ( + keyCurrentTerm = []byte("CurrentTerm") + keyLastVoteTerm = []byte("LastVoteTerm") + keyLastVoteCand = []byte("LastVoteCand") +) + +// getRPCHeader returns an initialized RPCHeader struct for the given +// Raft instance. This structure is sent along with RPC requests and +// responses. +func (r *Raft) getRPCHeader() RPCHeader { + return RPCHeader{ + ProtocolVersion: r.conf.ProtocolVersion, + } +} + +// checkRPCHeader houses logic about whether this instance of Raft can process +// the given RPC message. +func (r *Raft) checkRPCHeader(rpc RPC) error { + // Get the header off the RPC message. + wh, ok := rpc.Command.(WithRPCHeader) + if !ok { + return fmt.Errorf("RPC does not have a header") + } + header := wh.GetRPCHeader() + + // First check is to just make sure the code can understand the + // protocol at all. + if header.ProtocolVersion < ProtocolVersionMin || + header.ProtocolVersion > ProtocolVersionMax { + return ErrUnsupportedProtocol + } + + // Second check is whether we should support this message, given the + // current protocol we are configured to run. This will drop support + // for protocol version 0 starting at protocol version 2, which is + // currently what we want, and in general support one version back. We + // may need to revisit this policy depending on how future protocol + // changes evolve. + if header.ProtocolVersion < r.conf.ProtocolVersion-1 { + return ErrUnsupportedProtocol + } + + return nil +} + +// getSnapshotVersion returns the snapshot version that should be used when +// creating snapshots, given the protocol version in use. +func getSnapshotVersion(protocolVersion ProtocolVersion) SnapshotVersion { + // Right now we only have two versions and they are backwards compatible + // so we don't need to look at the protocol version. + return 1 +} + +// commitTuple is used to send an index that was committed, +// with an optional associated future that should be invoked. +type commitTuple struct { + log *Log + future *logFuture +} + +// leaderState is state that is used while we are a leader. +type leaderState struct { + commitCh chan struct{} + commitment *commitment + inflight *list.List // list of logFuture in log index order + replState map[ServerID]*followerReplication + notify map[*verifyFuture]struct{} + stepDown chan struct{} +} + +// setLeader is used to modify the current leader of the cluster +func (r *Raft) setLeader(leader ServerAddress) { + r.leaderLock.Lock() + r.leader = leader + r.leaderLock.Unlock() +} + +// requestConfigChange is a helper for the above functions that make +// configuration change requests. 'req' describes the change. For timeout, +// see AddVoter. +func (r *Raft) requestConfigChange(req configurationChangeRequest, timeout time.Duration) IndexFuture { + var timer <-chan time.Time + if timeout > 0 { + timer = time.After(timeout) + } + future := &configurationChangeFuture{ + req: req, + } + future.init() + select { + case <-timer: + return errorFuture{ErrEnqueueTimeout} + case r.configurationChangeCh <- future: + return future + case <-r.shutdownCh: + return errorFuture{ErrRaftShutdown} + } +} + +// run is a long running goroutine that runs the Raft FSM. +func (r *Raft) run() { + for { + // Check if we are doing a shutdown + select { + case <-r.shutdownCh: + // Clear the leader to prevent forwarding + r.setLeader("") + return + default: + } + + // Enter into a sub-FSM + switch r.getState() { + case Follower: + r.runFollower() + case Candidate: + r.runCandidate() + case Leader: + r.runLeader() + } + } +} + +// runFollower runs the FSM for a follower. +func (r *Raft) runFollower() { + didWarn := false + r.logger.Printf("[INFO] raft: %v entering Follower state (Leader: %q)", r, r.Leader()) + metrics.IncrCounter([]string{"raft", "state", "follower"}, 1) + heartbeatTimer := randomTimeout(r.conf.HeartbeatTimeout) + for { + select { + case rpc := <-r.rpcCh: + r.processRPC(rpc) + + case c := <-r.configurationChangeCh: + // Reject any operations since we are not the leader + c.respond(ErrNotLeader) + + case a := <-r.applyCh: + // Reject any operations since we are not the leader + a.respond(ErrNotLeader) + + case v := <-r.verifyCh: + // Reject any operations since we are not the leader + v.respond(ErrNotLeader) + + case r := <-r.userRestoreCh: + // Reject any restores since we are not the leader + r.respond(ErrNotLeader) + + case c := <-r.configurationsCh: + c.configurations = r.configurations.Clone() + c.respond(nil) + + case b := <-r.bootstrapCh: + b.respond(r.liveBootstrap(b.configuration)) + + case <-heartbeatTimer: + // Restart the heartbeat timer + heartbeatTimer = randomTimeout(r.conf.HeartbeatTimeout) + + // Check if we have had a successful contact + lastContact := r.LastContact() + if time.Now().Sub(lastContact) < r.conf.HeartbeatTimeout { + continue + } + + // Heartbeat failed! Transition to the candidate state + lastLeader := r.Leader() + r.setLeader("") + + if r.configurations.latestIndex == 0 { + if !didWarn { + r.logger.Printf("[WARN] raft: no known peers, aborting election") + didWarn = true + } + } else if r.configurations.latestIndex == r.configurations.committedIndex && + !hasVote(r.configurations.latest, r.localID) { + if !didWarn { + r.logger.Printf("[WARN] raft: not part of stable configuration, aborting election") + didWarn = true + } + } else { + r.logger.Printf(`[WARN] raft: Heartbeat timeout from %q reached, starting election`, lastLeader) + metrics.IncrCounter([]string{"raft", "transition", "heartbeat_timeout"}, 1) + r.setState(Candidate) + return + } + + case <-r.shutdownCh: + return + } + } +} + +// liveBootstrap attempts to seed an initial configuration for the cluster. See +// the Raft object's member BootstrapCluster for more details. This must only be +// called on the main thread, and only makes sense in the follower state. +func (r *Raft) liveBootstrap(configuration Configuration) error { + // Use the pre-init API to make the static updates. + err := BootstrapCluster(&r.conf, r.logs, r.stable, r.snapshots, + r.trans, configuration) + if err != nil { + return err + } + + // Make the configuration live. + var entry Log + if err := r.logs.GetLog(1, &entry); err != nil { + panic(err) + } + r.setCurrentTerm(1) + r.setLastLog(entry.Index, entry.Term) + r.processConfigurationLogEntry(&entry) + return nil +} + +// runCandidate runs the FSM for a candidate. +func (r *Raft) runCandidate() { + r.logger.Printf("[INFO] raft: %v entering Candidate state in term %v", + r, r.getCurrentTerm()+1) + metrics.IncrCounter([]string{"raft", "state", "candidate"}, 1) + + // Start vote for us, and set a timeout + voteCh := r.electSelf() + electionTimer := randomTimeout(r.conf.ElectionTimeout) + + // Tally the votes, need a simple majority + grantedVotes := 0 + votesNeeded := r.quorumSize() + r.logger.Printf("[DEBUG] raft: Votes needed: %d", votesNeeded) + + for r.getState() == Candidate { + select { + case rpc := <-r.rpcCh: + r.processRPC(rpc) + + case vote := <-voteCh: + // Check if the term is greater than ours, bail + if vote.Term > r.getCurrentTerm() { + r.logger.Printf("[DEBUG] raft: Newer term discovered, fallback to follower") + r.setState(Follower) + r.setCurrentTerm(vote.Term) + return + } + + // Check if the vote is granted + if vote.Granted { + grantedVotes++ + r.logger.Printf("[DEBUG] raft: Vote granted from %s in term %v. Tally: %d", + vote.voterID, vote.Term, grantedVotes) + } + + // Check if we've become the leader + if grantedVotes >= votesNeeded { + r.logger.Printf("[INFO] raft: Election won. Tally: %d", grantedVotes) + r.setState(Leader) + r.setLeader(r.localAddr) + return + } + + case c := <-r.configurationChangeCh: + // Reject any operations since we are not the leader + c.respond(ErrNotLeader) + + case a := <-r.applyCh: + // Reject any operations since we are not the leader + a.respond(ErrNotLeader) + + case v := <-r.verifyCh: + // Reject any operations since we are not the leader + v.respond(ErrNotLeader) + + case r := <-r.userRestoreCh: + // Reject any restores since we are not the leader + r.respond(ErrNotLeader) + + case c := <-r.configurationsCh: + c.configurations = r.configurations.Clone() + c.respond(nil) + + case b := <-r.bootstrapCh: + b.respond(ErrCantBootstrap) + + case <-electionTimer: + // Election failed! Restart the election. We simply return, + // which will kick us back into runCandidate + r.logger.Printf("[WARN] raft: Election timeout reached, restarting election") + return + + case <-r.shutdownCh: + return + } + } +} + +// runLeader runs the FSM for a leader. Do the setup here and drop into +// the leaderLoop for the hot loop. +func (r *Raft) runLeader() { + r.logger.Printf("[INFO] raft: %v entering Leader state", r) + metrics.IncrCounter([]string{"raft", "state", "leader"}, 1) + + // Notify that we are the leader + asyncNotifyBool(r.leaderCh, true) + + // Push to the notify channel if given + if notify := r.conf.NotifyCh; notify != nil { + select { + case notify <- true: + case <-r.shutdownCh: + } + } + + // Setup leader state + r.leaderState.commitCh = make(chan struct{}, 1) + r.leaderState.commitment = newCommitment(r.leaderState.commitCh, + r.configurations.latest, + r.getLastIndex()+1 /* first index that may be committed in this term */) + r.leaderState.inflight = list.New() + r.leaderState.replState = make(map[ServerID]*followerReplication) + r.leaderState.notify = make(map[*verifyFuture]struct{}) + r.leaderState.stepDown = make(chan struct{}, 1) + + // Cleanup state on step down + defer func() { + // Since we were the leader previously, we update our + // last contact time when we step down, so that we are not + // reporting a last contact time from before we were the + // leader. Otherwise, to a client it would seem our data + // is extremely stale. + r.setLastContact() + + // Stop replication + for _, p := range r.leaderState.replState { + close(p.stopCh) + } + + // Respond to all inflight operations + for e := r.leaderState.inflight.Front(); e != nil; e = e.Next() { + e.Value.(*logFuture).respond(ErrLeadershipLost) + } + + // Respond to any pending verify requests + for future := range r.leaderState.notify { + future.respond(ErrLeadershipLost) + } + + // Clear all the state + r.leaderState.commitCh = nil + r.leaderState.commitment = nil + r.leaderState.inflight = nil + r.leaderState.replState = nil + r.leaderState.notify = nil + r.leaderState.stepDown = nil + + // If we are stepping down for some reason, no known leader. + // We may have stepped down due to an RPC call, which would + // provide the leader, so we cannot always blank this out. + r.leaderLock.Lock() + if r.leader == r.localAddr { + r.leader = "" + } + r.leaderLock.Unlock() + + // Notify that we are not the leader + asyncNotifyBool(r.leaderCh, false) + + // Push to the notify channel if given + if notify := r.conf.NotifyCh; notify != nil { + select { + case notify <- false: + case <-r.shutdownCh: + // On shutdown, make a best effort but do not block + select { + case notify <- false: + default: + } + } + } + }() + + // Start a replication routine for each peer + r.startStopReplication() + + // Dispatch a no-op log entry first. This gets this leader up to the latest + // possible commit index, even in the absence of client commands. This used + // to append a configuration entry instead of a noop. However, that permits + // an unbounded number of uncommitted configurations in the log. We now + // maintain that there exists at most one uncommitted configuration entry in + // any log, so we have to do proper no-ops here. + noop := &logFuture{ + log: Log{ + Type: LogNoop, + }, + } + r.dispatchLogs([]*logFuture{noop}) + + // Sit in the leader loop until we step down + r.leaderLoop() +} + +// startStopReplication will set up state and start asynchronous replication to +// new peers, and stop replication to removed peers. Before removing a peer, +// it'll instruct the replication routines to try to replicate to the current +// index. This must only be called from the main thread. +func (r *Raft) startStopReplication() { + inConfig := make(map[ServerID]bool, len(r.configurations.latest.Servers)) + lastIdx := r.getLastIndex() + + // Start replication goroutines that need starting + for _, server := range r.configurations.latest.Servers { + if server.ID == r.localID { + continue + } + inConfig[server.ID] = true + if _, ok := r.leaderState.replState[server.ID]; !ok { + r.logger.Printf("[INFO] raft: Added peer %v, starting replication", server.ID) + s := &followerReplication{ + peer: server, + commitment: r.leaderState.commitment, + stopCh: make(chan uint64, 1), + triggerCh: make(chan struct{}, 1), + currentTerm: r.getCurrentTerm(), + nextIndex: lastIdx + 1, + lastContact: time.Now(), + notifyCh: make(chan struct{}, 1), + stepDown: r.leaderState.stepDown, + } + r.leaderState.replState[server.ID] = s + r.goFunc(func() { r.replicate(s) }) + asyncNotifyCh(s.triggerCh) + } + } + + // Stop replication goroutines that need stopping + for serverID, repl := range r.leaderState.replState { + if inConfig[serverID] { + continue + } + // Replicate up to lastIdx and stop + r.logger.Printf("[INFO] raft: Removed peer %v, stopping replication after %v", serverID, lastIdx) + repl.stopCh <- lastIdx + close(repl.stopCh) + delete(r.leaderState.replState, serverID) + } +} + +// configurationChangeChIfStable returns r.configurationChangeCh if it's safe +// to process requests from it, or nil otherwise. This must only be called +// from the main thread. +// +// Note that if the conditions here were to change outside of leaderLoop to take +// this from nil to non-nil, we would need leaderLoop to be kicked. +func (r *Raft) configurationChangeChIfStable() chan *configurationChangeFuture { + // Have to wait until: + // 1. The latest configuration is committed, and + // 2. This leader has committed some entry (the noop) in this term + // https://groups.google.com/forum/#!msg/raft-dev/t4xj6dJTP6E/d2D9LrWRza8J + if r.configurations.latestIndex == r.configurations.committedIndex && + r.getCommitIndex() >= r.leaderState.commitment.startIndex { + return r.configurationChangeCh + } + return nil +} + +// leaderLoop is the hot loop for a leader. It is invoked +// after all the various leader setup is done. +func (r *Raft) leaderLoop() { + // stepDown is used to track if there is an inflight log that + // would cause us to lose leadership (specifically a RemovePeer of + // ourselves). If this is the case, we must not allow any logs to + // be processed in parallel, otherwise we are basing commit on + // only a single peer (ourself) and replicating to an undefined set + // of peers. + stepDown := false + + lease := time.After(r.conf.LeaderLeaseTimeout) + for r.getState() == Leader { + select { + case rpc := <-r.rpcCh: + r.processRPC(rpc) + + case <-r.leaderState.stepDown: + r.setState(Follower) + + case <-r.leaderState.commitCh: + // Process the newly committed entries + oldCommitIndex := r.getCommitIndex() + commitIndex := r.leaderState.commitment.getCommitIndex() + r.setCommitIndex(commitIndex) + + if r.configurations.latestIndex > oldCommitIndex && + r.configurations.latestIndex <= commitIndex { + r.configurations.committed = r.configurations.latest + r.configurations.committedIndex = r.configurations.latestIndex + if !hasVote(r.configurations.committed, r.localID) { + stepDown = true + } + } + + for { + e := r.leaderState.inflight.Front() + if e == nil { + break + } + commitLog := e.Value.(*logFuture) + idx := commitLog.log.Index + if idx > commitIndex { + break + } + // Measure the commit time + metrics.MeasureSince([]string{"raft", "commitTime"}, commitLog.dispatch) + r.processLogs(idx, commitLog) + r.leaderState.inflight.Remove(e) + } + + if stepDown { + if r.conf.ShutdownOnRemove { + r.logger.Printf("[INFO] raft: Removed ourself, shutting down") + r.Shutdown() + } else { + r.logger.Printf("[INFO] raft: Removed ourself, transitioning to follower") + r.setState(Follower) + } + } + + case v := <-r.verifyCh: + if v.quorumSize == 0 { + // Just dispatched, start the verification + r.verifyLeader(v) + + } else if v.votes < v.quorumSize { + // Early return, means there must be a new leader + r.logger.Printf("[WARN] raft: New leader elected, stepping down") + r.setState(Follower) + delete(r.leaderState.notify, v) + v.respond(ErrNotLeader) + + } else { + // Quorum of members agree, we are still leader + delete(r.leaderState.notify, v) + v.respond(nil) + } + + case future := <-r.userRestoreCh: + err := r.restoreUserSnapshot(future.meta, future.reader) + future.respond(err) + + case c := <-r.configurationsCh: + c.configurations = r.configurations.Clone() + c.respond(nil) + + case future := <-r.configurationChangeChIfStable(): + r.appendConfigurationEntry(future) + + case b := <-r.bootstrapCh: + b.respond(ErrCantBootstrap) + + case newLog := <-r.applyCh: + // Group commit, gather all the ready commits + ready := []*logFuture{newLog} + for i := 0; i < r.conf.MaxAppendEntries; i++ { + select { + case newLog := <-r.applyCh: + ready = append(ready, newLog) + default: + break + } + } + + // Dispatch the logs + if stepDown { + // we're in the process of stepping down as leader, don't process anything new + for i := range ready { + ready[i].respond(ErrNotLeader) + } + } else { + r.dispatchLogs(ready) + } + + case <-lease: + // Check if we've exceeded the lease, potentially stepping down + maxDiff := r.checkLeaderLease() + + // Next check interval should adjust for the last node we've + // contacted, without going negative + checkInterval := r.conf.LeaderLeaseTimeout - maxDiff + if checkInterval < minCheckInterval { + checkInterval = minCheckInterval + } + + // Renew the lease timer + lease = time.After(checkInterval) + + case <-r.shutdownCh: + return + } + } +} + +// verifyLeader must be called from the main thread for safety. +// Causes the followers to attempt an immediate heartbeat. +func (r *Raft) verifyLeader(v *verifyFuture) { + // Current leader always votes for self + v.votes = 1 + + // Set the quorum size, hot-path for single node + v.quorumSize = r.quorumSize() + if v.quorumSize == 1 { + v.respond(nil) + return + } + + // Track this request + v.notifyCh = r.verifyCh + r.leaderState.notify[v] = struct{}{} + + // Trigger immediate heartbeats + for _, repl := range r.leaderState.replState { + repl.notifyLock.Lock() + repl.notify = append(repl.notify, v) + repl.notifyLock.Unlock() + asyncNotifyCh(repl.notifyCh) + } +} + +// checkLeaderLease is used to check if we can contact a quorum of nodes +// within the last leader lease interval. If not, we need to step down, +// as we may have lost connectivity. Returns the maximum duration without +// contact. This must only be called from the main thread. +func (r *Raft) checkLeaderLease() time.Duration { + // Track contacted nodes, we can always contact ourself + contacted := 1 + + // Check each follower + var maxDiff time.Duration + now := time.Now() + for peer, f := range r.leaderState.replState { + diff := now.Sub(f.LastContact()) + if diff <= r.conf.LeaderLeaseTimeout { + contacted++ + if diff > maxDiff { + maxDiff = diff + } + } else { + // Log at least once at high value, then debug. Otherwise it gets very verbose. + if diff <= 3*r.conf.LeaderLeaseTimeout { + r.logger.Printf("[WARN] raft: Failed to contact %v in %v", peer, diff) + } else { + r.logger.Printf("[DEBUG] raft: Failed to contact %v in %v", peer, diff) + } + } + metrics.AddSample([]string{"raft", "leader", "lastContact"}, float32(diff/time.Millisecond)) + } + + // Verify we can contact a quorum + quorum := r.quorumSize() + if contacted < quorum { + r.logger.Printf("[WARN] raft: Failed to contact quorum of nodes, stepping down") + r.setState(Follower) + metrics.IncrCounter([]string{"raft", "transition", "leader_lease_timeout"}, 1) + } + return maxDiff +} + +// quorumSize is used to return the quorum size. This must only be called on +// the main thread. +// TODO: revisit usage +func (r *Raft) quorumSize() int { + voters := 0 + for _, server := range r.configurations.latest.Servers { + if server.Suffrage == Voter { + voters++ + } + } + return voters/2 + 1 +} + +// restoreUserSnapshot is used to manually consume an external snapshot, such +// as if restoring from a backup. We will use the current Raft configuration, +// not the one from the snapshot, so that we can restore into a new cluster. We +// will also use the higher of the index of the snapshot, or the current index, +// and then add 1 to that, so we force a new state with a hole in the Raft log, +// so that the snapshot will be sent to followers and used for any new joiners. +// This can only be run on the leader, and returns a future that can be used to +// block until complete. +func (r *Raft) restoreUserSnapshot(meta *SnapshotMeta, reader io.Reader) error { + defer metrics.MeasureSince([]string{"raft", "restoreUserSnapshot"}, time.Now()) + + // Sanity check the version. + version := meta.Version + if version < SnapshotVersionMin || version > SnapshotVersionMax { + return fmt.Errorf("unsupported snapshot version %d", version) + } + + // We don't support snapshots while there's a config change + // outstanding since the snapshot doesn't have a means to + // represent this state. + committedIndex := r.configurations.committedIndex + latestIndex := r.configurations.latestIndex + if committedIndex != latestIndex { + return fmt.Errorf("cannot restore snapshot now, wait until the configuration entry at %v has been applied (have applied %v)", + latestIndex, committedIndex) + } + + // Cancel any inflight requests. + for { + e := r.leaderState.inflight.Front() + if e == nil { + break + } + e.Value.(*logFuture).respond(ErrAbortedByRestore) + r.leaderState.inflight.Remove(e) + } + + // We will overwrite the snapshot metadata with the current term, + // an index that's greater than the current index, or the last + // index in the snapshot. It's important that we leave a hole in + // the index so we know there's nothing in the Raft log there and + // replication will fault and send the snapshot. + term := r.getCurrentTerm() + lastIndex := r.getLastIndex() + if meta.Index > lastIndex { + lastIndex = meta.Index + } + lastIndex++ + + // Dump the snapshot. Note that we use the latest configuration, + // not the one that came with the snapshot. + sink, err := r.snapshots.Create(version, lastIndex, term, + r.configurations.latest, r.configurations.latestIndex, r.trans) + if err != nil { + return fmt.Errorf("failed to create snapshot: %v", err) + } + n, err := io.Copy(sink, reader) + if err != nil { + sink.Cancel() + return fmt.Errorf("failed to write snapshot: %v", err) + } + if n != meta.Size { + sink.Cancel() + return fmt.Errorf("failed to write snapshot, size didn't match (%d != %d)", n, meta.Size) + } + if err := sink.Close(); err != nil { + return fmt.Errorf("failed to close snapshot: %v", err) + } + r.logger.Printf("[INFO] raft: Copied %d bytes to local snapshot", n) + + // Restore the snapshot into the FSM. If this fails we are in a + // bad state so we panic to take ourselves out. + fsm := &restoreFuture{ID: sink.ID()} + fsm.init() + select { + case r.fsmMutateCh <- fsm: + case <-r.shutdownCh: + return ErrRaftShutdown + } + if err := fsm.Error(); err != nil { + panic(fmt.Errorf("failed to restore snapshot: %v", err)) + } + + // We set the last log so it looks like we've stored the empty + // index we burned. The last applied is set because we made the + // FSM take the snapshot state, and we store the last snapshot + // in the stable store since we created a snapshot as part of + // this process. + r.setLastLog(lastIndex, term) + r.setLastApplied(lastIndex) + r.setLastSnapshot(lastIndex, term) + + r.logger.Printf("[INFO] raft: Restored user snapshot (index %d)", lastIndex) + return nil +} + +// appendConfigurationEntry changes the configuration and adds a new +// configuration entry to the log. This must only be called from the +// main thread. +func (r *Raft) appendConfigurationEntry(future *configurationChangeFuture) { + configuration, err := nextConfiguration(r.configurations.latest, r.configurations.latestIndex, future.req) + if err != nil { + future.respond(err) + return + } + + r.logger.Printf("[INFO] raft: Updating configuration with %s (%v, %v) to %+v", + future.req.command, future.req.serverID, future.req.serverAddress, configuration.Servers) + + // In pre-ID compatibility mode we translate all configuration changes + // in to an old remove peer message, which can handle all supported + // cases for peer changes in the pre-ID world (adding and removing + // voters). Both add peer and remove peer log entries are handled + // similarly on old Raft servers, but remove peer does extra checks to + // see if a leader needs to step down. Since they both assert the full + // configuration, then we can safely call remove peer for everything. + if r.protocolVersion < 2 { + future.log = Log{ + Type: LogRemovePeerDeprecated, + Data: encodePeers(configuration, r.trans), + } + } else { + future.log = Log{ + Type: LogConfiguration, + Data: encodeConfiguration(configuration), + } + } + + r.dispatchLogs([]*logFuture{&future.logFuture}) + index := future.Index() + r.configurations.latest = configuration + r.configurations.latestIndex = index + r.leaderState.commitment.setConfiguration(configuration) + r.startStopReplication() +} + +// dispatchLog is called on the leader to push a log to disk, mark it +// as inflight and begin replication of it. +func (r *Raft) dispatchLogs(applyLogs []*logFuture) { + now := time.Now() + defer metrics.MeasureSince([]string{"raft", "leader", "dispatchLog"}, now) + + term := r.getCurrentTerm() + lastIndex := r.getLastIndex() + logs := make([]*Log, len(applyLogs)) + + for idx, applyLog := range applyLogs { + applyLog.dispatch = now + lastIndex++ + applyLog.log.Index = lastIndex + applyLog.log.Term = term + logs[idx] = &applyLog.log + r.leaderState.inflight.PushBack(applyLog) + } + + // Write the log entry locally + if err := r.logs.StoreLogs(logs); err != nil { + r.logger.Printf("[ERR] raft: Failed to commit logs: %v", err) + for _, applyLog := range applyLogs { + applyLog.respond(err) + } + r.setState(Follower) + return + } + r.leaderState.commitment.match(r.localID, lastIndex) + + // Update the last log since it's on disk now + r.setLastLog(lastIndex, term) + + // Notify the replicators of the new log + for _, f := range r.leaderState.replState { + asyncNotifyCh(f.triggerCh) + } +} + +// processLogs is used to apply all the committed entires that haven't been +// applied up to the given index limit. +// This can be called from both leaders and followers. +// Followers call this from AppendEntires, for n entires at a time, and always +// pass future=nil. +// Leaders call this once per inflight when entries are committed. They pass +// the future from inflights. +func (r *Raft) processLogs(index uint64, future *logFuture) { + // Reject logs we've applied already + lastApplied := r.getLastApplied() + if index <= lastApplied { + r.logger.Printf("[WARN] raft: Skipping application of old log: %d", index) + return + } + + // Apply all the preceding logs + for idx := r.getLastApplied() + 1; idx <= index; idx++ { + // Get the log, either from the future or from our log store + if future != nil && future.log.Index == idx { + r.processLog(&future.log, future) + + } else { + l := new(Log) + if err := r.logs.GetLog(idx, l); err != nil { + r.logger.Printf("[ERR] raft: Failed to get log at %d: %v", idx, err) + panic(err) + } + r.processLog(l, nil) + } + + // Update the lastApplied index and term + r.setLastApplied(idx) + } +} + +// processLog is invoked to process the application of a single committed log entry. +func (r *Raft) processLog(l *Log, future *logFuture) { + switch l.Type { + case LogBarrier: + // Barrier is handled by the FSM + fallthrough + + case LogCommand: + // Forward to the fsm handler + select { + case r.fsmMutateCh <- &commitTuple{l, future}: + case <-r.shutdownCh: + if future != nil { + future.respond(ErrRaftShutdown) + } + } + + // Return so that the future is only responded to + // by the FSM handler when the application is done + return + + case LogConfiguration: + case LogAddPeerDeprecated: + case LogRemovePeerDeprecated: + case LogNoop: + // Ignore the no-op + + default: + panic(fmt.Errorf("unrecognized log type: %#v", l)) + } + + // Invoke the future if given + if future != nil { + future.respond(nil) + } +} + +// processRPC is called to handle an incoming RPC request. This must only be +// called from the main thread. +func (r *Raft) processRPC(rpc RPC) { + if err := r.checkRPCHeader(rpc); err != nil { + rpc.Respond(nil, err) + return + } + + switch cmd := rpc.Command.(type) { + case *AppendEntriesRequest: + r.appendEntries(rpc, cmd) + case *RequestVoteRequest: + r.requestVote(rpc, cmd) + case *InstallSnapshotRequest: + r.installSnapshot(rpc, cmd) + default: + r.logger.Printf("[ERR] raft: Got unexpected command: %#v", rpc.Command) + rpc.Respond(nil, fmt.Errorf("unexpected command")) + } +} + +// processHeartbeat is a special handler used just for heartbeat requests +// so that they can be fast-pathed if a transport supports it. This must only +// be called from the main thread. +func (r *Raft) processHeartbeat(rpc RPC) { + defer metrics.MeasureSince([]string{"raft", "rpc", "processHeartbeat"}, time.Now()) + + // Check if we are shutdown, just ignore the RPC + select { + case <-r.shutdownCh: + return + default: + } + + // Ensure we are only handling a heartbeat + switch cmd := rpc.Command.(type) { + case *AppendEntriesRequest: + r.appendEntries(rpc, cmd) + default: + r.logger.Printf("[ERR] raft: Expected heartbeat, got command: %#v", rpc.Command) + rpc.Respond(nil, fmt.Errorf("unexpected command")) + } +} + +// appendEntries is invoked when we get an append entries RPC call. This must +// only be called from the main thread. +func (r *Raft) appendEntries(rpc RPC, a *AppendEntriesRequest) { + defer metrics.MeasureSince([]string{"raft", "rpc", "appendEntries"}, time.Now()) + // Setup a response + resp := &AppendEntriesResponse{ + RPCHeader: r.getRPCHeader(), + Term: r.getCurrentTerm(), + LastLog: r.getLastIndex(), + Success: false, + NoRetryBackoff: false, + } + var rpcErr error + defer func() { + rpc.Respond(resp, rpcErr) + }() + + // Ignore an older term + if a.Term < r.getCurrentTerm() { + return + } + + // Increase the term if we see a newer one, also transition to follower + // if we ever get an appendEntries call + if a.Term > r.getCurrentTerm() || r.getState() != Follower { + // Ensure transition to follower + r.setState(Follower) + r.setCurrentTerm(a.Term) + resp.Term = a.Term + } + + // Save the current leader + r.setLeader(ServerAddress(r.trans.DecodePeer(a.Leader))) + + // Verify the last log entry + if a.PrevLogEntry > 0 { + lastIdx, lastTerm := r.getLastEntry() + + var prevLogTerm uint64 + if a.PrevLogEntry == lastIdx { + prevLogTerm = lastTerm + + } else { + var prevLog Log + if err := r.logs.GetLog(a.PrevLogEntry, &prevLog); err != nil { + r.logger.Printf("[WARN] raft: Failed to get previous log: %d %v (last: %d)", + a.PrevLogEntry, err, lastIdx) + resp.NoRetryBackoff = true + return + } + prevLogTerm = prevLog.Term + } + + if a.PrevLogTerm != prevLogTerm { + r.logger.Printf("[WARN] raft: Previous log term mis-match: ours: %d remote: %d", + prevLogTerm, a.PrevLogTerm) + resp.NoRetryBackoff = true + return + } + } + + // Process any new entries + if len(a.Entries) > 0 { + start := time.Now() + + // Delete any conflicting entries, skip any duplicates + lastLogIdx, _ := r.getLastLog() + var newEntries []*Log + for i, entry := range a.Entries { + if entry.Index > lastLogIdx { + newEntries = a.Entries[i:] + break + } + var storeEntry Log + if err := r.logs.GetLog(entry.Index, &storeEntry); err != nil { + r.logger.Printf("[WARN] raft: Failed to get log entry %d: %v", + entry.Index, err) + return + } + if entry.Term != storeEntry.Term { + r.logger.Printf("[WARN] raft: Clearing log suffix from %d to %d", entry.Index, lastLogIdx) + if err := r.logs.DeleteRange(entry.Index, lastLogIdx); err != nil { + r.logger.Printf("[ERR] raft: Failed to clear log suffix: %v", err) + return + } + if entry.Index <= r.configurations.latestIndex { + r.configurations.latest = r.configurations.committed + r.configurations.latestIndex = r.configurations.committedIndex + } + newEntries = a.Entries[i:] + break + } + } + + if n := len(newEntries); n > 0 { + // Append the new entries + if err := r.logs.StoreLogs(newEntries); err != nil { + r.logger.Printf("[ERR] raft: Failed to append to logs: %v", err) + // TODO: leaving r.getLastLog() in the wrong + // state if there was a truncation above + return + } + + // Handle any new configuration changes + for _, newEntry := range newEntries { + r.processConfigurationLogEntry(newEntry) + } + + // Update the lastLog + last := newEntries[n-1] + r.setLastLog(last.Index, last.Term) + } + + metrics.MeasureSince([]string{"raft", "rpc", "appendEntries", "storeLogs"}, start) + } + + // Update the commit index + if a.LeaderCommitIndex > 0 && a.LeaderCommitIndex > r.getCommitIndex() { + start := time.Now() + idx := min(a.LeaderCommitIndex, r.getLastIndex()) + r.setCommitIndex(idx) + if r.configurations.latestIndex <= idx { + r.configurations.committed = r.configurations.latest + r.configurations.committedIndex = r.configurations.latestIndex + } + r.processLogs(idx, nil) + metrics.MeasureSince([]string{"raft", "rpc", "appendEntries", "processLogs"}, start) + } + + // Everything went well, set success + resp.Success = true + r.setLastContact() + return +} + +// processConfigurationLogEntry takes a log entry and updates the latest +// configuration if the entry results in a new configuration. This must only be +// called from the main thread, or from NewRaft() before any threads have begun. +func (r *Raft) processConfigurationLogEntry(entry *Log) { + if entry.Type == LogConfiguration { + r.configurations.committed = r.configurations.latest + r.configurations.committedIndex = r.configurations.latestIndex + r.configurations.latest = decodeConfiguration(entry.Data) + r.configurations.latestIndex = entry.Index + } else if entry.Type == LogAddPeerDeprecated || entry.Type == LogRemovePeerDeprecated { + r.configurations.committed = r.configurations.latest + r.configurations.committedIndex = r.configurations.latestIndex + r.configurations.latest = decodePeers(entry.Data, r.trans) + r.configurations.latestIndex = entry.Index + } +} + +// requestVote is invoked when we get an request vote RPC call. +func (r *Raft) requestVote(rpc RPC, req *RequestVoteRequest) { + defer metrics.MeasureSince([]string{"raft", "rpc", "requestVote"}, time.Now()) + r.observe(*req) + + // Setup a response + resp := &RequestVoteResponse{ + RPCHeader: r.getRPCHeader(), + Term: r.getCurrentTerm(), + Granted: false, + } + var rpcErr error + defer func() { + rpc.Respond(resp, rpcErr) + }() + + // Version 0 servers will panic unless the peers is present. It's only + // used on them to produce a warning message. + if r.protocolVersion < 2 { + resp.Peers = encodePeers(r.configurations.latest, r.trans) + } + + // Check if we have an existing leader [who's not the candidate] + candidate := r.trans.DecodePeer(req.Candidate) + if leader := r.Leader(); leader != "" && leader != candidate { + r.logger.Printf("[WARN] raft: Rejecting vote request from %v since we have a leader: %v", + candidate, leader) + return + } + + // Ignore an older term + if req.Term < r.getCurrentTerm() { + return + } + + // Increase the term if we see a newer one + if req.Term > r.getCurrentTerm() { + // Ensure transition to follower + r.setState(Follower) + r.setCurrentTerm(req.Term) + resp.Term = req.Term + } + + // Check if we have voted yet + lastVoteTerm, err := r.stable.GetUint64(keyLastVoteTerm) + if err != nil && err.Error() != "not found" { + r.logger.Printf("[ERR] raft: Failed to get last vote term: %v", err) + return + } + lastVoteCandBytes, err := r.stable.Get(keyLastVoteCand) + if err != nil && err.Error() != "not found" { + r.logger.Printf("[ERR] raft: Failed to get last vote candidate: %v", err) + return + } + + // Check if we've voted in this election before + if lastVoteTerm == req.Term && lastVoteCandBytes != nil { + r.logger.Printf("[INFO] raft: Duplicate RequestVote for same term: %d", req.Term) + if bytes.Compare(lastVoteCandBytes, req.Candidate) == 0 { + r.logger.Printf("[WARN] raft: Duplicate RequestVote from candidate: %s", req.Candidate) + resp.Granted = true + } + return + } + + // Reject if their term is older + lastIdx, lastTerm := r.getLastEntry() + if lastTerm > req.LastLogTerm { + r.logger.Printf("[WARN] raft: Rejecting vote request from %v since our last term is greater (%d, %d)", + candidate, lastTerm, req.LastLogTerm) + return + } + + if lastTerm == req.LastLogTerm && lastIdx > req.LastLogIndex { + r.logger.Printf("[WARN] raft: Rejecting vote request from %v since our last index is greater (%d, %d)", + candidate, lastIdx, req.LastLogIndex) + return + } + + // Persist a vote for safety + if err := r.persistVote(req.Term, req.Candidate); err != nil { + r.logger.Printf("[ERR] raft: Failed to persist vote: %v", err) + return + } + + resp.Granted = true + r.setLastContact() + return +} + +// installSnapshot is invoked when we get a InstallSnapshot RPC call. +// We must be in the follower state for this, since it means we are +// too far behind a leader for log replay. This must only be called +// from the main thread. +func (r *Raft) installSnapshot(rpc RPC, req *InstallSnapshotRequest) { + defer metrics.MeasureSince([]string{"raft", "rpc", "installSnapshot"}, time.Now()) + // Setup a response + resp := &InstallSnapshotResponse{ + Term: r.getCurrentTerm(), + Success: false, + } + var rpcErr error + defer func() { + io.Copy(ioutil.Discard, rpc.Reader) // ensure we always consume all the snapshot data from the stream [see issue #212] + rpc.Respond(resp, rpcErr) + }() + + // Sanity check the version + if req.SnapshotVersion < SnapshotVersionMin || + req.SnapshotVersion > SnapshotVersionMax { + rpcErr = fmt.Errorf("unsupported snapshot version %d", req.SnapshotVersion) + return + } + + // Ignore an older term + if req.Term < r.getCurrentTerm() { + r.logger.Printf("[INFO] raft: Ignoring installSnapshot request with older term of %d vs currentTerm %d", req.Term, r.getCurrentTerm()) + return + } + + // Increase the term if we see a newer one + if req.Term > r.getCurrentTerm() { + // Ensure transition to follower + r.setState(Follower) + r.setCurrentTerm(req.Term) + resp.Term = req.Term + } + + // Save the current leader + r.setLeader(ServerAddress(r.trans.DecodePeer(req.Leader))) + + // Create a new snapshot + var reqConfiguration Configuration + var reqConfigurationIndex uint64 + if req.SnapshotVersion > 0 { + reqConfiguration = decodeConfiguration(req.Configuration) + reqConfigurationIndex = req.ConfigurationIndex + } else { + reqConfiguration = decodePeers(req.Peers, r.trans) + reqConfigurationIndex = req.LastLogIndex + } + version := getSnapshotVersion(r.protocolVersion) + sink, err := r.snapshots.Create(version, req.LastLogIndex, req.LastLogTerm, + reqConfiguration, reqConfigurationIndex, r.trans) + if err != nil { + r.logger.Printf("[ERR] raft: Failed to create snapshot to install: %v", err) + rpcErr = fmt.Errorf("failed to create snapshot: %v", err) + return + } + + // Spill the remote snapshot to disk + n, err := io.Copy(sink, rpc.Reader) + if err != nil { + sink.Cancel() + r.logger.Printf("[ERR] raft: Failed to copy snapshot: %v", err) + rpcErr = err + return + } + + // Check that we received it all + if n != req.Size { + sink.Cancel() + r.logger.Printf("[ERR] raft: Failed to receive whole snapshot: %d / %d", n, req.Size) + rpcErr = fmt.Errorf("short read") + return + } + + // Finalize the snapshot + if err := sink.Close(); err != nil { + r.logger.Printf("[ERR] raft: Failed to finalize snapshot: %v", err) + rpcErr = err + return + } + r.logger.Printf("[INFO] raft: Copied %d bytes to local snapshot", n) + + // Restore snapshot + future := &restoreFuture{ID: sink.ID()} + future.init() + select { + case r.fsmMutateCh <- future: + case <-r.shutdownCh: + future.respond(ErrRaftShutdown) + return + } + + // Wait for the restore to happen + if err := future.Error(); err != nil { + r.logger.Printf("[ERR] raft: Failed to restore snapshot: %v", err) + rpcErr = err + return + } + + // Update the lastApplied so we don't replay old logs + r.setLastApplied(req.LastLogIndex) + + // Update the last stable snapshot info + r.setLastSnapshot(req.LastLogIndex, req.LastLogTerm) + + // Restore the peer set + r.configurations.latest = reqConfiguration + r.configurations.latestIndex = reqConfigurationIndex + r.configurations.committed = reqConfiguration + r.configurations.committedIndex = reqConfigurationIndex + + // Compact logs, continue even if this fails + if err := r.compactLogs(req.LastLogIndex); err != nil { + r.logger.Printf("[ERR] raft: Failed to compact logs: %v", err) + } + + r.logger.Printf("[INFO] raft: Installed remote snapshot") + resp.Success = true + r.setLastContact() + return +} + +// setLastContact is used to set the last contact time to now +func (r *Raft) setLastContact() { + r.lastContactLock.Lock() + r.lastContact = time.Now() + r.lastContactLock.Unlock() +} + +type voteResult struct { + RequestVoteResponse + voterID ServerID +} + +// electSelf is used to send a RequestVote RPC to all peers, and vote for +// ourself. This has the side affecting of incrementing the current term. The +// response channel returned is used to wait for all the responses (including a +// vote for ourself). This must only be called from the main thread. +func (r *Raft) electSelf() <-chan *voteResult { + // Create a response channel + respCh := make(chan *voteResult, len(r.configurations.latest.Servers)) + + // Increment the term + r.setCurrentTerm(r.getCurrentTerm() + 1) + + // Construct the request + lastIdx, lastTerm := r.getLastEntry() + req := &RequestVoteRequest{ + RPCHeader: r.getRPCHeader(), + Term: r.getCurrentTerm(), + Candidate: r.trans.EncodePeer(r.localID, r.localAddr), + LastLogIndex: lastIdx, + LastLogTerm: lastTerm, + } + + // Construct a function to ask for a vote + askPeer := func(peer Server) { + r.goFunc(func() { + defer metrics.MeasureSince([]string{"raft", "candidate", "electSelf"}, time.Now()) + resp := &voteResult{voterID: peer.ID} + err := r.trans.RequestVote(peer.ID, peer.Address, req, &resp.RequestVoteResponse) + if err != nil { + r.logger.Printf("[ERR] raft: Failed to make RequestVote RPC to %v: %v", peer, err) + resp.Term = req.Term + resp.Granted = false + } + respCh <- resp + }) + } + + // For each peer, request a vote + for _, server := range r.configurations.latest.Servers { + if server.Suffrage == Voter { + if server.ID == r.localID { + // Persist a vote for ourselves + if err := r.persistVote(req.Term, req.Candidate); err != nil { + r.logger.Printf("[ERR] raft: Failed to persist vote : %v", err) + return nil + } + // Include our own vote + respCh <- &voteResult{ + RequestVoteResponse: RequestVoteResponse{ + RPCHeader: r.getRPCHeader(), + Term: req.Term, + Granted: true, + }, + voterID: r.localID, + } + } else { + askPeer(server) + } + } + } + + return respCh +} + +// persistVote is used to persist our vote for safety. +func (r *Raft) persistVote(term uint64, candidate []byte) error { + if err := r.stable.SetUint64(keyLastVoteTerm, term); err != nil { + return err + } + if err := r.stable.Set(keyLastVoteCand, candidate); err != nil { + return err + } + return nil +} + +// setCurrentTerm is used to set the current term in a durable manner. +func (r *Raft) setCurrentTerm(t uint64) { + // Persist to disk first + if err := r.stable.SetUint64(keyCurrentTerm, t); err != nil { + panic(fmt.Errorf("failed to save current term: %v", err)) + } + r.raftState.setCurrentTerm(t) +} + +// setState is used to update the current state. Any state +// transition causes the known leader to be cleared. This means +// that leader should be set only after updating the state. +func (r *Raft) setState(state RaftState) { + r.setLeader("") + oldState := r.raftState.getState() + r.raftState.setState(state) + if oldState != state { + r.observe(state) + } +} diff --git a/vendor/github.com/hashicorp/raft/replication.go b/vendor/github.com/hashicorp/raft/replication.go new file mode 100644 index 00000000000..e631b5a09ba --- /dev/null +++ b/vendor/github.com/hashicorp/raft/replication.go @@ -0,0 +1,561 @@ +package raft + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/armon/go-metrics" +) + +const ( + maxFailureScale = 12 + failureWait = 10 * time.Millisecond +) + +var ( + // ErrLogNotFound indicates a given log entry is not available. + ErrLogNotFound = errors.New("log not found") + + // ErrPipelineReplicationNotSupported can be returned by the transport to + // signal that pipeline replication is not supported in general, and that + // no error message should be produced. + ErrPipelineReplicationNotSupported = errors.New("pipeline replication not supported") +) + +// followerReplication is in charge of sending snapshots and log entries from +// this leader during this particular term to a remote follower. +type followerReplication struct { + // peer contains the network address and ID of the remote follower. + peer Server + + // commitment tracks the entries acknowledged by followers so that the + // leader's commit index can advance. It is updated on successsful + // AppendEntries responses. + commitment *commitment + + // stopCh is notified/closed when this leader steps down or the follower is + // removed from the cluster. In the follower removed case, it carries a log + // index; replication should be attempted with a best effort up through that + // index, before exiting. + stopCh chan uint64 + // triggerCh is notified every time new entries are appended to the log. + triggerCh chan struct{} + + // currentTerm is the term of this leader, to be included in AppendEntries + // requests. + currentTerm uint64 + // nextIndex is the index of the next log entry to send to the follower, + // which may fall past the end of the log. + nextIndex uint64 + + // lastContact is updated to the current time whenever any response is + // received from the follower (successful or not). This is used to check + // whether the leader should step down (Raft.checkLeaderLease()). + lastContact time.Time + // lastContactLock protects 'lastContact'. + lastContactLock sync.RWMutex + + // failures counts the number of failed RPCs since the last success, which is + // used to apply backoff. + failures uint64 + + // notifyCh is notified to send out a heartbeat, which is used to check that + // this server is still leader. + notifyCh chan struct{} + // notify is a list of futures to be resolved upon receipt of an + // acknowledgement, then cleared from this list. + notify []*verifyFuture + // notifyLock protects 'notify'. + notifyLock sync.Mutex + + // stepDown is used to indicate to the leader that we + // should step down based on information from a follower. + stepDown chan struct{} + + // allowPipeline is used to determine when to pipeline the AppendEntries RPCs. + // It is private to this replication goroutine. + allowPipeline bool +} + +// notifyAll is used to notify all the waiting verify futures +// if the follower believes we are still the leader. +func (s *followerReplication) notifyAll(leader bool) { + // Clear the waiting notifies minimizing lock time + s.notifyLock.Lock() + n := s.notify + s.notify = nil + s.notifyLock.Unlock() + + // Submit our votes + for _, v := range n { + v.vote(leader) + } +} + +// LastContact returns the time of last contact. +func (s *followerReplication) LastContact() time.Time { + s.lastContactLock.RLock() + last := s.lastContact + s.lastContactLock.RUnlock() + return last +} + +// setLastContact sets the last contact to the current time. +func (s *followerReplication) setLastContact() { + s.lastContactLock.Lock() + s.lastContact = time.Now() + s.lastContactLock.Unlock() +} + +// replicate is a long running routine that replicates log entries to a single +// follower. +func (r *Raft) replicate(s *followerReplication) { + // Start an async heartbeating routing + stopHeartbeat := make(chan struct{}) + defer close(stopHeartbeat) + r.goFunc(func() { r.heartbeat(s, stopHeartbeat) }) + +RPC: + shouldStop := false + for !shouldStop { + select { + case maxIndex := <-s.stopCh: + // Make a best effort to replicate up to this index + if maxIndex > 0 { + r.replicateTo(s, maxIndex) + } + return + case <-s.triggerCh: + lastLogIdx, _ := r.getLastLog() + shouldStop = r.replicateTo(s, lastLogIdx) + case <-randomTimeout(r.conf.CommitTimeout): // TODO: what is this? + lastLogIdx, _ := r.getLastLog() + shouldStop = r.replicateTo(s, lastLogIdx) + } + + // If things looks healthy, switch to pipeline mode + if !shouldStop && s.allowPipeline { + goto PIPELINE + } + } + return + +PIPELINE: + // Disable until re-enabled + s.allowPipeline = false + + // Replicates using a pipeline for high performance. This method + // is not able to gracefully recover from errors, and so we fall back + // to standard mode on failure. + if err := r.pipelineReplicate(s); err != nil { + if err != ErrPipelineReplicationNotSupported { + r.logger.Printf("[ERR] raft: Failed to start pipeline replication to %s: %s", s.peer, err) + } + } + goto RPC +} + +// replicateTo is a helper to replicate(), used to replicate the logs up to a +// given last index. +// If the follower log is behind, we take care to bring them up to date. +func (r *Raft) replicateTo(s *followerReplication, lastIndex uint64) (shouldStop bool) { + // Create the base request + var req AppendEntriesRequest + var resp AppendEntriesResponse + var start time.Time +START: + // Prevent an excessive retry rate on errors + if s.failures > 0 { + select { + case <-time.After(backoff(failureWait, s.failures, maxFailureScale)): + case <-r.shutdownCh: + } + } + + // Setup the request + if err := r.setupAppendEntries(s, &req, s.nextIndex, lastIndex); err == ErrLogNotFound { + goto SEND_SNAP + } else if err != nil { + return + } + + // Make the RPC call + start = time.Now() + if err := r.trans.AppendEntries(s.peer.ID, s.peer.Address, &req, &resp); err != nil { + r.logger.Printf("[ERR] raft: Failed to AppendEntries to %v: %v", s.peer, err) + s.failures++ + return + } + appendStats(string(s.peer.ID), start, float32(len(req.Entries))) + + // Check for a newer term, stop running + if resp.Term > req.Term { + r.handleStaleTerm(s) + return true + } + + // Update the last contact + s.setLastContact() + + // Update s based on success + if resp.Success { + // Update our replication state + updateLastAppended(s, &req) + + // Clear any failures, allow pipelining + s.failures = 0 + s.allowPipeline = true + } else { + s.nextIndex = max(min(s.nextIndex-1, resp.LastLog+1), 1) + if resp.NoRetryBackoff { + s.failures = 0 + } else { + s.failures++ + } + r.logger.Printf("[WARN] raft: AppendEntries to %v rejected, sending older logs (next: %d)", s.peer, s.nextIndex) + } + +CHECK_MORE: + // Poll the stop channel here in case we are looping and have been asked + // to stop, or have stepped down as leader. Even for the best effort case + // where we are asked to replicate to a given index and then shutdown, + // it's better to not loop in here to send lots of entries to a straggler + // that's leaving the cluster anyways. + select { + case <-s.stopCh: + return true + default: + } + + // Check if there are more logs to replicate + if s.nextIndex <= lastIndex { + goto START + } + return + + // SEND_SNAP is used when we fail to get a log, usually because the follower + // is too far behind, and we must ship a snapshot down instead +SEND_SNAP: + if stop, err := r.sendLatestSnapshot(s); stop { + return true + } else if err != nil { + r.logger.Printf("[ERR] raft: Failed to send snapshot to %v: %v", s.peer, err) + return + } + + // Check if there is more to replicate + goto CHECK_MORE +} + +// sendLatestSnapshot is used to send the latest snapshot we have +// down to our follower. +func (r *Raft) sendLatestSnapshot(s *followerReplication) (bool, error) { + // Get the snapshots + snapshots, err := r.snapshots.List() + if err != nil { + r.logger.Printf("[ERR] raft: Failed to list snapshots: %v", err) + return false, err + } + + // Check we have at least a single snapshot + if len(snapshots) == 0 { + return false, fmt.Errorf("no snapshots found") + } + + // Open the most recent snapshot + snapID := snapshots[0].ID + meta, snapshot, err := r.snapshots.Open(snapID) + if err != nil { + r.logger.Printf("[ERR] raft: Failed to open snapshot %v: %v", snapID, err) + return false, err + } + defer snapshot.Close() + + // Setup the request + req := InstallSnapshotRequest{ + RPCHeader: r.getRPCHeader(), + SnapshotVersion: meta.Version, + Term: s.currentTerm, + Leader: r.trans.EncodePeer(r.localID, r.localAddr), + LastLogIndex: meta.Index, + LastLogTerm: meta.Term, + Peers: meta.Peers, + Size: meta.Size, + Configuration: encodeConfiguration(meta.Configuration), + ConfigurationIndex: meta.ConfigurationIndex, + } + + // Make the call + start := time.Now() + var resp InstallSnapshotResponse + if err := r.trans.InstallSnapshot(s.peer.ID, s.peer.Address, &req, &resp, snapshot); err != nil { + r.logger.Printf("[ERR] raft: Failed to install snapshot %v: %v", snapID, err) + s.failures++ + return false, err + } + metrics.MeasureSince([]string{"raft", "replication", "installSnapshot", string(s.peer.ID)}, start) + + // Check for a newer term, stop running + if resp.Term > req.Term { + r.handleStaleTerm(s) + return true, nil + } + + // Update the last contact + s.setLastContact() + + // Check for success + if resp.Success { + // Update the indexes + s.nextIndex = meta.Index + 1 + s.commitment.match(s.peer.ID, meta.Index) + + // Clear any failures + s.failures = 0 + + // Notify we are still leader + s.notifyAll(true) + } else { + s.failures++ + r.logger.Printf("[WARN] raft: InstallSnapshot to %v rejected", s.peer) + } + return false, nil +} + +// heartbeat is used to periodically invoke AppendEntries on a peer +// to ensure they don't time out. This is done async of replicate(), +// since that routine could potentially be blocked on disk IO. +func (r *Raft) heartbeat(s *followerReplication, stopCh chan struct{}) { + var failures uint64 + req := AppendEntriesRequest{ + RPCHeader: r.getRPCHeader(), + Term: s.currentTerm, + Leader: r.trans.EncodePeer(r.localID, r.localAddr), + } + var resp AppendEntriesResponse + for { + // Wait for the next heartbeat interval or forced notify + select { + case <-s.notifyCh: + case <-randomTimeout(r.conf.HeartbeatTimeout / 10): + case <-stopCh: + return + } + + start := time.Now() + if err := r.trans.AppendEntries(s.peer.ID, s.peer.Address, &req, &resp); err != nil { + r.logger.Printf("[ERR] raft: Failed to heartbeat to %v: %v", s.peer.Address, err) + failures++ + select { + case <-time.After(backoff(failureWait, failures, maxFailureScale)): + case <-stopCh: + } + } else { + s.setLastContact() + failures = 0 + metrics.MeasureSince([]string{"raft", "replication", "heartbeat", string(s.peer.ID)}, start) + s.notifyAll(resp.Success) + } + } +} + +// pipelineReplicate is used when we have synchronized our state with the follower, +// and want to switch to a higher performance pipeline mode of replication. +// We only pipeline AppendEntries commands, and if we ever hit an error, we fall +// back to the standard replication which can handle more complex situations. +func (r *Raft) pipelineReplicate(s *followerReplication) error { + // Create a new pipeline + pipeline, err := r.trans.AppendEntriesPipeline(s.peer.ID, s.peer.Address) + if err != nil { + return err + } + defer pipeline.Close() + + // Log start and stop of pipeline + r.logger.Printf("[INFO] raft: pipelining replication to peer %v", s.peer) + defer r.logger.Printf("[INFO] raft: aborting pipeline replication to peer %v", s.peer) + + // Create a shutdown and finish channel + stopCh := make(chan struct{}) + finishCh := make(chan struct{}) + + // Start a dedicated decoder + r.goFunc(func() { r.pipelineDecode(s, pipeline, stopCh, finishCh) }) + + // Start pipeline sends at the last good nextIndex + nextIndex := s.nextIndex + + shouldStop := false +SEND: + for !shouldStop { + select { + case <-finishCh: + break SEND + case maxIndex := <-s.stopCh: + // Make a best effort to replicate up to this index + if maxIndex > 0 { + r.pipelineSend(s, pipeline, &nextIndex, maxIndex) + } + break SEND + case <-s.triggerCh: + lastLogIdx, _ := r.getLastLog() + shouldStop = r.pipelineSend(s, pipeline, &nextIndex, lastLogIdx) + case <-randomTimeout(r.conf.CommitTimeout): + lastLogIdx, _ := r.getLastLog() + shouldStop = r.pipelineSend(s, pipeline, &nextIndex, lastLogIdx) + } + } + + // Stop our decoder, and wait for it to finish + close(stopCh) + select { + case <-finishCh: + case <-r.shutdownCh: + } + return nil +} + +// pipelineSend is used to send data over a pipeline. It is a helper to +// pipelineReplicate. +func (r *Raft) pipelineSend(s *followerReplication, p AppendPipeline, nextIdx *uint64, lastIndex uint64) (shouldStop bool) { + // Create a new append request + req := new(AppendEntriesRequest) + if err := r.setupAppendEntries(s, req, *nextIdx, lastIndex); err != nil { + return true + } + + // Pipeline the append entries + if _, err := p.AppendEntries(req, new(AppendEntriesResponse)); err != nil { + r.logger.Printf("[ERR] raft: Failed to pipeline AppendEntries to %v: %v", s.peer, err) + return true + } + + // Increase the next send log to avoid re-sending old logs + if n := len(req.Entries); n > 0 { + last := req.Entries[n-1] + *nextIdx = last.Index + 1 + } + return false +} + +// pipelineDecode is used to decode the responses of pipelined requests. +func (r *Raft) pipelineDecode(s *followerReplication, p AppendPipeline, stopCh, finishCh chan struct{}) { + defer close(finishCh) + respCh := p.Consumer() + for { + select { + case ready := <-respCh: + req, resp := ready.Request(), ready.Response() + appendStats(string(s.peer.ID), ready.Start(), float32(len(req.Entries))) + + // Check for a newer term, stop running + if resp.Term > req.Term { + r.handleStaleTerm(s) + return + } + + // Update the last contact + s.setLastContact() + + // Abort pipeline if not successful + if !resp.Success { + return + } + + // Update our replication state + updateLastAppended(s, req) + case <-stopCh: + return + } + } +} + +// setupAppendEntries is used to setup an append entries request. +func (r *Raft) setupAppendEntries(s *followerReplication, req *AppendEntriesRequest, nextIndex, lastIndex uint64) error { + req.RPCHeader = r.getRPCHeader() + req.Term = s.currentTerm + req.Leader = r.trans.EncodePeer(r.localID, r.localAddr) + req.LeaderCommitIndex = r.getCommitIndex() + if err := r.setPreviousLog(req, nextIndex); err != nil { + return err + } + if err := r.setNewLogs(req, nextIndex, lastIndex); err != nil { + return err + } + return nil +} + +// setPreviousLog is used to setup the PrevLogEntry and PrevLogTerm for an +// AppendEntriesRequest given the next index to replicate. +func (r *Raft) setPreviousLog(req *AppendEntriesRequest, nextIndex uint64) error { + // Guard for the first index, since there is no 0 log entry + // Guard against the previous index being a snapshot as well + lastSnapIdx, lastSnapTerm := r.getLastSnapshot() + if nextIndex == 1 { + req.PrevLogEntry = 0 + req.PrevLogTerm = 0 + + } else if (nextIndex - 1) == lastSnapIdx { + req.PrevLogEntry = lastSnapIdx + req.PrevLogTerm = lastSnapTerm + + } else { + var l Log + if err := r.logs.GetLog(nextIndex-1, &l); err != nil { + r.logger.Printf("[ERR] raft: Failed to get log at index %d: %v", + nextIndex-1, err) + return err + } + + // Set the previous index and term (0 if nextIndex is 1) + req.PrevLogEntry = l.Index + req.PrevLogTerm = l.Term + } + return nil +} + +// setNewLogs is used to setup the logs which should be appended for a request. +func (r *Raft) setNewLogs(req *AppendEntriesRequest, nextIndex, lastIndex uint64) error { + // Append up to MaxAppendEntries or up to the lastIndex + req.Entries = make([]*Log, 0, r.conf.MaxAppendEntries) + maxIndex := min(nextIndex+uint64(r.conf.MaxAppendEntries)-1, lastIndex) + for i := nextIndex; i <= maxIndex; i++ { + oldLog := new(Log) + if err := r.logs.GetLog(i, oldLog); err != nil { + r.logger.Printf("[ERR] raft: Failed to get log at index %d: %v", i, err) + return err + } + req.Entries = append(req.Entries, oldLog) + } + return nil +} + +// appendStats is used to emit stats about an AppendEntries invocation. +func appendStats(peer string, start time.Time, logs float32) { + metrics.MeasureSince([]string{"raft", "replication", "appendEntries", "rpc", peer}, start) + metrics.IncrCounter([]string{"raft", "replication", "appendEntries", "logs", peer}, logs) +} + +// handleStaleTerm is used when a follower indicates that we have a stale term. +func (r *Raft) handleStaleTerm(s *followerReplication) { + r.logger.Printf("[ERR] raft: peer %v has newer term, stopping replication", s.peer) + s.notifyAll(false) // No longer leader + asyncNotifyCh(s.stepDown) +} + +// updateLastAppended is used to update follower replication state after a +// successful AppendEntries RPC. +// TODO: This isn't used during InstallSnapshot, but the code there is similar. +func updateLastAppended(s *followerReplication, req *AppendEntriesRequest) { + // Mark any inflight logs as committed + if logs := req.Entries; len(logs) > 0 { + last := logs[len(logs)-1] + s.nextIndex = last.Index + 1 + s.commitment.match(s.peer.ID, last.Index) + } + + // Notify still leader + s.notifyAll(true) +} diff --git a/vendor/github.com/hashicorp/raft/snapshot.go b/vendor/github.com/hashicorp/raft/snapshot.go new file mode 100644 index 00000000000..5287ebc4183 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/snapshot.go @@ -0,0 +1,239 @@ +package raft + +import ( + "fmt" + "io" + "time" + + "github.com/armon/go-metrics" +) + +// SnapshotMeta is for metadata of a snapshot. +type SnapshotMeta struct { + // Version is the version number of the snapshot metadata. This does not cover + // the application's data in the snapshot, that should be versioned + // separately. + Version SnapshotVersion + + // ID is opaque to the store, and is used for opening. + ID string + + // Index and Term store when the snapshot was taken. + Index uint64 + Term uint64 + + // Peers is deprecated and used to support version 0 snapshots, but will + // be populated in version 1 snapshots as well to help with upgrades. + Peers []byte + + // Configuration and ConfigurationIndex are present in version 1 + // snapshots and later. + Configuration Configuration + ConfigurationIndex uint64 + + // Size is the size of the snapshot in bytes. + Size int64 +} + +// SnapshotStore interface is used to allow for flexible implementations +// of snapshot storage and retrieval. For example, a client could implement +// a shared state store such as S3, allowing new nodes to restore snapshots +// without streaming from the leader. +type SnapshotStore interface { + // Create is used to begin a snapshot at a given index and term, and with + // the given committed configuration. The version parameter controls + // which snapshot version to create. + Create(version SnapshotVersion, index, term uint64, configuration Configuration, + configurationIndex uint64, trans Transport) (SnapshotSink, error) + + // List is used to list the available snapshots in the store. + // It should return then in descending order, with the highest index first. + List() ([]*SnapshotMeta, error) + + // Open takes a snapshot ID and provides a ReadCloser. Once close is + // called it is assumed the snapshot is no longer needed. + Open(id string) (*SnapshotMeta, io.ReadCloser, error) +} + +// SnapshotSink is returned by StartSnapshot. The FSM will Write state +// to the sink and call Close on completion. On error, Cancel will be invoked. +type SnapshotSink interface { + io.WriteCloser + ID() string + Cancel() error +} + +// runSnapshots is a long running goroutine used to manage taking +// new snapshots of the FSM. It runs in parallel to the FSM and +// main goroutines, so that snapshots do not block normal operation. +func (r *Raft) runSnapshots() { + for { + select { + case <-randomTimeout(r.conf.SnapshotInterval): + // Check if we should snapshot + if !r.shouldSnapshot() { + continue + } + + // Trigger a snapshot + if _, err := r.takeSnapshot(); err != nil { + r.logger.Printf("[ERR] raft: Failed to take snapshot: %v", err) + } + + case future := <-r.userSnapshotCh: + // User-triggered, run immediately + id, err := r.takeSnapshot() + if err != nil { + r.logger.Printf("[ERR] raft: Failed to take snapshot: %v", err) + } else { + future.opener = func() (*SnapshotMeta, io.ReadCloser, error) { + return r.snapshots.Open(id) + } + } + future.respond(err) + + case <-r.shutdownCh: + return + } + } +} + +// shouldSnapshot checks if we meet the conditions to take +// a new snapshot. +func (r *Raft) shouldSnapshot() bool { + // Check the last snapshot index + lastSnap, _ := r.getLastSnapshot() + + // Check the last log index + lastIdx, err := r.logs.LastIndex() + if err != nil { + r.logger.Printf("[ERR] raft: Failed to get last log index: %v", err) + return false + } + + // Compare the delta to the threshold + delta := lastIdx - lastSnap + return delta >= r.conf.SnapshotThreshold +} + +// takeSnapshot is used to take a new snapshot. This must only be called from +// the snapshot thread, never the main thread. This returns the ID of the new +// snapshot, along with an error. +func (r *Raft) takeSnapshot() (string, error) { + defer metrics.MeasureSince([]string{"raft", "snapshot", "takeSnapshot"}, time.Now()) + + // Create a request for the FSM to perform a snapshot. + snapReq := &reqSnapshotFuture{} + snapReq.init() + + // Wait for dispatch or shutdown. + select { + case r.fsmSnapshotCh <- snapReq: + case <-r.shutdownCh: + return "", ErrRaftShutdown + } + + // Wait until we get a response + if err := snapReq.Error(); err != nil { + if err != ErrNothingNewToSnapshot { + err = fmt.Errorf("failed to start snapshot: %v", err) + } + return "", err + } + defer snapReq.snapshot.Release() + + // Make a request for the configurations and extract the committed info. + // We have to use the future here to safely get this information since + // it is owned by the main thread. + configReq := &configurationsFuture{} + configReq.init() + select { + case r.configurationsCh <- configReq: + case <-r.shutdownCh: + return "", ErrRaftShutdown + } + if err := configReq.Error(); err != nil { + return "", err + } + committed := configReq.configurations.committed + committedIndex := configReq.configurations.committedIndex + + // We don't support snapshots while there's a config change outstanding + // since the snapshot doesn't have a means to represent this state. This + // is a little weird because we need the FSM to apply an index that's + // past the configuration change, even though the FSM itself doesn't see + // the configuration changes. It should be ok in practice with normal + // application traffic flowing through the FSM. If there's none of that + // then it's not crucial that we snapshot, since there's not much going + // on Raft-wise. + if snapReq.index < committedIndex { + return "", fmt.Errorf("cannot take snapshot now, wait until the configuration entry at %v has been applied (have applied %v)", + committedIndex, snapReq.index) + } + + // Create a new snapshot. + r.logger.Printf("[INFO] raft: Starting snapshot up to %d", snapReq.index) + start := time.Now() + version := getSnapshotVersion(r.protocolVersion) + sink, err := r.snapshots.Create(version, snapReq.index, snapReq.term, committed, committedIndex, r.trans) + if err != nil { + return "", fmt.Errorf("failed to create snapshot: %v", err) + } + metrics.MeasureSince([]string{"raft", "snapshot", "create"}, start) + + // Try to persist the snapshot. + start = time.Now() + if err := snapReq.snapshot.Persist(sink); err != nil { + sink.Cancel() + return "", fmt.Errorf("failed to persist snapshot: %v", err) + } + metrics.MeasureSince([]string{"raft", "snapshot", "persist"}, start) + + // Close and check for error. + if err := sink.Close(); err != nil { + return "", fmt.Errorf("failed to close snapshot: %v", err) + } + + // Update the last stable snapshot info. + r.setLastSnapshot(snapReq.index, snapReq.term) + + // Compact the logs. + if err := r.compactLogs(snapReq.index); err != nil { + return "", err + } + + r.logger.Printf("[INFO] raft: Snapshot to %d complete", snapReq.index) + return sink.ID(), nil +} + +// compactLogs takes the last inclusive index of a snapshot +// and trims the logs that are no longer needed. +func (r *Raft) compactLogs(snapIdx uint64) error { + defer metrics.MeasureSince([]string{"raft", "compactLogs"}, time.Now()) + // Determine log ranges to compact + minLog, err := r.logs.FirstIndex() + if err != nil { + return fmt.Errorf("failed to get first log index: %v", err) + } + + // Check if we have enough logs to truncate + lastLogIdx, _ := r.getLastLog() + if lastLogIdx <= r.conf.TrailingLogs { + return nil + } + + // Truncate up to the end of the snapshot, or `TrailingLogs` + // back from the head, which ever is further back. This ensures + // at least `TrailingLogs` entries, but does not allow logs + // after the snapshot to be removed. + maxLog := min(snapIdx, lastLogIdx-r.conf.TrailingLogs) + + // Log this + r.logger.Printf("[INFO] raft: Compacting logs from %d to %d", minLog, maxLog) + + // Compact the logs + if err := r.logs.DeleteRange(minLog, maxLog); err != nil { + return fmt.Errorf("log compaction failed: %v", err) + } + return nil +} diff --git a/vendor/github.com/hashicorp/raft/stable.go b/vendor/github.com/hashicorp/raft/stable.go new file mode 100644 index 00000000000..ff59a8c570a --- /dev/null +++ b/vendor/github.com/hashicorp/raft/stable.go @@ -0,0 +1,15 @@ +package raft + +// StableStore is used to provide stable storage +// of key configurations to ensure safety. +type StableStore interface { + Set(key []byte, val []byte) error + + // Get returns the value for key, or an empty byte slice if key was not found. + Get(key []byte) ([]byte, error) + + SetUint64(key []byte, val uint64) error + + // GetUint64 returns the uint64 value for key, or 0 if key was not found. + GetUint64(key []byte) (uint64, error) +} diff --git a/vendor/github.com/hashicorp/raft/state.go b/vendor/github.com/hashicorp/raft/state.go new file mode 100644 index 00000000000..a58cd0d19e6 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/state.go @@ -0,0 +1,171 @@ +package raft + +import ( + "sync" + "sync/atomic" +) + +// RaftState captures the state of a Raft node: Follower, Candidate, Leader, +// or Shutdown. +type RaftState uint32 + +const ( + // Follower is the initial state of a Raft node. + Follower RaftState = iota + + // Candidate is one of the valid states of a Raft node. + Candidate + + // Leader is one of the valid states of a Raft node. + Leader + + // Shutdown is the terminal state of a Raft node. + Shutdown +) + +func (s RaftState) String() string { + switch s { + case Follower: + return "Follower" + case Candidate: + return "Candidate" + case Leader: + return "Leader" + case Shutdown: + return "Shutdown" + default: + return "Unknown" + } +} + +// raftState is used to maintain various state variables +// and provides an interface to set/get the variables in a +// thread safe manner. +type raftState struct { + // currentTerm commitIndex, lastApplied, must be kept at the top of + // the struct so they're 64 bit aligned which is a requirement for + // atomic ops on 32 bit platforms. + + // The current term, cache of StableStore + currentTerm uint64 + + // Highest committed log entry + commitIndex uint64 + + // Last applied log to the FSM + lastApplied uint64 + + // protects 4 next fields + lastLock sync.Mutex + + // Cache the latest snapshot index/term + lastSnapshotIndex uint64 + lastSnapshotTerm uint64 + + // Cache the latest log from LogStore + lastLogIndex uint64 + lastLogTerm uint64 + + // Tracks running goroutines + routinesGroup sync.WaitGroup + + // The current state + state RaftState +} + +func (r *raftState) getState() RaftState { + stateAddr := (*uint32)(&r.state) + return RaftState(atomic.LoadUint32(stateAddr)) +} + +func (r *raftState) setState(s RaftState) { + stateAddr := (*uint32)(&r.state) + atomic.StoreUint32(stateAddr, uint32(s)) +} + +func (r *raftState) getCurrentTerm() uint64 { + return atomic.LoadUint64(&r.currentTerm) +} + +func (r *raftState) setCurrentTerm(term uint64) { + atomic.StoreUint64(&r.currentTerm, term) +} + +func (r *raftState) getLastLog() (index, term uint64) { + r.lastLock.Lock() + index = r.lastLogIndex + term = r.lastLogTerm + r.lastLock.Unlock() + return +} + +func (r *raftState) setLastLog(index, term uint64) { + r.lastLock.Lock() + r.lastLogIndex = index + r.lastLogTerm = term + r.lastLock.Unlock() +} + +func (r *raftState) getLastSnapshot() (index, term uint64) { + r.lastLock.Lock() + index = r.lastSnapshotIndex + term = r.lastSnapshotTerm + r.lastLock.Unlock() + return +} + +func (r *raftState) setLastSnapshot(index, term uint64) { + r.lastLock.Lock() + r.lastSnapshotIndex = index + r.lastSnapshotTerm = term + r.lastLock.Unlock() +} + +func (r *raftState) getCommitIndex() uint64 { + return atomic.LoadUint64(&r.commitIndex) +} + +func (r *raftState) setCommitIndex(index uint64) { + atomic.StoreUint64(&r.commitIndex, index) +} + +func (r *raftState) getLastApplied() uint64 { + return atomic.LoadUint64(&r.lastApplied) +} + +func (r *raftState) setLastApplied(index uint64) { + atomic.StoreUint64(&r.lastApplied, index) +} + +// Start a goroutine and properly handle the race between a routine +// starting and incrementing, and exiting and decrementing. +func (r *raftState) goFunc(f func()) { + r.routinesGroup.Add(1) + go func() { + defer r.routinesGroup.Done() + f() + }() +} + +func (r *raftState) waitShutdown() { + r.routinesGroup.Wait() +} + +// getLastIndex returns the last index in stable storage. +// Either from the last log or from the last snapshot. +func (r *raftState) getLastIndex() uint64 { + r.lastLock.Lock() + defer r.lastLock.Unlock() + return max(r.lastLogIndex, r.lastSnapshotIndex) +} + +// getLastEntry returns the last index and term in stable storage. +// Either from the last log or from the last snapshot. +func (r *raftState) getLastEntry() (uint64, uint64) { + r.lastLock.Lock() + defer r.lastLock.Unlock() + if r.lastLogIndex >= r.lastSnapshotIndex { + return r.lastLogIndex, r.lastLogTerm + } + return r.lastSnapshotIndex, r.lastSnapshotTerm +} diff --git a/vendor/github.com/hashicorp/raft/tcp_transport.go b/vendor/github.com/hashicorp/raft/tcp_transport.go new file mode 100644 index 00000000000..29b2740f624 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/tcp_transport.go @@ -0,0 +1,116 @@ +package raft + +import ( + "errors" + "io" + "log" + "net" + "time" +) + +var ( + errNotAdvertisable = errors.New("local bind address is not advertisable") + errNotTCP = errors.New("local address is not a TCP address") +) + +// TCPStreamLayer implements StreamLayer interface for plain TCP. +type TCPStreamLayer struct { + advertise net.Addr + listener *net.TCPListener +} + +// NewTCPTransport returns a NetworkTransport that is built on top of +// a TCP streaming transport layer. +func NewTCPTransport( + bindAddr string, + advertise net.Addr, + maxPool int, + timeout time.Duration, + logOutput io.Writer, +) (*NetworkTransport, error) { + return newTCPTransport(bindAddr, advertise, func(stream StreamLayer) *NetworkTransport { + return NewNetworkTransport(stream, maxPool, timeout, logOutput) + }) +} + +// NewTCPTransportWithLogger returns a NetworkTransport that is built on top of +// a TCP streaming transport layer, with log output going to the supplied Logger +func NewTCPTransportWithLogger( + bindAddr string, + advertise net.Addr, + maxPool int, + timeout time.Duration, + logger *log.Logger, +) (*NetworkTransport, error) { + return newTCPTransport(bindAddr, advertise, func(stream StreamLayer) *NetworkTransport { + return NewNetworkTransportWithLogger(stream, maxPool, timeout, logger) + }) +} + +// NewTCPTransportWithLogger returns a NetworkTransport that is built on top of +// a TCP streaming transport layer, using a default logger and the address provider +func NewTCPTransportWithConfig( + bindAddr string, + advertise net.Addr, + config *NetworkTransportConfig, +) (*NetworkTransport, error) { + return newTCPTransport(bindAddr, advertise, func(stream StreamLayer) *NetworkTransport { + config.Stream = stream + return NewNetworkTransportWithConfig(config) + }) +} + +func newTCPTransport(bindAddr string, + advertise net.Addr, + transportCreator func(stream StreamLayer) *NetworkTransport) (*NetworkTransport, error) { + // Try to bind + list, err := net.Listen("tcp", bindAddr) + if err != nil { + return nil, err + } + + // Create stream + stream := &TCPStreamLayer{ + advertise: advertise, + listener: list.(*net.TCPListener), + } + + // Verify that we have a usable advertise address + addr, ok := stream.Addr().(*net.TCPAddr) + if !ok { + list.Close() + return nil, errNotTCP + } + if addr.IP.IsUnspecified() { + list.Close() + return nil, errNotAdvertisable + } + + // Create the network transport + trans := transportCreator(stream) + return trans, nil +} + +// Dial implements the StreamLayer interface. +func (t *TCPStreamLayer) Dial(address ServerAddress, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("tcp", string(address), timeout) +} + +// Accept implements the net.Listener interface. +func (t *TCPStreamLayer) Accept() (c net.Conn, err error) { + return t.listener.Accept() +} + +// Close implements the net.Listener interface. +func (t *TCPStreamLayer) Close() (err error) { + return t.listener.Close() +} + +// Addr implements the net.Listener interface. +func (t *TCPStreamLayer) Addr() net.Addr { + // Use an advertise addr if provided + if t.advertise != nil { + return t.advertise + } + return t.listener.Addr() +} diff --git a/vendor/github.com/hashicorp/raft/transport.go b/vendor/github.com/hashicorp/raft/transport.go new file mode 100644 index 00000000000..85459b221d1 --- /dev/null +++ b/vendor/github.com/hashicorp/raft/transport.go @@ -0,0 +1,124 @@ +package raft + +import ( + "io" + "time" +) + +// RPCResponse captures both a response and a potential error. +type RPCResponse struct { + Response interface{} + Error error +} + +// RPC has a command, and provides a response mechanism. +type RPC struct { + Command interface{} + Reader io.Reader // Set only for InstallSnapshot + RespChan chan<- RPCResponse +} + +// Respond is used to respond with a response, error or both +func (r *RPC) Respond(resp interface{}, err error) { + r.RespChan <- RPCResponse{resp, err} +} + +// Transport provides an interface for network transports +// to allow Raft to communicate with other nodes. +type Transport interface { + // Consumer returns a channel that can be used to + // consume and respond to RPC requests. + Consumer() <-chan RPC + + // LocalAddr is used to return our local address to distinguish from our peers. + LocalAddr() ServerAddress + + // AppendEntriesPipeline returns an interface that can be used to pipeline + // AppendEntries requests. + AppendEntriesPipeline(id ServerID, target ServerAddress) (AppendPipeline, error) + + // AppendEntries sends the appropriate RPC to the target node. + AppendEntries(id ServerID, target ServerAddress, args *AppendEntriesRequest, resp *AppendEntriesResponse) error + + // RequestVote sends the appropriate RPC to the target node. + RequestVote(id ServerID, target ServerAddress, args *RequestVoteRequest, resp *RequestVoteResponse) error + + // InstallSnapshot is used to push a snapshot down to a follower. The data is read from + // the ReadCloser and streamed to the client. + InstallSnapshot(id ServerID, target ServerAddress, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error + + // EncodePeer is used to serialize a peer's address. + EncodePeer(id ServerID, addr ServerAddress) []byte + + // DecodePeer is used to deserialize a peer's address. + DecodePeer([]byte) ServerAddress + + // SetHeartbeatHandler is used to setup a heartbeat handler + // as a fast-pass. This is to avoid head-of-line blocking from + // disk IO. If a Transport does not support this, it can simply + // ignore the call, and push the heartbeat onto the Consumer channel. + SetHeartbeatHandler(cb func(rpc RPC)) +} + +// WithClose is an interface that a transport may provide which +// allows a transport to be shut down cleanly when a Raft instance +// shuts down. +// +// It is defined separately from Transport as unfortunately it wasn't in the +// original interface specification. +type WithClose interface { + // Close permanently closes a transport, stopping + // any associated goroutines and freeing other resources. + Close() error +} + +// LoopbackTransport is an interface that provides a loopback transport suitable for testing +// e.g. InmemTransport. It's there so we don't have to rewrite tests. +type LoopbackTransport interface { + Transport // Embedded transport reference + WithPeers // Embedded peer management + WithClose // with a close routine +} + +// WithPeers is an interface that a transport may provide which allows for connection and +// disconnection. Unless the transport is a loopback transport, the transport specified to +// "Connect" is likely to be nil. +type WithPeers interface { + Connect(peer ServerAddress, t Transport) // Connect a peer + Disconnect(peer ServerAddress) // Disconnect a given peer + DisconnectAll() // Disconnect all peers, possibly to reconnect them later +} + +// AppendPipeline is used for pipelining AppendEntries requests. It is used +// to increase the replication throughput by masking latency and better +// utilizing bandwidth. +type AppendPipeline interface { + // AppendEntries is used to add another request to the pipeline. + // The send may block which is an effective form of back-pressure. + AppendEntries(args *AppendEntriesRequest, resp *AppendEntriesResponse) (AppendFuture, error) + + // Consumer returns a channel that can be used to consume + // response futures when they are ready. + Consumer() <-chan AppendFuture + + // Close closes the pipeline and cancels all inflight RPCs + Close() error +} + +// AppendFuture is used to return information about a pipelined AppendEntries request. +type AppendFuture interface { + Future + + // Start returns the time that the append request was started. + // It is always OK to call this method. + Start() time.Time + + // Request holds the parameters of the AppendEntries call. + // It is always OK to call this method. + Request() *AppendEntriesRequest + + // Response holds the results of the AppendEntries call. + // This method must only be called after the Error + // method returns, and will only be valid on success. + Response() *AppendEntriesResponse +} diff --git a/vendor/github.com/hashicorp/raft/util.go b/vendor/github.com/hashicorp/raft/util.go new file mode 100644 index 00000000000..90428d7437e --- /dev/null +++ b/vendor/github.com/hashicorp/raft/util.go @@ -0,0 +1,133 @@ +package raft + +import ( + "bytes" + crand "crypto/rand" + "fmt" + "math" + "math/big" + "math/rand" + "time" + + "github.com/hashicorp/go-msgpack/codec" +) + +func init() { + // Ensure we use a high-entropy seed for the psuedo-random generator + rand.Seed(newSeed()) +} + +// returns an int64 from a crypto random source +// can be used to seed a source for a math/rand. +func newSeed() int64 { + r, err := crand.Int(crand.Reader, big.NewInt(math.MaxInt64)) + if err != nil { + panic(fmt.Errorf("failed to read random bytes: %v", err)) + } + return r.Int64() +} + +// randomTimeout returns a value that is between the minVal and 2x minVal. +func randomTimeout(minVal time.Duration) <-chan time.Time { + if minVal == 0 { + return nil + } + extra := (time.Duration(rand.Int63()) % minVal) + return time.After(minVal + extra) +} + +// min returns the minimum. +func min(a, b uint64) uint64 { + if a <= b { + return a + } + return b +} + +// max returns the maximum. +func max(a, b uint64) uint64 { + if a >= b { + return a + } + return b +} + +// generateUUID is used to generate a random UUID. +func generateUUID() string { + buf := make([]byte, 16) + if _, err := crand.Read(buf); err != nil { + panic(fmt.Errorf("failed to read random bytes: %v", err)) + } + + return fmt.Sprintf("%08x-%04x-%04x-%04x-%12x", + buf[0:4], + buf[4:6], + buf[6:8], + buf[8:10], + buf[10:16]) +} + +// asyncNotifyCh is used to do an async channel send +// to a single channel without blocking. +func asyncNotifyCh(ch chan struct{}) { + select { + case ch <- struct{}{}: + default: + } +} + +// drainNotifyCh empties out a single-item notification channel without +// blocking, and returns whether it received anything. +func drainNotifyCh(ch chan struct{}) bool { + select { + case <-ch: + return true + default: + return false + } +} + +// asyncNotifyBool is used to do an async notification +// on a bool channel. +func asyncNotifyBool(ch chan bool, v bool) { + select { + case ch <- v: + default: + } +} + +// Decode reverses the encode operation on a byte slice input. +func decodeMsgPack(buf []byte, out interface{}) error { + r := bytes.NewBuffer(buf) + hd := codec.MsgpackHandle{} + dec := codec.NewDecoder(r, &hd) + return dec.Decode(out) +} + +// Encode writes an encoded object to a new bytes buffer. +func encodeMsgPack(in interface{}) (*bytes.Buffer, error) { + buf := bytes.NewBuffer(nil) + hd := codec.MsgpackHandle{} + enc := codec.NewEncoder(buf, &hd) + err := enc.Encode(in) + return buf, err +} + +// backoff is used to compute an exponential backoff +// duration. Base time is scaled by the current round, +// up to some maximum scale factor. +func backoff(base time.Duration, round, limit uint64) time.Duration { + power := min(round, limit) + for power > 2 { + base *= 2 + power-- + } + return base +} + +// Needed for sorting []uint64, used to determine commitment +type uint64Slice []uint64 + +func (p uint64Slice) Len() int { return len(p) } +func (p uint64Slice) Less(i, j int) bool { return p[i] < p[j] } +func (p uint64Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } diff --git a/vendor/github.com/nats-io/gnatsd/LICENSE b/vendor/github.com/nats-io/gnatsd/LICENSE new file mode 100644 index 00000000000..261eeb9e9f8 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/nats-io/gnatsd/conf/lex.go b/vendor/github.com/nats-io/gnatsd/conf/lex.go new file mode 100644 index 00000000000..f9603a992aa --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/conf/lex.go @@ -0,0 +1,1141 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Customized heavily from +// https://github.com/BurntSushi/toml/blob/master/lex.go, which is based on +// Rob Pike's talk: http://cuddle.googlecode.com/hg/talk/lex.html + +// The format supported is less restrictive than today's formats. +// Supports mixed Arrays [], nested Maps {}, multiple comment types (# and //) +// Also supports key value assigments using '=' or ':' or whiteSpace() +// e.g. foo = 2, foo : 2, foo 2 +// maps can be assigned with no key separator as well +// semicolons as value terminators in key/value assignments are optional +// +// see lex_test.go for more examples. + +package conf + +import ( + "encoding/hex" + "fmt" + "strings" + "unicode" + "unicode/utf8" +) + +type itemType int + +const ( + itemError itemType = iota + itemNIL // used in the parser to indicate no type + itemEOF + itemKey + itemText + itemString + itemBool + itemInteger + itemFloat + itemDatetime + itemArrayStart + itemArrayEnd + itemMapStart + itemMapEnd + itemCommentStart + itemVariable + itemInclude +) + +const ( + eof = 0 + mapStart = '{' + mapEnd = '}' + keySepEqual = '=' + keySepColon = ':' + arrayStart = '[' + arrayEnd = ']' + arrayValTerm = ',' + mapValTerm = ',' + commentHashStart = '#' + commentSlashStart = '/' + dqStringStart = '"' + dqStringEnd = '"' + sqStringStart = '\'' + sqStringEnd = '\'' + optValTerm = ';' + topOptStart = '{' + topOptValTerm = ',' + topOptTerm = '}' + blockStart = '(' + blockEnd = ')' +) + +type stateFn func(lx *lexer) stateFn + +type lexer struct { + input string + start int + pos int + width int + line int + state stateFn + items chan item + + // A stack of state functions used to maintain context. + // The idea is to reuse parts of the state machine in various places. + // For example, values can appear at the top level or within arbitrarily + // nested arrays. The last state on the stack is used after a value has + // been lexed. Similarly for comments. + stack []stateFn + + // Used for processing escapable substrings in double-quoted and raw strings + stringParts []string + stringStateFn stateFn +} + +type item struct { + typ itemType + val string + line int +} + +func (lx *lexer) nextItem() item { + for { + select { + case item := <-lx.items: + return item + default: + lx.state = lx.state(lx) + } + } +} + +func lex(input string) *lexer { + lx := &lexer{ + input: input, + state: lexTop, + line: 1, + items: make(chan item, 10), + stack: make([]stateFn, 0, 10), + stringParts: []string{}, + } + return lx +} + +func (lx *lexer) push(state stateFn) { + lx.stack = append(lx.stack, state) +} + +func (lx *lexer) pop() stateFn { + if len(lx.stack) == 0 { + return lx.errorf("BUG in lexer: no states to pop.") + } + li := len(lx.stack) - 1 + last := lx.stack[li] + lx.stack = lx.stack[0:li] + return last +} + +func (lx *lexer) emit(typ itemType) { + lx.items <- item{typ, strings.Join(lx.stringParts, "") + lx.input[lx.start:lx.pos], lx.line} + lx.start = lx.pos +} + +func (lx *lexer) emitString() { + var finalString string + if len(lx.stringParts) > 0 { + finalString = strings.Join(lx.stringParts, "") + lx.input[lx.start:lx.pos] + lx.stringParts = []string{} + } else { + finalString = lx.input[lx.start:lx.pos] + } + lx.items <- item{itemString, finalString, lx.line} + lx.start = lx.pos +} + +func (lx *lexer) addCurrentStringPart(offset int) { + lx.stringParts = append(lx.stringParts, lx.input[lx.start:lx.pos-offset]) + lx.start = lx.pos +} + +func (lx *lexer) addStringPart(s string) stateFn { + lx.stringParts = append(lx.stringParts, s) + lx.start = lx.pos + return lx.stringStateFn +} + +func (lx *lexer) hasEscapedParts() bool { + return len(lx.stringParts) > 0 +} + +func (lx *lexer) next() (r rune) { + if lx.pos >= len(lx.input) { + lx.width = 0 + return eof + } + + if lx.input[lx.pos] == '\n' { + lx.line++ + } + r, lx.width = utf8.DecodeRuneInString(lx.input[lx.pos:]) + lx.pos += lx.width + return r +} + +// ignore skips over the pending input before this point. +func (lx *lexer) ignore() { + lx.start = lx.pos +} + +// backup steps back one rune. Can be called only once per call of next. +func (lx *lexer) backup() { + lx.pos -= lx.width + if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' { + lx.line-- + } +} + +// peek returns but does not consume the next rune in the input. +func (lx *lexer) peek() rune { + r := lx.next() + lx.backup() + return r +} + +// errorf stops all lexing by emitting an error and returning `nil`. +// Note that any value that is a character is escaped if it's a special +// character (new lines, tabs, etc.). +func (lx *lexer) errorf(format string, values ...interface{}) stateFn { + for i, value := range values { + if v, ok := value.(rune); ok { + values[i] = escapeSpecial(v) + } + } + lx.items <- item{ + itemError, + fmt.Sprintf(format, values...), + lx.line, + } + return nil +} + +// lexTop consumes elements at the top level of data structure. +func lexTop(lx *lexer) stateFn { + r := lx.next() + if unicode.IsSpace(r) { + return lexSkip(lx, lexTop) + } + + switch r { + case topOptStart: + return lexSkip(lx, lexTop) + case commentHashStart: + lx.push(lexTop) + return lexCommentStart + case commentSlashStart: + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexTop) + return lexCommentStart + } + lx.backup() + fallthrough + case eof: + if lx.pos > lx.start { + return lx.errorf("Unexpected EOF.") + } + lx.emit(itemEOF) + return nil + } + + // At this point, the only valid item can be a key, so we back up + // and let the key lexer do the rest. + lx.backup() + lx.push(lexTopValueEnd) + return lexKeyStart +} + +// lexTopValueEnd is entered whenever a top-level value has been consumed. +// It must see only whitespace, and will turn back to lexTop upon a new line. +// If it sees EOF, it will quit the lexer successfully. +func lexTopValueEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case r == commentHashStart: + // a comment will read to a new line for us. + lx.push(lexTop) + return lexCommentStart + case r == commentSlashStart: + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexTop) + return lexCommentStart + } + lx.backup() + fallthrough + case isWhitespace(r): + return lexTopValueEnd + case isNL(r) || r == eof || r == optValTerm || r == topOptValTerm || r == topOptTerm: + lx.ignore() + return lexTop + } + return lx.errorf("Expected a top-level value to end with a new line, "+ + "comment or EOF, but got '%v' instead.", r) +} + +// lexKeyStart consumes a key name up until the first non-whitespace character. +// lexKeyStart will ignore whitespace. It will also eat enclosing quotes. +func lexKeyStart(lx *lexer) stateFn { + r := lx.peek() + switch { + case isKeySeparator(r): + return lx.errorf("Unexpected key separator '%v'", r) + case unicode.IsSpace(r): + lx.next() + return lexSkip(lx, lexKeyStart) + case r == dqStringStart: + lx.next() + return lexSkip(lx, lexDubQuotedKey) + case r == sqStringStart: + lx.next() + return lexSkip(lx, lexQuotedKey) + } + lx.ignore() + lx.next() + return lexKey +} + +// lexDubQuotedKey consumes the text of a key between quotes. +func lexDubQuotedKey(lx *lexer) stateFn { + r := lx.peek() + if r == dqStringEnd { + lx.emit(itemKey) + lx.next() + return lexSkip(lx, lexKeyEnd) + } + lx.next() + return lexDubQuotedKey +} + +// lexQuotedKey consumes the text of a key between quotes. +func lexQuotedKey(lx *lexer) stateFn { + r := lx.peek() + if r == sqStringEnd { + lx.emit(itemKey) + lx.next() + return lexSkip(lx, lexKeyEnd) + } + lx.next() + return lexQuotedKey +} + +// keyCheckKeyword will check for reserved keywords as the key value when the key is +// separated with a space. +func (lx *lexer) keyCheckKeyword(fallThrough, push stateFn) stateFn { + key := strings.ToLower(lx.input[lx.start:lx.pos]) + switch key { + case "include": + lx.ignore() + if push != nil { + lx.push(push) + } + return lexIncludeStart + } + lx.emit(itemKey) + return fallThrough +} + +// lexIncludeStart will consume the whitespace til the start of the value. +func lexIncludeStart(lx *lexer) stateFn { + r := lx.next() + if isWhitespace(r) { + return lexSkip(lx, lexIncludeStart) + } + lx.backup() + return lexInclude +} + +// lexIncludeQuotedString consumes the inner contents of a string. It assumes that the +// beginning '"' has already been consumed and ignored. It will not interpret any +// internal contents. +func lexIncludeQuotedString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == sqStringEnd: + lx.backup() + lx.emit(itemInclude) + lx.next() + lx.ignore() + return lx.pop() + } + return lexIncludeQuotedString +} + +// lexIncludeDubQuotedString consumes the inner contents of a string. It assumes that the +// beginning '"' has already been consumed and ignored. It will not interpret any +// internal contents. +func lexIncludeDubQuotedString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == dqStringEnd: + lx.backup() + lx.emit(itemInclude) + lx.next() + lx.ignore() + return lx.pop() + } + return lexIncludeDubQuotedString +} + +// lexIncludeString consumes the inner contents of a raw string. +func lexIncludeString(lx *lexer) stateFn { + r := lx.next() + switch { + case isNL(r) || r == eof || r == optValTerm || r == mapEnd || isWhitespace(r): + lx.backup() + lx.emit(itemInclude) + return lx.pop() + case r == sqStringEnd: + lx.backup() + lx.emit(itemInclude) + lx.next() + lx.ignore() + return lx.pop() + } + return lexIncludeString +} + +// lexInclude will consume the include value. +func lexInclude(lx *lexer) stateFn { + r := lx.next() + switch { + case r == sqStringStart: + lx.ignore() // ignore the " or ' + return lexIncludeQuotedString + case r == dqStringStart: + lx.ignore() // ignore the " or ' + return lexIncludeDubQuotedString + case r == arrayStart: + return lx.errorf("Expected include value but found start of an array") + case r == mapStart: + return lx.errorf("Expected include value but found start of a map") + case r == blockStart: + return lx.errorf("Expected include value but found start of a block") + case unicode.IsDigit(r), r == '-': + return lx.errorf("Expected include value but found start of a number") + case r == '\\': + return lx.errorf("Expected include value but found escape sequence") + case isNL(r): + return lx.errorf("Expected include value but found new line") + } + lx.backup() + return lexIncludeString +} + +// lexKey consumes the text of a key. Assumes that the first character (which +// is not whitespace) has already been consumed. +func lexKey(lx *lexer) stateFn { + r := lx.peek() + if unicode.IsSpace(r) { + // Spaces signal we could be looking at a keyword, e.g. include. + // Keywords will eat the keyword and set the appropriate return stateFn. + return lx.keyCheckKeyword(lexKeyEnd, nil) + } else if isKeySeparator(r) || r == eof { + lx.emit(itemKey) + return lexKeyEnd + } + lx.next() + return lexKey +} + +// lexKeyEnd consumes the end of a key (up to the key separator). +// Assumes that the first whitespace character after a key (or the '=' or ':' +// separator) has NOT been consumed. +func lexKeyEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case unicode.IsSpace(r): + return lexSkip(lx, lexKeyEnd) + case isKeySeparator(r): + return lexSkip(lx, lexValue) + case r == eof: + lx.emit(itemEOF) + return nil + } + // We start the value here + lx.backup() + return lexValue +} + +// lexValue starts the consumption of a value anywhere a value is expected. +// lexValue will ignore whitespace. +// After a value is lexed, the last state on the next is popped and returned. +func lexValue(lx *lexer) stateFn { + // We allow whitespace to precede a value, but NOT new lines. + // In array syntax, the array states are responsible for ignoring new lines. + r := lx.next() + if isWhitespace(r) { + return lexSkip(lx, lexValue) + } + + switch { + case r == arrayStart: + lx.ignore() + lx.emit(itemArrayStart) + return lexArrayValue + case r == mapStart: + lx.ignore() + lx.emit(itemMapStart) + return lexMapKeyStart + case r == sqStringStart: + lx.ignore() // ignore the " or ' + return lexQuotedString + case r == dqStringStart: + lx.ignore() // ignore the " or ' + lx.stringStateFn = lexDubQuotedString + return lexDubQuotedString + case r == '-': + return lexNegNumberStart + case r == blockStart: + lx.ignore() + return lexBlock + case unicode.IsDigit(r): + lx.backup() // avoid an extra state and use the same as above + return lexNumberOrDateOrIPStart + case r == '.': // special error case, be kind to users + return lx.errorf("Floats must start with a digit") + case isNL(r): + return lx.errorf("Expected value but found new line") + } + lx.backup() + lx.stringStateFn = lexString + return lexString +} + +// lexArrayValue consumes one value in an array. It assumes that '[' or ',' +// have already been consumed. All whitespace and new lines are ignored. +func lexArrayValue(lx *lexer) stateFn { + r := lx.next() + switch { + case unicode.IsSpace(r): + return lexSkip(lx, lexArrayValue) + case r == commentHashStart: + lx.push(lexArrayValue) + return lexCommentStart + case r == commentSlashStart: + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexArrayValue) + return lexCommentStart + } + lx.backup() + fallthrough + case r == arrayValTerm: + return lx.errorf("Unexpected array value terminator '%v'.", arrayValTerm) + case r == arrayEnd: + return lexArrayEnd + } + + lx.backup() + lx.push(lexArrayValueEnd) + return lexValue +} + +// lexArrayValueEnd consumes the cruft between values of an array. Namely, +// it ignores whitespace and expects either a ',' or a ']'. +func lexArrayValueEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case isWhitespace(r): + return lexSkip(lx, lexArrayValueEnd) + case r == commentHashStart: + lx.push(lexArrayValueEnd) + return lexCommentStart + case r == commentSlashStart: + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexArrayValueEnd) + return lexCommentStart + } + lx.backup() + fallthrough + case r == arrayValTerm || isNL(r): + return lexSkip(lx, lexArrayValue) // Move onto next + case r == arrayEnd: + return lexArrayEnd + } + return lx.errorf("Expected an array value terminator %q or an array "+ + "terminator %q, but got '%v' instead.", arrayValTerm, arrayEnd, r) +} + +// lexArrayEnd finishes the lexing of an array. It assumes that a ']' has +// just been consumed. +func lexArrayEnd(lx *lexer) stateFn { + lx.ignore() + lx.emit(itemArrayEnd) + return lx.pop() +} + +// lexMapKeyStart consumes a key name up until the first non-whitespace +// character. +// lexMapKeyStart will ignore whitespace. +func lexMapKeyStart(lx *lexer) stateFn { + r := lx.peek() + switch { + case isKeySeparator(r): + return lx.errorf("Unexpected key separator '%v'.", r) + case unicode.IsSpace(r): + lx.next() + return lexSkip(lx, lexMapKeyStart) + case r == mapEnd: + lx.next() + return lexSkip(lx, lexMapEnd) + case r == commentHashStart: + lx.next() + lx.push(lexMapKeyStart) + return lexCommentStart + case r == commentSlashStart: + lx.next() + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexMapKeyStart) + return lexCommentStart + } + lx.backup() + case r == sqStringStart: + lx.next() + return lexSkip(lx, lexMapQuotedKey) + case r == dqStringStart: + lx.next() + return lexSkip(lx, lexMapDubQuotedKey) + } + lx.ignore() + lx.next() + return lexMapKey +} + +// lexMapQuotedKey consumes the text of a key between quotes. +func lexMapQuotedKey(lx *lexer) stateFn { + r := lx.peek() + if r == sqStringEnd { + lx.emit(itemKey) + lx.next() + return lexSkip(lx, lexMapKeyEnd) + } + lx.next() + return lexMapQuotedKey +} + +// lexMapQuotedKey consumes the text of a key between quotes. +func lexMapDubQuotedKey(lx *lexer) stateFn { + r := lx.peek() + if r == dqStringEnd { + lx.emit(itemKey) + lx.next() + return lexSkip(lx, lexMapKeyEnd) + } + lx.next() + return lexMapDubQuotedKey +} + +// lexMapKey consumes the text of a key. Assumes that the first character (which +// is not whitespace) has already been consumed. +func lexMapKey(lx *lexer) stateFn { + r := lx.peek() + if unicode.IsSpace(r) { + // Spaces signal we could be looking at a keyword, e.g. include. + // Keywords will eat the keyword and set the appropriate return stateFn. + return lx.keyCheckKeyword(lexMapKeyEnd, lexMapValueEnd) + } else if isKeySeparator(r) { + lx.emit(itemKey) + return lexMapKeyEnd + } + lx.next() + return lexMapKey +} + +// lexMapKeyEnd consumes the end of a key (up to the key separator). +// Assumes that the first whitespace character after a key (or the '=' +// separator) has NOT been consumed. +func lexMapKeyEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case unicode.IsSpace(r): + return lexSkip(lx, lexMapKeyEnd) + case isKeySeparator(r): + return lexSkip(lx, lexMapValue) + } + // We start the value here + lx.backup() + return lexMapValue +} + +// lexMapValue consumes one value in a map. It assumes that '{' or ',' +// have already been consumed. All whitespace and new lines are ignored. +// Map values can be separated by ',' or simple NLs. +func lexMapValue(lx *lexer) stateFn { + r := lx.next() + switch { + case unicode.IsSpace(r): + return lexSkip(lx, lexMapValue) + case r == mapValTerm: + return lx.errorf("Unexpected map value terminator %q.", mapValTerm) + case r == mapEnd: + return lexSkip(lx, lexMapEnd) + } + lx.backup() + lx.push(lexMapValueEnd) + return lexValue +} + +// lexMapValueEnd consumes the cruft between values of a map. Namely, +// it ignores whitespace and expects either a ',' or a '}'. +func lexMapValueEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case isWhitespace(r): + return lexSkip(lx, lexMapValueEnd) + case r == commentHashStart: + lx.push(lexMapValueEnd) + return lexCommentStart + case r == commentSlashStart: + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexMapValueEnd) + return lexCommentStart + } + lx.backup() + fallthrough + case r == optValTerm || r == mapValTerm || isNL(r): + return lexSkip(lx, lexMapKeyStart) // Move onto next + case r == mapEnd: + return lexSkip(lx, lexMapEnd) + } + return lx.errorf("Expected a map value terminator %q or a map "+ + "terminator %q, but got '%v' instead.", mapValTerm, mapEnd, r) +} + +// lexMapEnd finishes the lexing of a map. It assumes that a '}' has +// just been consumed. +func lexMapEnd(lx *lexer) stateFn { + lx.ignore() + lx.emit(itemMapEnd) + return lx.pop() +} + +// Checks if the unquoted string was actually a boolean +func (lx *lexer) isBool() bool { + str := strings.ToLower(lx.input[lx.start:lx.pos]) + return str == "true" || str == "false" || + str == "on" || str == "off" || + str == "yes" || str == "no" +} + +// Check if the unquoted string is a variable reference, starting with $. +func (lx *lexer) isVariable() bool { + if lx.input[lx.start] == '$' { + lx.start += 1 + return true + } + return false +} + +// lexQuotedString consumes the inner contents of a string. It assumes that the +// beginning '"' has already been consumed and ignored. It will not interpret any +// internal contents. +func lexQuotedString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == sqStringEnd: + lx.backup() + lx.emit(itemString) + lx.next() + lx.ignore() + return lx.pop() + } + return lexQuotedString +} + +// lexDubQuotedString consumes the inner contents of a string. It assumes that the +// beginning '"' has already been consumed and ignored. It will not interpret any +// internal contents. +func lexDubQuotedString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == '\\': + lx.addCurrentStringPart(1) + return lexStringEscape + case r == dqStringEnd: + lx.backup() + lx.emitString() + lx.next() + lx.ignore() + return lx.pop() + } + return lexDubQuotedString +} + +// lexString consumes the inner contents of a raw string. +func lexString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == '\\': + lx.addCurrentStringPart(1) + return lexStringEscape + // Termination of non-quoted strings + case isNL(r) || r == eof || r == optValTerm || + r == arrayValTerm || r == arrayEnd || r == mapEnd || + isWhitespace(r): + + lx.backup() + if lx.hasEscapedParts() { + lx.emitString() + } else if lx.isBool() { + lx.emit(itemBool) + } else if lx.isVariable() { + lx.emit(itemVariable) + } else { + lx.emitString() + } + return lx.pop() + case r == sqStringEnd: + lx.backup() + lx.emitString() + lx.next() + lx.ignore() + return lx.pop() + } + return lexString +} + +// lexBlock consumes the inner contents as a string. It assumes that the +// beginning '(' has already been consumed and ignored. It will continue +// processing until it finds a ')' on a new line by itself. +func lexBlock(lx *lexer) stateFn { + r := lx.next() + switch { + case r == blockEnd: + lx.backup() + lx.backup() + + // Looking for a ')' character on a line by itself, if the previous + // character isn't a new line, then break so we keep processing the block. + if lx.next() != '\n' { + lx.next() + break + } + lx.next() + + // Make sure the next character is a new line or an eof. We want a ')' on a + // bare line by itself. + switch lx.next() { + case '\n', eof: + lx.backup() + lx.backup() + lx.emit(itemString) + lx.next() + lx.ignore() + return lx.pop() + } + lx.backup() + } + return lexBlock +} + +// lexStringEscape consumes an escaped character. It assumes that the preceding +// '\\' has already been consumed. +func lexStringEscape(lx *lexer) stateFn { + r := lx.next() + switch r { + case 'x': + return lexStringBinary + case 't': + return lx.addStringPart("\t") + case 'n': + return lx.addStringPart("\n") + case 'r': + return lx.addStringPart("\r") + case '"': + return lx.addStringPart("\"") + case '\\': + return lx.addStringPart("\\") + } + return lx.errorf("Invalid escape character '%v'. Only the following "+ + "escape characters are allowed: \\xXX, \\t, \\n, \\r, \\\", \\\\.", r) +} + +// lexStringBinary consumes two hexadecimal digits following '\x'. It assumes +// that the '\x' has already been consumed. +func lexStringBinary(lx *lexer) stateFn { + r := lx.next() + if isNL(r) { + return lx.errorf("Expected two hexadecimal digits after '\\x', but hit end of line") + } + r = lx.next() + if isNL(r) { + return lx.errorf("Expected two hexadecimal digits after '\\x', but hit end of line") + } + offset := lx.pos - 2 + byteString, err := hex.DecodeString(lx.input[offset:lx.pos]) + if err != nil { + return lx.errorf("Expected two hexadecimal digits after '\\x', but got '%s'", lx.input[offset:lx.pos]) + } + lx.addStringPart(string(byteString)) + return lx.stringStateFn +} + +// lexNumberOrDateStart consumes either a (positive) integer, a float, a datetime, or IP. +// It assumes that NO negative sign has been consumed, that is triggered above. +func lexNumberOrDateOrIPStart(lx *lexer) stateFn { + r := lx.next() + if !unicode.IsDigit(r) { + if r == '.' { + return lx.errorf("Floats must start with a digit, not '.'.") + } + return lx.errorf("Expected a digit but got '%v'.", r) + } + return lexNumberOrDateOrIP +} + +// lexNumberOrDateOrIP consumes either a (positive) integer, float, datetime or IP. +func lexNumberOrDateOrIP(lx *lexer) stateFn { + r := lx.next() + switch { + case r == '-': + if lx.pos-lx.start != 5 { + return lx.errorf("All ISO8601 dates must be in full Zulu form.") + } + return lexDateAfterYear + case unicode.IsDigit(r): + return lexNumberOrDateOrIP + case r == '.': + return lexFloatStart // Assume float at first, but could be IP + case isNumberSuffix(r): + return lexConvenientNumber + } + + lx.backup() + lx.emit(itemInteger) + return lx.pop() +} + +// lexConvenientNumber is when we have a suffix, e.g. 1k or 1Mb +func lexConvenientNumber(lx *lexer) stateFn { + r := lx.next() + switch { + case r == 'b' || r == 'B': + return lexConvenientNumber + } + lx.backup() + lx.emit(itemInteger) + return lx.pop() +} + +// lexDateAfterYear consumes a full Zulu Datetime in ISO8601 format. +// It assumes that "YYYY-" has already been consumed. +func lexDateAfterYear(lx *lexer) stateFn { + formats := []rune{ + // digits are '0'. + // everything else is direct equality. + '0', '0', '-', '0', '0', + 'T', + '0', '0', ':', '0', '0', ':', '0', '0', + 'Z', + } + for _, f := range formats { + r := lx.next() + if f == '0' { + if !unicode.IsDigit(r) { + return lx.errorf("Expected digit in ISO8601 datetime, "+ + "but found '%v' instead.", r) + } + } else if f != r { + return lx.errorf("Expected '%v' in ISO8601 datetime, "+ + "but found '%v' instead.", f, r) + } + } + lx.emit(itemDatetime) + return lx.pop() +} + +// lexNegNumberStart consumes either an integer or a float. It assumes that a +// negative sign has already been read, but that *no* digits have been consumed. +// lexNegNumberStart will move to the appropriate integer or float states. +func lexNegNumberStart(lx *lexer) stateFn { + // we MUST see a digit. Even floats have to start with a digit. + r := lx.next() + if !unicode.IsDigit(r) { + if r == '.' { + return lx.errorf("Floats must start with a digit, not '.'.") + } + return lx.errorf("Expected a digit but got '%v'.", r) + } + return lexNegNumber +} + +// lexNumber consumes a negative integer or a float after seeing the first digit. +func lexNegNumber(lx *lexer) stateFn { + r := lx.next() + switch { + case unicode.IsDigit(r): + return lexNegNumber + case r == '.': + return lexFloatStart + case isNumberSuffix(r): + return lexConvenientNumber + } + lx.backup() + lx.emit(itemInteger) + return lx.pop() +} + +// lexFloatStart starts the consumption of digits of a float after a '.'. +// Namely, at least one digit is required. +func lexFloatStart(lx *lexer) stateFn { + r := lx.next() + if !unicode.IsDigit(r) { + return lx.errorf("Floats must have a digit after the '.', but got "+ + "'%v' instead.", r) + } + return lexFloat +} + +// lexFloat consumes the digits of a float after a '.'. +// Assumes that one digit has been consumed after a '.' already. +func lexFloat(lx *lexer) stateFn { + r := lx.next() + if unicode.IsDigit(r) { + return lexFloat + } + + // Not a digit, if its another '.', need to see if we falsely assumed a float. + if r == '.' { + return lexIPAddr + } + + lx.backup() + lx.emit(itemFloat) + return lx.pop() +} + +// lexIPAddr consumes IP addrs, like 127.0.0.1:4222 +func lexIPAddr(lx *lexer) stateFn { + r := lx.next() + if unicode.IsDigit(r) || r == '.' || r == ':' || r == '-' { + return lexIPAddr + } + lx.backup() + lx.emit(itemString) + return lx.pop() +} + +// lexCommentStart begins the lexing of a comment. It will emit +// itemCommentStart and consume no characters, passing control to lexComment. +func lexCommentStart(lx *lexer) stateFn { + lx.ignore() + lx.emit(itemCommentStart) + return lexComment +} + +// lexComment lexes an entire comment. It assumes that '#' has been consumed. +// It will consume *up to* the first new line character, and pass control +// back to the last state on the stack. +func lexComment(lx *lexer) stateFn { + r := lx.peek() + if isNL(r) || r == eof { + lx.emit(itemText) + return lx.pop() + } + lx.next() + return lexComment +} + +// lexSkip ignores all slurped input and moves on to the next state. +func lexSkip(lx *lexer, nextState stateFn) stateFn { + return func(lx *lexer) stateFn { + lx.ignore() + return nextState + } +} + +// Tests to see if we have a number suffix +func isNumberSuffix(r rune) bool { + return r == 'k' || r == 'K' || r == 'm' || r == 'M' || r == 'g' || r == 'G' +} + +// Tests for both key separators +func isKeySeparator(r rune) bool { + return r == keySepEqual || r == keySepColon +} + +// isWhitespace returns true if `r` is a whitespace character according +// to the spec. +func isWhitespace(r rune) bool { + return r == '\t' || r == ' ' +} + +func isNL(r rune) bool { + return r == '\n' || r == '\r' +} + +func (itype itemType) String() string { + switch itype { + case itemError: + return "Error" + case itemNIL: + return "NIL" + case itemEOF: + return "EOF" + case itemText: + return "Text" + case itemString: + return "String" + case itemBool: + return "Bool" + case itemInteger: + return "Integer" + case itemFloat: + return "Float" + case itemDatetime: + return "DateTime" + case itemKey: + return "Key" + case itemArrayStart: + return "ArrayStart" + case itemArrayEnd: + return "ArrayEnd" + case itemMapStart: + return "MapStart" + case itemMapEnd: + return "MapEnd" + case itemCommentStart: + return "CommentStart" + case itemVariable: + return "Variable" + case itemInclude: + return "Include" + } + panic(fmt.Sprintf("BUG: Unknown type '%s'.", itype.String())) +} + +func (item item) String() string { + return fmt.Sprintf("(%s, '%s', %d)", item.typ.String(), item.val, item.line) +} + +func escapeSpecial(c rune) string { + switch c { + case '\n': + return "\\n" + } + return string(c) +} diff --git a/vendor/github.com/nats-io/gnatsd/conf/parse.go b/vendor/github.com/nats-io/gnatsd/conf/parse.go new file mode 100644 index 00000000000..09205ae0bec --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/conf/parse.go @@ -0,0 +1,295 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conf supports a configuration file format used by gnatsd. It is +// a flexible format that combines the best of traditional +// configuration formats and newer styles such as JSON and YAML. +package conf + +// The format supported is less restrictive than today's formats. +// Supports mixed Arrays [], nested Maps {}, multiple comment types (# and //) +// Also supports key value assigments using '=' or ':' or whiteSpace() +// e.g. foo = 2, foo : 2, foo 2 +// maps can be assigned with no key separator as well +// semicolons as value terminators in key/value assignments are optional +// +// see parse_test.go for more examples. + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "strings" + "time" + "unicode" +) + +type parser struct { + mapping map[string]interface{} + lx *lexer + + // The current scoped context, can be array or map + ctx interface{} + + // stack of contexts, either map or array/slice stack + ctxs []interface{} + + // Keys stack + keys []string + + // The config file path, empty by default. + fp string +} + +// Parse will return a map of keys to interface{}, although concrete types +// underly them. The values supported are string, bool, int64, float64, DateTime. +// Arrays and nested Maps are also supported. +func Parse(data string) (map[string]interface{}, error) { + p, err := parse(data, "") + if err != nil { + return nil, err + } + return p.mapping, nil +} + +// ParseFile is a helper to open file, etc. and parse the contents. +func ParseFile(fp string) (map[string]interface{}, error) { + data, err := ioutil.ReadFile(fp) + if err != nil { + return nil, fmt.Errorf("error opening config file: %v", err) + } + + p, err := parse(string(data), filepath.Dir(fp)) + if err != nil { + return nil, err + } + return p.mapping, nil +} + +func parse(data, fp string) (p *parser, err error) { + p = &parser{ + mapping: make(map[string]interface{}), + lx: lex(data), + ctxs: make([]interface{}, 0, 4), + keys: make([]string, 0, 4), + fp: fp, + } + p.pushContext(p.mapping) + + for { + it := p.next() + if it.typ == itemEOF { + break + } + if err := p.processItem(it); err != nil { + return nil, err + } + } + + return p, nil +} + +func (p *parser) next() item { + return p.lx.nextItem() +} + +func (p *parser) pushContext(ctx interface{}) { + p.ctxs = append(p.ctxs, ctx) + p.ctx = ctx +} + +func (p *parser) popContext() interface{} { + if len(p.ctxs) == 0 { + panic("BUG in parser, context stack empty") + } + li := len(p.ctxs) - 1 + last := p.ctxs[li] + p.ctxs = p.ctxs[0:li] + p.ctx = p.ctxs[len(p.ctxs)-1] + return last +} + +func (p *parser) pushKey(key string) { + p.keys = append(p.keys, key) +} + +func (p *parser) popKey() string { + if len(p.keys) == 0 { + panic("BUG in parser, keys stack empty") + } + li := len(p.keys) - 1 + last := p.keys[li] + p.keys = p.keys[0:li] + return last +} + +func (p *parser) processItem(it item) error { + switch it.typ { + case itemError: + return fmt.Errorf("Parse error on line %d: '%s'", it.line, it.val) + case itemKey: + p.pushKey(it.val) + case itemMapStart: + newCtx := make(map[string]interface{}) + p.pushContext(newCtx) + case itemMapEnd: + p.setValue(p.popContext()) + case itemString: + p.setValue(it.val) // FIXME(dlc) sanitize string? + case itemInteger: + lastDigit := 0 + for _, r := range it.val { + if !unicode.IsDigit(r) && r != '-' { + break + } + lastDigit++ + } + numStr := it.val[:lastDigit] + num, err := strconv.ParseInt(numStr, 10, 64) + if err != nil { + if e, ok := err.(*strconv.NumError); ok && + e.Err == strconv.ErrRange { + return fmt.Errorf("Integer '%s' is out of the range.", it.val) + } + return fmt.Errorf("Expected integer, but got '%s'.", it.val) + } + // Process a suffix + suffix := strings.ToLower(strings.TrimSpace(it.val[lastDigit:])) + switch suffix { + case "": + p.setValue(num) + case "k": + p.setValue(num * 1000) + case "kb": + p.setValue(num * 1024) + case "m": + p.setValue(num * 1000 * 1000) + case "mb": + p.setValue(num * 1024 * 1024) + case "g": + p.setValue(num * 1000 * 1000 * 1000) + case "gb": + p.setValue(num * 1024 * 1024 * 1024) + } + case itemFloat: + num, err := strconv.ParseFloat(it.val, 64) + if err != nil { + if e, ok := err.(*strconv.NumError); ok && + e.Err == strconv.ErrRange { + return fmt.Errorf("Float '%s' is out of the range.", it.val) + } + return fmt.Errorf("Expected float, but got '%s'.", it.val) + } + p.setValue(num) + case itemBool: + switch strings.ToLower(it.val) { + case "true", "yes", "on": + p.setValue(true) + case "false", "no", "off": + p.setValue(false) + default: + return fmt.Errorf("Expected boolean value, but got '%s'.", it.val) + } + case itemDatetime: + dt, err := time.Parse("2006-01-02T15:04:05Z", it.val) + if err != nil { + return fmt.Errorf( + "Expected Zulu formatted DateTime, but got '%s'.", it.val) + } + p.setValue(dt) + case itemArrayStart: + var array = make([]interface{}, 0) + p.pushContext(array) + case itemArrayEnd: + array := p.ctx + p.popContext() + p.setValue(array) + case itemVariable: + if value, ok := p.lookupVariable(it.val); ok { + p.setValue(value) + } else { + return fmt.Errorf("Variable reference for '%s' on line %d can not be found.", + it.val, it.line) + } + case itemInclude: + m, err := ParseFile(filepath.Join(p.fp, it.val)) + if err != nil { + return fmt.Errorf("Error parsing include file '%s', %v.", it.val, err) + } + for k, v := range m { + p.pushKey(k) + p.setValue(v) + } + } + + return nil +} + +// Used to map an environment value into a temporary map to pass to secondary Parse call. +const pkey = "pk" + +// We special case raw strings here that are bcrypt'd. This allows us not to force quoting the strings +const bcryptPrefix = "2a$" + +// lookupVariable will lookup a variable reference. It will use block scoping on keys +// it has seen before, with the top level scoping being the environment variables. We +// ignore array contexts and only process the map contexts.. +// +// Returns true for ok if it finds something, similar to map. +func (p *parser) lookupVariable(varReference string) (interface{}, bool) { + // Do special check to see if it is a raw bcrypt string. + if strings.HasPrefix(varReference, bcryptPrefix) { + return "$" + varReference, true + } + + // Loop through contexts currently on the stack. + for i := len(p.ctxs) - 1; i >= 0; i -= 1 { + ctx := p.ctxs[i] + // Process if it is a map context + if m, ok := ctx.(map[string]interface{}); ok { + if v, ok := m[varReference]; ok { + return v, ok + } + } + } + + // If we are here, we have exhausted our context maps and still not found anything. + // Parse from the environment. + if vStr, ok := os.LookupEnv(varReference); ok { + // Everything we get here will be a string value, so we need to process as a parser would. + if vmap, err := Parse(fmt.Sprintf("%s=%s", pkey, vStr)); err == nil { + v, ok := vmap[pkey] + return v, ok + } + } + return nil, false +} + +func (p *parser) setValue(val interface{}) { + // Test to see if we are on an array or a map + + // Array processing + if ctx, ok := p.ctx.([]interface{}); ok { + p.ctx = append(ctx, val) + p.ctxs[len(p.ctxs)-1] = p.ctx + } + + // Map processing + if ctx, ok := p.ctx.(map[string]interface{}); ok { + key := p.popKey() + // FIXME(dlc), make sure to error if redefining same key? + ctx[key] = val + } +} diff --git a/vendor/github.com/nats-io/gnatsd/logger/log.go b/vendor/github.com/nats-io/gnatsd/logger/log.go new file mode 100644 index 00000000000..132cb42a61c --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/logger/log.go @@ -0,0 +1,152 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//Package logger provides logging facilities for the NATS server +package logger + +import ( + "fmt" + "log" + "os" +) + +// Logger is the server logger +type Logger struct { + logger *log.Logger + debug bool + trace bool + infoLabel string + errorLabel string + fatalLabel string + debugLabel string + traceLabel string + logFile *os.File // file pointer for the file logger. +} + +// NewStdLogger creates a logger with output directed to Stderr +func NewStdLogger(time, debug, trace, colors, pid bool) *Logger { + flags := 0 + if time { + flags = log.LstdFlags | log.Lmicroseconds + } + + pre := "" + if pid { + pre = pidPrefix() + } + + l := &Logger{ + logger: log.New(os.Stderr, pre, flags), + debug: debug, + trace: trace, + } + + if colors { + setColoredLabelFormats(l) + } else { + setPlainLabelFormats(l) + } + + return l +} + +// NewFileLogger creates a logger with output directed to a file +func NewFileLogger(filename string, time, debug, trace, pid bool) *Logger { + fileflags := os.O_WRONLY | os.O_APPEND | os.O_CREATE + f, err := os.OpenFile(filename, fileflags, 0660) + if err != nil { + log.Fatalf("error opening file: %v", err) + } + + flags := 0 + if time { + flags = log.LstdFlags | log.Lmicroseconds + } + + pre := "" + if pid { + pre = pidPrefix() + } + + l := &Logger{ + logger: log.New(f, pre, flags), + debug: debug, + trace: trace, + logFile: f, + } + + setPlainLabelFormats(l) + return l +} + +// Close implements the io.Closer interface to clean up +// resources in the server's logger implementation. +// Caller must ensure threadsafety. +func (l *Logger) Close() error { + if f := l.logFile; f != nil { + l.logFile = nil + return f.Close() + } + return nil +} + +// Generate the pid prefix string +func pidPrefix() string { + return fmt.Sprintf("[%d] ", os.Getpid()) +} + +func setPlainLabelFormats(l *Logger) { + l.infoLabel = "[INF] " + l.debugLabel = "[DBG] " + l.errorLabel = "[ERR] " + l.fatalLabel = "[FTL] " + l.traceLabel = "[TRC] " +} + +func setColoredLabelFormats(l *Logger) { + colorFormat := "[\x1b[%dm%s\x1b[0m] " + l.infoLabel = fmt.Sprintf(colorFormat, 32, "INF") + l.debugLabel = fmt.Sprintf(colorFormat, 36, "DBG") + l.errorLabel = fmt.Sprintf(colorFormat, 31, "ERR") + l.fatalLabel = fmt.Sprintf(colorFormat, 31, "FTL") + l.traceLabel = fmt.Sprintf(colorFormat, 33, "TRC") +} + +// Noticef logs a notice statement +func (l *Logger) Noticef(format string, v ...interface{}) { + l.logger.Printf(l.infoLabel+format, v...) +} + +// Errorf logs an error statement +func (l *Logger) Errorf(format string, v ...interface{}) { + l.logger.Printf(l.errorLabel+format, v...) +} + +// Fatalf logs a fatal error +func (l *Logger) Fatalf(format string, v ...interface{}) { + l.logger.Fatalf(l.fatalLabel+format, v...) +} + +// Debugf logs a debug statement +func (l *Logger) Debugf(format string, v ...interface{}) { + if l.debug { + l.logger.Printf(l.debugLabel+format, v...) + } +} + +// Tracef logs a trace statement +func (l *Logger) Tracef(format string, v ...interface{}) { + if l.trace { + l.logger.Printf(l.traceLabel+format, v...) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/logger/syslog.go b/vendor/github.com/nats-io/gnatsd/logger/syslog.go new file mode 100644 index 00000000000..7d7713b3df6 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/logger/syslog.go @@ -0,0 +1,127 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !windows + +package logger + +import ( + "fmt" + "log" + "log/syslog" + "net/url" + "os" + "strings" +) + +// SysLogger provides a system logger facility +type SysLogger struct { + writer *syslog.Writer + debug bool + trace bool +} + +// SetSyslogName sets the name to use for the syslog. +// Currently used only on Windows. +func SetSyslogName(name string) {} + +// GetSysLoggerTag generates the tag name for use in syslog statements. If +// the executable is linked, the name of the link will be used as the tag, +// otherwise, the name of the executable is used. "gnatsd" is the default +// for the NATS server. +func GetSysLoggerTag() string { + procName := os.Args[0] + if strings.ContainsRune(procName, os.PathSeparator) { + parts := strings.FieldsFunc(procName, func(c rune) bool { + return c == os.PathSeparator + }) + procName = parts[len(parts)-1] + } + return procName +} + +// NewSysLogger creates a new system logger +func NewSysLogger(debug, trace bool) *SysLogger { + w, err := syslog.New(syslog.LOG_DAEMON|syslog.LOG_NOTICE, GetSysLoggerTag()) + if err != nil { + log.Fatalf("error connecting to syslog: %q", err.Error()) + } + + return &SysLogger{ + writer: w, + debug: debug, + trace: trace, + } +} + +// NewRemoteSysLogger creates a new remote system logger +func NewRemoteSysLogger(fqn string, debug, trace bool) *SysLogger { + network, addr := getNetworkAndAddr(fqn) + w, err := syslog.Dial(network, addr, syslog.LOG_DEBUG, GetSysLoggerTag()) + if err != nil { + log.Fatalf("error connecting to syslog: %q", err.Error()) + } + + return &SysLogger{ + writer: w, + debug: debug, + trace: trace, + } +} + +func getNetworkAndAddr(fqn string) (network, addr string) { + u, err := url.Parse(fqn) + if err != nil { + log.Fatal(err) + } + + network = u.Scheme + if network == "udp" || network == "tcp" { + addr = u.Host + } else if network == "unix" { + addr = u.Path + } else { + log.Fatalf("error invalid network type: %q", u.Scheme) + } + + return +} + +// Noticef logs a notice statement +func (l *SysLogger) Noticef(format string, v ...interface{}) { + l.writer.Notice(fmt.Sprintf(format, v...)) +} + +// Fatalf logs a fatal error +func (l *SysLogger) Fatalf(format string, v ...interface{}) { + l.writer.Crit(fmt.Sprintf(format, v...)) +} + +// Errorf logs an error statement +func (l *SysLogger) Errorf(format string, v ...interface{}) { + l.writer.Err(fmt.Sprintf(format, v...)) +} + +// Debugf logs a debug statement +func (l *SysLogger) Debugf(format string, v ...interface{}) { + if l.debug { + l.writer.Debug(fmt.Sprintf(format, v...)) + } +} + +// Tracef logs a trace statement +func (l *SysLogger) Tracef(format string, v ...interface{}) { + if l.trace { + l.writer.Notice(fmt.Sprintf(format, v...)) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/logger/syslog_windows.go b/vendor/github.com/nats-io/gnatsd/logger/syslog_windows.go new file mode 100644 index 00000000000..54972062e30 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/logger/syslog_windows.go @@ -0,0 +1,107 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package logger logs to the windows event log +package logger + +import ( + "fmt" + "os" + "strings" + + "golang.org/x/sys/windows/svc/eventlog" +) + +var natsEventSource = "NATS-Server" + +// SetSyslogName sets the name to use for the system log event source +func SetSyslogName(name string) { + natsEventSource = name +} + +// SysLogger logs to the windows event logger +type SysLogger struct { + writer *eventlog.Log + debug bool + trace bool +} + +// NewSysLogger creates a log using the windows event logger +func NewSysLogger(debug, trace bool) *SysLogger { + if err := eventlog.InstallAsEventCreate(natsEventSource, eventlog.Info|eventlog.Error|eventlog.Warning); err != nil { + if !strings.Contains(err.Error(), "registry key already exists") { + panic(fmt.Sprintf("could not access event log: %v", err)) + } + } + + w, err := eventlog.Open(natsEventSource) + if err != nil { + panic(fmt.Sprintf("could not open event log: %v", err)) + } + + return &SysLogger{ + writer: w, + debug: debug, + trace: trace, + } +} + +// NewRemoteSysLogger creates a remote event logger +func NewRemoteSysLogger(fqn string, debug, trace bool) *SysLogger { + w, err := eventlog.OpenRemote(fqn, natsEventSource) + if err != nil { + panic(fmt.Sprintf("could not open event log: %v", err)) + } + + return &SysLogger{ + writer: w, + debug: debug, + trace: trace, + } +} + +func formatMsg(tag, format string, v ...interface{}) string { + orig := fmt.Sprintf(format, v...) + return fmt.Sprintf("pid[%d][%s]: %s", os.Getpid(), tag, orig) +} + +// Noticef logs a notice statement +func (l *SysLogger) Noticef(format string, v ...interface{}) { + l.writer.Info(1, formatMsg("NOTICE", format, v...)) +} + +// Fatalf logs a fatal error +func (l *SysLogger) Fatalf(format string, v ...interface{}) { + msg := formatMsg("FATAL", format, v...) + l.writer.Error(5, msg) + panic(msg) +} + +// Errorf logs an error statement +func (l *SysLogger) Errorf(format string, v ...interface{}) { + l.writer.Error(2, formatMsg("ERROR", format, v...)) +} + +// Debugf logs a debug statement +func (l *SysLogger) Debugf(format string, v ...interface{}) { + if l.debug { + l.writer.Info(3, formatMsg("DEBUG", format, v...)) + } +} + +// Tracef logs a trace statement +func (l *SysLogger) Tracef(format string, v ...interface{}) { + if l.trace { + l.writer.Info(4, formatMsg("TRACE", format, v...)) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/auth.go b/vendor/github.com/nats-io/gnatsd/server/auth.go new file mode 100644 index 00000000000..724e0cb77cc --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/auth.go @@ -0,0 +1,271 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "crypto/tls" + "fmt" + "strings" + + "golang.org/x/crypto/bcrypt" +) + +// Authentication is an interface for implementing authentication +type Authentication interface { + // Check if a client is authorized to connect + Check(c ClientAuthentication) bool +} + +// ClientAuthentication is an interface for client authentication +type ClientAuthentication interface { + // Get options associated with a client + GetOpts() *clientOpts + // If TLS is enabled, TLS ConnectionState, nil otherwise + GetTLSConnectionState() *tls.ConnectionState + // Optionally map a user after auth. + RegisterUser(*User) +} + +// User is for multiple accounts/users. +type User struct { + Username string `json:"user"` + Password string `json:"password"` + Permissions *Permissions `json:"permissions"` +} + +// clone performs a deep copy of the User struct, returning a new clone with +// all values copied. +func (u *User) clone() *User { + if u == nil { + return nil + } + clone := &User{} + *clone = *u + clone.Permissions = u.Permissions.clone() + return clone +} + +// SubjectPermission is an individual allow and deny struct for publish +// and subscribe authorizations. +type SubjectPermission struct { + Allow []string `json:"allow"` + Deny []string `json:"deny"` +} + +// Permissions are the allowed subjects on a per +// publish or subscribe basis. +type Permissions struct { + Publish *SubjectPermission `json:"publish"` + Subscribe *SubjectPermission `json:"subscribe"` +} + +// RoutePermissions are similar to user permissions +// but describe what a server can import/export from and to +// another server. +type RoutePermissions struct { + Import *SubjectPermission `json:"import"` + Export *SubjectPermission `json:"export"` +} + +// clone will clone an individual subject permission. +func (p *SubjectPermission) clone() *SubjectPermission { + if p == nil { + return nil + } + clone := &SubjectPermission{} + if p.Allow != nil { + clone.Allow = make([]string, len(p.Allow)) + copy(clone.Allow, p.Allow) + } + if p.Deny != nil { + clone.Deny = make([]string, len(p.Deny)) + copy(clone.Deny, p.Deny) + } + return clone +} + +// clone performs a deep copy of the Permissions struct, returning a new clone +// with all values copied. +func (p *Permissions) clone() *Permissions { + if p == nil { + return nil + } + clone := &Permissions{} + if p.Publish != nil { + clone.Publish = p.Publish.clone() + } + if p.Subscribe != nil { + clone.Subscribe = p.Subscribe.clone() + } + return clone +} + +// configureAuthorization will do any setup needed for authorization. +// Lock is assumed held. +func (s *Server) configureAuthorization() { + if s.opts == nil { + return + } + + // Snapshot server options. + opts := s.getOpts() + + // Check for multiple users first + // This just checks and sets up the user map if we have multiple users. + if opts.CustomClientAuthentication != nil { + s.info.AuthRequired = true + } else if opts.Users != nil { + s.users = make(map[string]*User) + for _, u := range opts.Users { + s.users[u.Username] = u + } + s.info.AuthRequired = true + } else if opts.Username != "" || opts.Authorization != "" { + s.info.AuthRequired = true + } else { + s.users = nil + s.info.AuthRequired = false + } +} + +// checkAuthorization will check authorization based on client type and +// return boolean indicating if client is authorized. +func (s *Server) checkAuthorization(c *client) bool { + switch c.typ { + case CLIENT: + return s.isClientAuthorized(c) + case ROUTER: + return s.isRouterAuthorized(c) + default: + return false + } +} + +// hasUsers leyt's us know if we have a users array. +func (s *Server) hasUsers() bool { + s.mu.Lock() + hu := s.users != nil + s.mu.Unlock() + return hu +} + +// isClientAuthorized will check the client against the proper authorization method and data. +// This could be token or username/password based. +func (s *Server) isClientAuthorized(c *client) bool { + // Snapshot server options. + opts := s.getOpts() + + // Check custom auth first, then multiple users, then token, then single user/pass. + if opts.CustomClientAuthentication != nil { + return opts.CustomClientAuthentication.Check(c) + } else if s.hasUsers() { + s.mu.Lock() + user, ok := s.users[c.opts.Username] + s.mu.Unlock() + + if !ok { + return false + } + ok = comparePasswords(user.Password, c.opts.Password) + // If we are authorized, register the user which will properly setup any permissions + // for pub/sub authorizations. + if ok { + c.RegisterUser(user) + } + return ok + + } else if opts.Authorization != "" { + return comparePasswords(opts.Authorization, c.opts.Authorization) + + } else if opts.Username != "" { + if opts.Username != c.opts.Username { + return false + } + return comparePasswords(opts.Password, c.opts.Password) + } + + return true +} + +// checkRouterAuth checks optional router authorization which can be nil or username/password. +func (s *Server) isRouterAuthorized(c *client) bool { + // Snapshot server options. + opts := s.getOpts() + + if s.opts.CustomRouterAuthentication != nil { + return s.opts.CustomRouterAuthentication.Check(c) + } + + if opts.Cluster.Username == "" { + return true + } + + if opts.Cluster.Username != c.opts.Username { + return false + } + if !comparePasswords(opts.Cluster.Password, c.opts.Password) { + return false + } + c.setRoutePermissions(opts.Cluster.Permissions) + return true +} + +// removeUnauthorizedSubs removes any subscriptions the client has that are no +// longer authorized, e.g. due to a config reload. +func (s *Server) removeUnauthorizedSubs(c *client) { + c.mu.Lock() + if c.perms == nil { + c.mu.Unlock() + return + } + + subs := make(map[string]*subscription, len(c.subs)) + for sid, sub := range c.subs { + subs[sid] = sub + } + c.mu.Unlock() + + for sid, sub := range subs { + if !c.canSubscribe(sub.subject) { + _ = s.sl.Remove(sub) + c.mu.Lock() + delete(c.subs, sid) + c.mu.Unlock() + c.sendErr(fmt.Sprintf("Permissions Violation for Subscription to %q (sid %s)", + sub.subject, sub.sid)) + s.Noticef("Removed sub %q for user %q - not authorized", + string(sub.subject), c.opts.Username) + } + } +} + +// Support for bcrypt stored passwords and tokens. +const bcryptPrefix = "$2a$" + +// isBcrypt checks whether the given password or token is bcrypted. +func isBcrypt(password string) bool { + return strings.HasPrefix(password, bcryptPrefix) +} + +func comparePasswords(serverPassword, clientPassword string) bool { + // Check to see if the server password is a bcrypt hash + if isBcrypt(serverPassword) { + if err := bcrypt.CompareHashAndPassword([]byte(serverPassword), []byte(clientPassword)); err != nil { + return false + } + } else if serverPassword != clientPassword { + return false + } + return true +} diff --git a/vendor/github.com/nats-io/gnatsd/server/ciphersuites.go b/vendor/github.com/nats-io/gnatsd/server/ciphersuites.go new file mode 100644 index 00000000000..cbc5a2fff83 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/ciphersuites.go @@ -0,0 +1,97 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "crypto/tls" +) + +// Where we maintain all of the available ciphers +var cipherMap = map[string]uint16{ + "TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA, + "TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, + "TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, + "TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256, + "TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, + "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + "TLS_ECDHE_RSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, + "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, +} + +var cipherMapByID = map[uint16]string{ + tls.TLS_RSA_WITH_RC4_128_SHA: "TLS_RSA_WITH_RC4_128_SHA", + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_RSA_WITH_3DES_EDE_CBC_SHA", + tls.TLS_RSA_WITH_AES_128_CBC_SHA: "TLS_RSA_WITH_AES_128_CBC_SHA", + tls.TLS_RSA_WITH_AES_128_CBC_SHA256: "TLS_RSA_WITH_AES_128_CBC_SHA256", + tls.TLS_RSA_WITH_AES_256_CBC_SHA: "TLS_RSA_WITH_AES_256_CBC_SHA", + tls.TLS_RSA_WITH_AES_256_GCM_SHA384: "TLS_RSA_WITH_AES_256_GCM_SHA384", + tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", + tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: "TLS_ECDHE_RSA_WITH_RC4_128_SHA", + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", +} + +func defaultCipherSuites() []uint16 { + return []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + } +} + +// Where we maintain available curve preferences +var curvePreferenceMap = map[string]tls.CurveID{ + "CurveP256": tls.CurveP256, + "CurveP384": tls.CurveP384, + "CurveP521": tls.CurveP521, + "X25519": tls.X25519, +} + +// reorder to default to the highest level of security. See: +// https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go +func defaultCurvePreferences() []tls.CurveID { + return []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.X25519, // faster than P256, arguably more secure + tls.CurveP256, + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/client.go b/vendor/github.com/nats-io/gnatsd/server/client.go new file mode 100644 index 00000000000..23a507eabcd --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/client.go @@ -0,0 +1,1865 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "math/rand" + "net" + "sync" + "sync/atomic" + "time" +) + +// Type of client connection. +const ( + // CLIENT is an end user. + CLIENT = iota + // ROUTER is another router in the cluster. + ROUTER +) + +const ( + // Original Client protocol from 2009. + // http://nats.io/documentation/internals/nats-protocol/ + ClientProtoZero = iota + // This signals a client can receive more then the original INFO block. + // This can be used to update clients on other cluster members, etc. + ClientProtoInfo +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +const ( + // Scratch buffer size for the processMsg() calls. + msgScratchSize = 512 + msgHeadProto = "MSG " + msgHeadProtoLen = len(msgHeadProto) +) + +// For controlling dynamic buffer sizes. +const ( + startBufSize = 512 // For INFO/CONNECT block + minBufSize = 64 // Smallest to shrink to for PING/PONG + maxBufSize = 65536 // 64k + shortsToShrink = 2 +) + +// Represent client booleans with a bitmask +type clientFlag byte + +// Some client state represented as flags +const ( + connectReceived clientFlag = 1 << iota // The CONNECT proto has been received + firstPongSent // The first PONG has been sent + handshakeComplete // For TLS clients, indicate that the handshake is complete + clearConnection // Marks that clearConnection has already been called. + flushOutbound // Marks client as having a flushOutbound call in progress. +) + +// set the flag (would be equivalent to set the boolean to true) +func (cf *clientFlag) set(c clientFlag) { + *cf |= c +} + +// clear the flag (would be equivalent to set the boolean to false) +func (cf *clientFlag) clear(c clientFlag) { + *cf &= ^c +} + +// isSet returns true if the flag is set, false otherwise +func (cf clientFlag) isSet(c clientFlag) bool { + return cf&c != 0 +} + +// setIfNotSet will set the flag `c` only if that flag was not already +// set and return true to indicate that the flag has been set. Returns +// false otherwise. +func (cf *clientFlag) setIfNotSet(c clientFlag) bool { + if *cf&c == 0 { + *cf |= c + return true + } + return false +} + +// Reason client was closed. This will be passed into +// calls to clearConnection, but will only be stored +// in ConnInfo for monitoring. +type ClosedState int + +const ( + ClientClosed = ClosedState(iota + 1) + AuthenticationTimeout + AuthenticationViolation + TLSHandshakeError + SlowConsumerPendingBytes + SlowConsumerWriteDeadline + WriteError + ReadError + ParseError + StaleConnection + ProtocolViolation + BadClientProtocolVersion + WrongPort + MaxConnectionsExceeded + MaxPayloadExceeded + MaxControlLineExceeded + DuplicateRoute + RouteRemoved + ServerShutdown +) + +type client struct { + // Here first because of use of atomics, and memory alignment. + stats + mpay int64 + msubs int + mu sync.Mutex + typ int + cid uint64 + opts clientOpts + start time.Time + nc net.Conn + ncs string + out outbound + srv *Server + subs map[string]*subscription + perms *permissions + in readCache + pcd map[*client]struct{} + atmr *time.Timer + ping pinfo + msgb [msgScratchSize]byte + last time.Time + parseState + + rtt time.Duration + rttStart time.Time + + route *route + + debug bool + trace bool + echo bool + + flags clientFlag // Compact booleans into a single field. Size will be increased when needed. +} + +// Struct for PING initiation from the server. +type pinfo struct { + tmr *time.Timer + out int +} + +// outbound holds pending data for a socket. +type outbound struct { + p []byte // Primary write buffer + s []byte // Secondary for use post flush + nb net.Buffers // net.Buffers for writev IO + sz int // limit size per []byte, uses variable BufSize constants, start, min, max. + sws int // Number of short writes, used for dyanmic resizing. + pb int64 // Total pending/queued bytes. + pm int64 // Total pending/queued messages. + sg *sync.Cond // Flusher conditional for signaling. + fsp int // Flush signals that are pending from readLoop's pcd. + mp int64 // snapshot of max pending. + wdl time.Duration // Snapshot fo write deadline. + lft time.Duration // Last flush time. +} + +type perm struct { + allow *Sublist + deny *Sublist +} +type permissions struct { + sub perm + pub perm + pcache map[string]bool +} + +const ( + maxResultCacheSize = 512 + maxPermCacheSize = 32 + pruneSize = 16 +) + +// Used in readloop to cache hot subject lookups and group statistics. +type readCache struct { + genid uint64 + results map[string]*SublistResult + prand *rand.Rand + msgs int + bytes int + subs int + rsz int // Read buffer size + srs int // Short reads, used for dynamic buffer resizing. +} + +func (c *client) String() (id string) { + return c.ncs +} + +func (c *client) GetOpts() *clientOpts { + return &c.opts +} + +// GetTLSConnectionState returns the TLS ConnectionState if TLS is enabled, nil +// otherwise. Implements the ClientAuth interface. +func (c *client) GetTLSConnectionState() *tls.ConnectionState { + tc, ok := c.nc.(*tls.Conn) + if !ok { + return nil + } + state := tc.ConnectionState() + return &state +} + +type subscription struct { + client *client + subject []byte + queue []byte + sid []byte + nm int64 + max int64 +} + +type clientOpts struct { + Echo bool `json:"echo"` + Verbose bool `json:"verbose"` + Pedantic bool `json:"pedantic"` + TLSRequired bool `json:"tls_required"` + Authorization string `json:"auth_token"` + Username string `json:"user"` + Password string `json:"pass"` + Name string `json:"name"` + Lang string `json:"lang"` + Version string `json:"version"` + Protocol int `json:"protocol"` +} + +var defaultOpts = clientOpts{Verbose: true, Pedantic: true, Echo: true} + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +// Lock should be held +func (c *client) initClient() { + s := c.srv + c.cid = atomic.AddUint64(&s.gcid, 1) + + // Outbound data structure setup + c.out.sz = startBufSize + c.out.sg = sync.NewCond(&c.mu) + opts := s.getOpts() + // Snapshots to avoid mutex access in fast paths. + c.out.wdl = opts.WriteDeadline + c.out.mp = opts.MaxPending + + c.subs = make(map[string]*subscription) + c.echo = true + + c.debug = (atomic.LoadInt32(&c.srv.logging.debug) != 0) + c.trace = (atomic.LoadInt32(&c.srv.logging.trace) != 0) + + // This is a scratch buffer used for processMsg() + // The msg header starts with "MSG ", + // in bytes that is [77 83 71 32]. + c.msgb = [msgScratchSize]byte{77, 83, 71, 32} + + // This is to track pending clients that have data to be flushed + // after we process inbound msgs from our own connection. + c.pcd = make(map[*client]struct{}) + + // snapshot the string version of the connection + conn := "-" + if ip, ok := c.nc.(*net.TCPConn); ok { + addr := ip.RemoteAddr().(*net.TCPAddr) + conn = fmt.Sprintf("%s:%d", addr.IP, addr.Port) + } + + switch c.typ { + case CLIENT: + c.ncs = fmt.Sprintf("%s - cid:%d", conn, c.cid) + case ROUTER: + c.ncs = fmt.Sprintf("%s - rid:%d", conn, c.cid) + } +} + +// RegisterUser allows auth to call back into a new client +// with the authenticated user. This is used to map any permissions +// into the client. +func (c *client) RegisterUser(user *User) { + if user.Permissions == nil { + // Reset perms to nil in case client previously had them. + c.mu.Lock() + c.perms = nil + c.mu.Unlock() + return + } + + // Process Permissions and map into client connection structures. + c.mu.Lock() + defer c.mu.Unlock() + + c.setPermissions(user.Permissions) +} + +// Initializes client.perms structure. +// Lock is held on entry. +func (c *client) setPermissions(perms *Permissions) { + if perms == nil { + return + } + c.perms = &permissions{} + c.perms.pcache = make(map[string]bool) + + // Loop over publish permissions + if perms.Publish != nil { + if len(perms.Publish.Allow) > 0 { + c.perms.pub.allow = NewSublist() + } + for _, pubSubject := range perms.Publish.Allow { + sub := &subscription{subject: []byte(pubSubject)} + c.perms.pub.allow.Insert(sub) + } + if len(perms.Publish.Deny) > 0 { + c.perms.pub.deny = NewSublist() + } + for _, pubSubject := range perms.Publish.Deny { + sub := &subscription{subject: []byte(pubSubject)} + c.perms.pub.deny.Insert(sub) + } + } + + // Loop over subscribe permissions + if perms.Subscribe != nil { + if len(perms.Subscribe.Allow) > 0 { + c.perms.sub.allow = NewSublist() + } + for _, subSubject := range perms.Subscribe.Allow { + sub := &subscription{subject: []byte(subSubject)} + c.perms.sub.allow.Insert(sub) + } + if len(perms.Subscribe.Deny) > 0 { + c.perms.sub.deny = NewSublist() + } + for _, subSubject := range perms.Subscribe.Deny { + sub := &subscription{subject: []byte(subSubject)} + c.perms.sub.deny.Insert(sub) + } + } +} + +// writeLoop is the main socket write functionality. +// Runs in its own Go routine. +func (c *client) writeLoop() { + defer c.srv.grWG.Done() + + // Used to check that we did flush from last wake up. + waitOk := true + + // Main loop. Will wait to be signaled and then will use + // buffered outbound structure for efficient writev to the underlying socket. + for { + c.mu.Lock() + if waitOk && (c.out.pb == 0 || c.out.fsp > 0) && len(c.out.nb) == 0 && !c.flags.isSet(clearConnection) { + // Wait on pending data. + c.out.sg.Wait() + } + // Flush data + waitOk = c.flushOutbound() + isClosed := c.flags.isSet(clearConnection) + c.mu.Unlock() + + if isClosed { + return + } + } +} + +// readLoop is the main socket read functionality. +// Runs in its own Go routine. +func (c *client) readLoop() { + // Grab the connection off the client, it will be cleared on a close. + // We check for that after the loop, but want to avoid a nil dereference + c.mu.Lock() + nc := c.nc + s := c.srv + c.in.rsz = startBufSize + defer s.grWG.Done() + c.mu.Unlock() + + if nc == nil { + return + } + + // Start read buffer. + + b := make([]byte, c.in.rsz) + + for { + n, err := nc.Read(b) + if err != nil { + if err == io.EOF { + c.closeConnection(ClientClosed) + } else { + c.closeConnection(ReadError) + } + return + } + + // Grab for updates for last activity. + last := time.Now() + + // Clear inbound stats cache + c.in.msgs = 0 + c.in.bytes = 0 + c.in.subs = 0 + + // Main call into parser for inbound data. This will generate callouts + // to process messages, etc. + if err := c.parse(b[:n]); err != nil { + // handled inline + if err != ErrMaxPayload && err != ErrAuthorization { + c.Errorf("%s", err.Error()) + c.closeConnection(ProtocolViolation) + } + return + } + + // Updates stats for client and server that were collected + // from parsing through the buffer. + if c.in.msgs > 0 { + atomic.AddInt64(&c.inMsgs, int64(c.in.msgs)) + atomic.AddInt64(&c.inBytes, int64(c.in.bytes)) + atomic.AddInt64(&s.inMsgs, int64(c.in.msgs)) + atomic.AddInt64(&s.inBytes, int64(c.in.bytes)) + } + + // Budget to spend in place flushing outbound data. + // Client will be checked on several fronts to see + // if applicable. Routes will never wait in place. + budget := 500 * time.Microsecond + if c.typ == ROUTER { + budget = 0 + } + + // Check pending clients for flush. + for cp := range c.pcd { + // Queue up a flush for those in the set + cp.mu.Lock() + // Update last activity for message delivery + cp.last = last + cp.out.fsp-- + if budget > 0 && cp.flushOutbound() { + budget -= cp.out.lft + } else { + cp.flushSignal() + } + cp.mu.Unlock() + delete(c.pcd, cp) + } + + // Update activity, check read buffer size. + c.mu.Lock() + nc := c.nc + + // Activity based on interest changes or data/msgs. + if c.in.msgs > 0 || c.in.subs > 0 { + c.last = last + } + + if n >= cap(b) { + c.in.srs = 0 + } else if n < cap(b)/2 { // divide by 2 b/c we want less than what we would shrink to. + c.in.srs++ + } + + // Update read buffer size as/if needed. + if n >= cap(b) && cap(b) < maxBufSize { + // Grow + c.in.rsz = cap(b) * 2 + b = make([]byte, c.in.rsz) + } else if n < cap(b) && cap(b) > minBufSize && c.in.srs > shortsToShrink { + // Shrink, for now don't accelerate, ping/pong will eventually sort it out. + c.in.rsz = cap(b) / 2 + b = make([]byte, c.in.rsz) + } + c.mu.Unlock() + + // Check to see if we got closed, e.g. slow consumer + if nc == nil { + return + } + } +} + +// collapsePtoNB will place primary onto nb buffer as needed in prep for WriteTo. +// This will return a copy on purpose. +func (c *client) collapsePtoNB() net.Buffers { + if c.out.p != nil { + p := c.out.p + c.out.p = nil + return append(c.out.nb, p) + } + return c.out.nb +} + +// This will handle the fixup needed on a partial write. +// Assume pending has been already calculated correctly. +func (c *client) handlePartialWrite(pnb net.Buffers) { + nb := c.collapsePtoNB() + // The partial needs to be first, so append nb to pnb + c.out.nb = append(pnb, nb...) +} + +// flushOutbound will flush outbound buffer to a client. +// Will return if data was attempted to be written. +// Lock must be held +func (c *client) flushOutbound() bool { + if c.flags.isSet(flushOutbound) { + return false + } + c.flags.set(flushOutbound) + defer c.flags.clear(flushOutbound) + + // Check for nothing to do. + if c.nc == nil || c.srv == nil || c.out.pb == 0 { + return true // true because no need to queue a signal. + } + + // Snapshot opts + srv := c.srv + + // Place primary on nb, assign primary to secondary, nil out nb and secondary. + nb := c.collapsePtoNB() + c.out.p, c.out.nb, c.out.s = c.out.s, nil, nil + + // For selecting primary replacement. + cnb := nb + + // In case it goes away after releasing the lock. + nc := c.nc + attempted := c.out.pb + apm := c.out.pm + + // Do NOT hold lock during actual IO + c.mu.Unlock() + + // flush here + now := time.Now() + // FIXME(dlc) - writev will do multiple IOs past 1024 on + // most platforms, need to account for that with deadline? + nc.SetWriteDeadline(now.Add(c.out.wdl)) + // Actual write to the socket. + n, err := nb.WriteTo(nc) + nc.SetWriteDeadline(time.Time{}) + lft := time.Since(now) + + // Re-acquire client lock + c.mu.Lock() + + // Update flush time statistics + c.out.lft = lft + + // Subtract from pending bytes and messages. + c.out.pb -= n + c.out.pm -= apm // FIXME(dlc) - this will not be accurate. + + // Check for partial writes + if n != attempted && n > 0 { + c.handlePartialWrite(nb) + } else if n >= int64(c.out.sz) { + c.out.sws = 0 + } + + if err != nil { + if n == 0 { + c.out.pb -= attempted + } + if ne, ok := err.(net.Error); ok && ne.Timeout() { + atomic.AddInt64(&srv.slowConsumers, 1) + c.clearConnection(SlowConsumerWriteDeadline) + c.Noticef("Slow Consumer Detected: WriteDeadline of %v Exceeded", c.out.wdl) + } else { + c.clearConnection(WriteError) + c.Debugf("Error flushing: %v", err) + } + return true + } + + // Adjust based on what we wrote plus any pending. + pt := int(n + c.out.pb) + + // Adjust sz as needed downward, keeping power of 2. + // We do this at a slower rate, hence the pt*4. + if pt < c.out.sz && c.out.sz > minBufSize { + c.out.sws++ + if c.out.sws > shortsToShrink { + c.out.sz >>= 1 + } + } + // Adjust sz as needed upward, keeping power of 2. + if pt > c.out.sz && c.out.sz < maxBufSize { + c.out.sz <<= 1 + } + + // Check to see if we can reuse buffers. + if len(cnb) > 0 { + oldp := cnb[0][:0] + if cap(oldp) >= c.out.sz { + // Replace primary or secondary if they are nil, reusing same buffer. + if c.out.p == nil { + c.out.p = oldp + } else if c.out.s == nil || cap(c.out.s) < c.out.sz { + c.out.s = oldp + } + } + } + return true +} + +// flushSignal will use server to queue the flush IO operation to a pool of flushers. +// Lock must be held. +func (c *client) flushSignal() { + c.out.sg.Signal() +} + +func (c *client) traceMsg(msg []byte) { + if !c.trace { + return + } + // FIXME(dlc), allow limits to printable payload + c.Tracef("->> MSG_PAYLOAD: [%s]", string(msg[:len(msg)-LEN_CR_LF])) +} + +func (c *client) traceInOp(op string, arg []byte) { + c.traceOp("->> %s", op, arg) +} + +func (c *client) traceOutOp(op string, arg []byte) { + c.traceOp("<<- %s", op, arg) +} + +func (c *client) traceOp(format, op string, arg []byte) { + if !c.trace { + return + } + + opa := []interface{}{} + if op != "" { + opa = append(opa, op) + } + if arg != nil { + opa = append(opa, string(arg)) + } + c.Tracef(format, opa) +} + +// Process the information messages from Clients and other Routes. +func (c *client) processInfo(arg []byte) error { + info := Info{} + if err := json.Unmarshal(arg, &info); err != nil { + return err + } + if c.typ == ROUTER { + c.processRouteInfo(&info) + } + return nil +} + +func (c *client) processErr(errStr string) { + switch c.typ { + case CLIENT: + c.Errorf("Client Error %s", errStr) + case ROUTER: + c.Errorf("Route Error %s", errStr) + } + c.closeConnection(ParseError) +} + +func (c *client) processConnect(arg []byte) error { + c.traceInOp("CONNECT", arg) + + c.mu.Lock() + // If we can't stop the timer because the callback is in progress... + if !c.clearAuthTimer() { + // wait for it to finish and handle sending the failure back to + // the client. + for c.nc != nil { + c.mu.Unlock() + time.Sleep(25 * time.Millisecond) + c.mu.Lock() + } + c.mu.Unlock() + return nil + } + c.last = time.Now() + typ := c.typ + r := c.route + srv := c.srv + // Moved unmarshalling of clients' Options under the lock. + // The client has already been added to the server map, so it is possible + // that other routines lookup the client, and access its options under + // the client's lock, so unmarshalling the options outside of the lock + // would cause data RACEs. + if err := json.Unmarshal(arg, &c.opts); err != nil { + c.mu.Unlock() + return err + } + // Indicate that the CONNECT protocol has been received, and that the + // server now knows which protocol this client supports. + c.flags.set(connectReceived) + // Capture these under lock + c.echo = c.opts.Echo + proto := c.opts.Protocol + verbose := c.opts.Verbose + lang := c.opts.Lang + c.mu.Unlock() + + if srv != nil { + // As soon as c.opts is unmarshalled and if the proto is at + // least ClientProtoInfo, we need to increment the following counter. + // This is decremented when client is removed from the server's + // clients map. + if proto >= ClientProtoInfo { + srv.mu.Lock() + srv.cproto++ + srv.mu.Unlock() + } + + // Check for Auth + if ok := srv.checkAuthorization(c); !ok { + c.authViolation() + return ErrAuthorization + } + } + + // Check client protocol request if it exists. + if typ == CLIENT && (proto < ClientProtoZero || proto > ClientProtoInfo) { + c.sendErr(ErrBadClientProtocol.Error()) + c.closeConnection(BadClientProtocolVersion) + return ErrBadClientProtocol + } else if typ == ROUTER && lang != "" { + // Way to detect clients that incorrectly connect to the route listen + // port. Client provide Lang in the CONNECT protocol while ROUTEs don't. + c.sendErr(ErrClientConnectedToRoutePort.Error()) + c.closeConnection(WrongPort) + return ErrClientConnectedToRoutePort + } + + // Grab connection name of remote route. + if typ == ROUTER && r != nil { + c.mu.Lock() + c.route.remoteID = c.opts.Name + c.mu.Unlock() + } + + if verbose { + c.sendOK() + } + return nil +} + +func (c *client) authTimeout() { + c.sendErr(ErrAuthTimeout.Error()) + c.Debugf("Authorization Timeout") + c.closeConnection(AuthenticationTimeout) +} + +func (c *client) authViolation() { + if c.srv != nil && c.srv.getOpts().Users != nil { + c.Errorf("%s - User %q", + ErrAuthorization.Error(), + c.opts.Username) + } else { + c.Errorf(ErrAuthorization.Error()) + } + c.sendErr("Authorization Violation") + c.closeConnection(AuthenticationViolation) +} + +func (c *client) maxConnExceeded() { + c.Errorf(ErrTooManyConnections.Error()) + c.sendErr(ErrTooManyConnections.Error()) + c.closeConnection(MaxConnectionsExceeded) +} + +func (c *client) maxSubsExceeded() { + c.Errorf(ErrTooManySubs.Error()) + c.sendErr(ErrTooManySubs.Error()) +} + +func (c *client) maxPayloadViolation(sz int, max int64) { + c.Errorf("%s: %d vs %d", ErrMaxPayload.Error(), sz, max) + c.sendErr("Maximum Payload Violation") + c.closeConnection(MaxPayloadExceeded) +} + +// queueOutbound queues data for client/route connections. +// Return pending length. +// Lock should be held. +func (c *client) queueOutbound(data []byte) { + // Add to pending bytes total. + c.out.pb += int64(len(data)) + + // Check for slow consumer via pending bytes limit. + // ok to return here, client is going away. + if c.out.pb > c.out.mp { + c.clearConnection(SlowConsumerPendingBytes) + atomic.AddInt64(&c.srv.slowConsumers, 1) + c.Noticef("Slow Consumer Detected: MaxPending of %d Exceeded", c.out.mp) + return + } + + if c.out.p == nil && len(data) < maxBufSize { + if c.out.sz == 0 { + c.out.sz = startBufSize + } + if c.out.s != nil && cap(c.out.s) >= c.out.sz { + c.out.p = c.out.s + c.out.s = nil + } else { + // FIXME(dlc) - make power of 2 if less than maxBufSize? + c.out.p = make([]byte, 0, c.out.sz) + } + } + // Determine if we copy or reference + available := cap(c.out.p) - len(c.out.p) + if len(data) > available { + // We can fit into existing primary, but message will fit in next one + // we allocate or utilize from the secondary. So copy what we can. + if available > 0 && len(data) < c.out.sz { + c.out.p = append(c.out.p, data[:available]...) + data = data[available:] + } + // Put the primary on the nb if it has a payload + if len(c.out.p) > 0 { + c.out.nb = append(c.out.nb, c.out.p) + c.out.p = nil + } + // Check for a big message, and if found place directly on nb + // FIXME(dlc) - do we need signaling of ownership here if we want len(data) < + if len(data) > maxBufSize { + c.out.nb = append(c.out.nb, data) + } else { + // We will copy to primary. + if c.out.p == nil { + // Grow here + if (c.out.sz << 1) <= maxBufSize { + c.out.sz <<= 1 + } + if len(data) > c.out.sz { + c.out.p = make([]byte, 0, len(data)) + } else { + if c.out.s != nil && cap(c.out.s) >= c.out.sz { // TODO(dlc) - Size mismatch? + c.out.p = c.out.s + c.out.s = nil + } else { + c.out.p = make([]byte, 0, c.out.sz) + } + } + } + c.out.p = append(c.out.p, data...) + } + } else { + c.out.p = append(c.out.p, data...) + } +} + +// Assume the lock is held upon entry. +func (c *client) sendProto(info []byte, doFlush bool) { + if c.nc == nil { + return + } + c.queueOutbound(info) + if !(doFlush && c.flushOutbound()) { + c.flushSignal() + } +} + +// Assume the lock is held upon entry. +func (c *client) sendPong() { + c.traceOutOp("PONG", nil) + c.sendProto([]byte("PONG\r\n"), true) +} + +// Assume the lock is held upon entry. +func (c *client) sendPing() { + c.rttStart = time.Now() + c.ping.out++ + c.traceOutOp("PING", nil) + c.sendProto([]byte("PING\r\n"), true) +} + +// Generates the INFO to be sent to the client with the client ID included. +// info arg will be copied since passed by value. +// Assume lock is held. +func (c *client) generateClientInfoJSON(info Info) []byte { + info.CID = c.cid + // Generate the info json + b, _ := json.Marshal(info) + pcs := [][]byte{[]byte("INFO"), b, []byte(CR_LF)} + return bytes.Join(pcs, []byte(" ")) +} + +// Assume the lock is held upon entry. +func (c *client) sendInfo(info []byte) { + c.sendProto(info, true) +} + +func (c *client) sendErr(err string) { + c.mu.Lock() + c.traceOutOp("-ERR", []byte(err)) + c.sendProto([]byte(fmt.Sprintf("-ERR '%s'\r\n", err)), true) + c.mu.Unlock() +} + +func (c *client) sendOK() { + c.mu.Lock() + c.traceOutOp("OK", nil) + // Can not autoflush this one, needs to be async. + c.sendProto([]byte("+OK\r\n"), false) + // FIXME(dlc) - ?? + c.pcd[c] = needFlush + c.mu.Unlock() +} + +func (c *client) processPing() { + c.mu.Lock() + c.traceInOp("PING", nil) + if c.nc == nil { + c.mu.Unlock() + return + } + c.sendPong() + + // The CONNECT should have been received, but make sure it + // is so before proceeding + if !c.flags.isSet(connectReceived) { + c.mu.Unlock() + return + } + // If we are here, the CONNECT has been received so we know + // if this client supports async INFO or not. + var ( + checkClusterChange bool + srv = c.srv + ) + // For older clients, just flip the firstPongSent flag if not already + // set and we are done. + if c.opts.Protocol < ClientProtoInfo || srv == nil { + c.flags.setIfNotSet(firstPongSent) + } else { + // This is a client that supports async INFO protocols. + // If this is the first PING (so firstPongSent is not set yet), + // we will need to check if there was a change in cluster topology. + checkClusterChange = !c.flags.isSet(firstPongSent) + } + c.mu.Unlock() + + if checkClusterChange { + srv.mu.Lock() + c.mu.Lock() + // Now that we are under both locks, we can flip the flag. + // This prevents sendAsyncInfoToClients() and and code here + // to send a double INFO protocol. + c.flags.set(firstPongSent) + // If there was a cluster update since this client was created, + // send an updated INFO protocol now. + if srv.lastCURLsUpdate >= c.start.UnixNano() { + c.sendInfo(c.generateClientInfoJSON(srv.copyInfo())) + } + c.mu.Unlock() + srv.mu.Unlock() + } +} + +func (c *client) processPong() { + c.traceInOp("PONG", nil) + c.mu.Lock() + c.ping.out = 0 + c.rtt = time.Since(c.rttStart) + c.mu.Unlock() +} + +func (c *client) processMsgArgs(arg []byte) error { + if c.trace { + c.traceInOp("MSG", arg) + } + + // Unroll splitArgs to avoid runtime/heap issues + a := [MAX_MSG_ARGS][]byte{} + args := a[:0] + start := -1 + for i, b := range arg { + switch b { + case ' ', '\t', '\r', '\n': + if start >= 0 { + args = append(args, arg[start:i]) + start = -1 + } + default: + if start < 0 { + start = i + } + } + } + if start >= 0 { + args = append(args, arg[start:]) + } + + switch len(args) { + case 3: + c.pa.reply = nil + c.pa.szb = args[2] + c.pa.size = parseSize(args[2]) + case 4: + c.pa.reply = args[2] + c.pa.szb = args[3] + c.pa.size = parseSize(args[3]) + default: + return fmt.Errorf("processMsgArgs Parse Error: '%s'", arg) + } + if c.pa.size < 0 { + return fmt.Errorf("processMsgArgs Bad or Missing Size: '%s'", arg) + } + + // Common ones processed after check for arg length + c.pa.subject = args[0] + c.pa.sid = args[1] + + return nil +} + +func (c *client) processPub(arg []byte) error { + if c.trace { + c.traceInOp("PUB", arg) + } + + // Unroll splitArgs to avoid runtime/heap issues + a := [MAX_PUB_ARGS][]byte{} + args := a[:0] + start := -1 + for i, b := range arg { + switch b { + case ' ', '\t': + if start >= 0 { + args = append(args, arg[start:i]) + start = -1 + } + default: + if start < 0 { + start = i + } + } + } + if start >= 0 { + args = append(args, arg[start:]) + } + + switch len(args) { + case 2: + c.pa.subject = args[0] + c.pa.reply = nil + c.pa.size = parseSize(args[1]) + c.pa.szb = args[1] + case 3: + c.pa.subject = args[0] + c.pa.reply = args[1] + c.pa.size = parseSize(args[2]) + c.pa.szb = args[2] + default: + return fmt.Errorf("processPub Parse Error: '%s'", arg) + } + if c.pa.size < 0 { + return fmt.Errorf("processPub Bad or Missing Size: '%s'", arg) + } + maxPayload := atomic.LoadInt64(&c.mpay) + if maxPayload > 0 && int64(c.pa.size) > maxPayload { + c.maxPayloadViolation(c.pa.size, maxPayload) + return ErrMaxPayload + } + + if c.opts.Pedantic && !IsValidLiteralSubject(string(c.pa.subject)) { + c.sendErr("Invalid Publish Subject") + } + return nil +} + +func splitArg(arg []byte) [][]byte { + a := [MAX_MSG_ARGS][]byte{} + args := a[:0] + start := -1 + for i, b := range arg { + switch b { + case ' ', '\t', '\r', '\n': + if start >= 0 { + args = append(args, arg[start:i]) + start = -1 + } + default: + if start < 0 { + start = i + } + } + } + if start >= 0 { + args = append(args, arg[start:]) + } + return args +} + +func (c *client) processSub(argo []byte) (err error) { + c.traceInOp("SUB", argo) + + // Indicate activity. + c.in.subs++ + + // Copy so we do not reference a potentially large buffer + arg := make([]byte, len(argo)) + copy(arg, argo) + args := splitArg(arg) + sub := &subscription{client: c} + switch len(args) { + case 2: + sub.subject = args[0] + sub.queue = nil + sub.sid = args[1] + case 3: + sub.subject = args[0] + sub.queue = args[1] + sub.sid = args[2] + default: + return fmt.Errorf("processSub Parse Error: '%s'", arg) + } + + shouldForward := false + + c.mu.Lock() + if c.nc == nil { + c.mu.Unlock() + return nil + } + + // Check permissions if applicable. + if c.typ == ROUTER { + if !c.canExport(sub.subject) { + c.mu.Unlock() + return nil + } + } else if !c.canSubscribe(sub.subject) { + c.mu.Unlock() + c.sendErr(fmt.Sprintf("Permissions Violation for Subscription to %q", sub.subject)) + c.Errorf("Subscription Violation - User %q, Subject %q, SID %s", + c.opts.Username, sub.subject, sub.sid) + return nil + } + + if c.msubs > 0 && len(c.subs) >= c.msubs { + c.mu.Unlock() + c.maxSubsExceeded() + return nil + } + + // Check if we have a maximum on the number of subscriptions. + // We can have two SUB protocols coming from a route due to some + // race conditions. We should make sure that we process only one. + sid := string(sub.sid) + if c.subs[sid] == nil { + c.subs[sid] = sub + if c.srv != nil { + err = c.srv.sl.Insert(sub) + if err != nil { + delete(c.subs, sid) + } else { + shouldForward = c.typ != ROUTER + } + } + } + c.mu.Unlock() + if err != nil { + c.sendErr("Invalid Subject") + return nil + } else if c.opts.Verbose { + c.sendOK() + } + if shouldForward { + c.srv.broadcastSubscribe(sub) + } + + return nil +} + +// canSubscribe determines if the client is authorized to subscribe to the +// given subject. Assumes caller is holding lock. +func (c *client) canSubscribe(subject []byte) bool { + if c.perms == nil { + return true + } + + allowed := true + + // Check allow list. If no allow list that means all are allowed. Deny can overrule. + if c.perms.sub.allow != nil { + r := c.perms.sub.allow.Match(string(subject)) + allowed = len(r.psubs) != 0 + } + // If we have a deny list and we think we are allowed, check that as well. + if allowed && c.perms.sub.deny != nil { + r := c.perms.sub.deny.Match(string(subject)) + allowed = len(r.psubs) == 0 + } + return allowed +} + +// Low level unsubscribe for a given client. +func (c *client) unsubscribe(sub *subscription) { + c.mu.Lock() + defer c.mu.Unlock() + if sub.max > 0 && sub.nm < sub.max { + c.Debugf( + "Deferring actual UNSUB(%s): %d max, %d received\n", + string(sub.subject), sub.max, sub.nm) + return + } + c.traceOp("<-> %s", "DELSUB", sub.sid) + + delete(c.subs, string(sub.sid)) + if c.srv != nil { + c.srv.sl.Remove(sub) + } + + // If we are a queue subscriber on a client connection and we have routes, + // we will remember the remote sid and the queue group in case a route + // tries to deliver us a message. Remote queue subscribers are directed + // so we need to know what to do to avoid unnecessary message drops + // from [auto-]unsubscribe. + if c.typ == CLIENT && c.srv != nil && len(sub.queue) > 0 { + c.srv.holdRemoteQSub(sub) + } +} + +func (c *client) processUnsub(arg []byte) error { + c.traceInOp("UNSUB", arg) + args := splitArg(arg) + var sid []byte + max := -1 + + switch len(args) { + case 1: + sid = args[0] + case 2: + sid = args[0] + max = parseSize(args[1]) + default: + return fmt.Errorf("processUnsub Parse Error: '%s'", arg) + } + + // Indicate activity. + c.in.subs += 1 + + var sub *subscription + + unsub := false + shouldForward := false + ok := false + + c.mu.Lock() + if sub, ok = c.subs[string(sid)]; ok { + if max > 0 { + sub.max = int64(max) + } else { + // Clear it here to override + sub.max = 0 + } + unsub = true + shouldForward = c.typ != ROUTER && c.srv != nil + } + c.mu.Unlock() + + if unsub { + c.unsubscribe(sub) + } + if shouldForward { + c.srv.broadcastUnSubscribe(sub) + } + if c.opts.Verbose { + c.sendOK() + } + + return nil +} + +func (c *client) msgHeader(mh []byte, sub *subscription) []byte { + mh = append(mh, sub.sid...) + mh = append(mh, ' ') + if c.pa.reply != nil { + mh = append(mh, c.pa.reply...) + mh = append(mh, ' ') + } + mh = append(mh, c.pa.szb...) + mh = append(mh, "\r\n"...) + return mh +} + +// Used to treat maps as efficient set +var needFlush = struct{}{} +var routeSeen = struct{}{} + +func (c *client) deliverMsg(sub *subscription, mh, msg []byte) bool { + if sub.client == nil { + return false + } + client := sub.client + client.mu.Lock() + + // Check echo + if c == client && !client.echo { + client.mu.Unlock() + return false + } + + srv := client.srv + + sub.nm++ + // Check if we should auto-unsubscribe. + if sub.max > 0 { + // For routing.. + shouldForward := client.typ != ROUTER && client.srv != nil + // If we are at the exact number, unsubscribe but + // still process the message in hand, otherwise + // unsubscribe and drop message on the floor. + if sub.nm == sub.max { + c.Debugf("Auto-unsubscribe limit of %d reached for sid '%s'\n", sub.max, string(sub.sid)) + // Due to defer, reverse the code order so that execution + // is consistent with other cases where we unsubscribe. + if shouldForward { + defer srv.broadcastUnSubscribe(sub) + } + defer client.unsubscribe(sub) + } else if sub.nm > sub.max { + c.Debugf("Auto-unsubscribe limit [%d] exceeded\n", sub.max) + client.mu.Unlock() + client.unsubscribe(sub) + if shouldForward { + srv.broadcastUnSubscribe(sub) + } + return false + } + } + + // Check for closed connection + if client.nc == nil { + client.mu.Unlock() + return false + } + + // Update statistics + + // The msg includes the CR_LF, so pull back out for accounting. + msgSize := int64(len(msg) - LEN_CR_LF) + + // No atomic needed since accessed under client lock. + // Monitor is reading those also under client's lock. + client.outMsgs++ + client.outBytes += msgSize + + atomic.AddInt64(&srv.outMsgs, 1) + atomic.AddInt64(&srv.outBytes, msgSize) + + // Queue to outbound buffer + client.queueOutbound(mh) + client.queueOutbound(msg) + + client.out.pm++ + + // Check outbound threshold and queue IO flush if needed. + if client.out.pm > 1 && client.out.pb > maxBufSize*2 { + client.flushSignal() + } + + if c.trace { + client.traceOutOp(string(mh[:len(mh)-LEN_CR_LF]), nil) + } + + // Increment the flush pending signals if we are setting for the first time. + if _, ok := c.pcd[client]; !ok { + client.out.fsp++ + } + client.mu.Unlock() + + // Remember for when we return to the top of the loop. + c.pcd[client] = needFlush + + return true +} + +// pruneCache will prune the cache via randomly +// deleting items. Doing so pruneSize items at a time. +func (c *client) prunePubPermsCache() { + r := 0 + for subject := range c.perms.pcache { + delete(c.perms.pcache, subject) + if r++; r > pruneSize { + break + } + } +} + +// pubAllowed checks on publish permissioning. +func (c *client) pubAllowed(subject []byte) bool { + // Disallow publish to _SYS.>, these are reserved for internals. + if len(subject) > 4 && string(subject[:5]) == "_SYS." { + return false + } + if c.perms == nil { + return true + } + + // Check if published subject is allowed if we have permissions in place. + allowed, ok := c.perms.pcache[string(subject)] + if ok { + return allowed + } + + // Cache miss, check allow then deny as needed. + if c.perms.pub.allow != nil { + r := c.perms.pub.allow.Match(string(subject)) + allowed = len(r.psubs) != 0 + } else { + // No entries means all are allowed. Deny will overrule as needed. + allowed = true + } + // If we have a deny list and are currently allowed, check that as well. + if allowed && c.perms.pub.deny != nil { + r := c.perms.pub.deny.Match(string(subject)) + allowed = len(r.psubs) == 0 + } + + // Update our cache here. + c.perms.pcache[string(subject)] = allowed + + // Prune if needed. + if len(c.perms.pcache) > maxPermCacheSize { + c.prunePubPermsCache() + } + return allowed +} + +// prepMsgHeader will prepare the message header prefix +func (c *client) prepMsgHeader() []byte { + // Use the scratch buffer.. + msgh := c.msgb[:msgHeadProtoLen] + + // msg header + msgh = append(msgh, c.pa.subject...) + return append(msgh, ' ') +} + +// processMsg is called to process an inbound msg from a client. +func (c *client) processMsg(msg []byte) { + // Snapshot server. + srv := c.srv + + // Update statistics + // The msg includes the CR_LF, so pull back out for accounting. + c.in.msgs += 1 + c.in.bytes += len(msg) - LEN_CR_LF + + if c.trace { + c.traceMsg(msg) + } + + // Check pub permissions (don't do this for routes) + if c.typ == CLIENT && !c.pubAllowed(c.pa.subject) { + c.pubPermissionViolation(c.pa.subject) + return + } + + if c.opts.Verbose { + c.sendOK() + } + + // Mostly under testing scenarios. + if srv == nil { + return + } + + // Match the subscriptions. We will use our own L1 map if + // it's still valid, avoiding contention on the shared sublist. + var r *SublistResult + var ok bool + + genid := atomic.LoadUint64(&srv.sl.genid) + + if genid == c.in.genid && c.in.results != nil { + r, ok = c.in.results[string(c.pa.subject)] + } else { + // reset our L1 completely. + c.in.results = make(map[string]*SublistResult) + c.in.genid = genid + } + + if !ok { + subject := string(c.pa.subject) + r = srv.sl.Match(subject) + c.in.results[subject] = r + // Prune the results cache. Keeps us from unbounded growth. + if len(c.in.results) > maxResultCacheSize { + n := 0 + for subject := range c.in.results { + delete(c.in.results, subject) + if n++; n > pruneSize { + break + } + } + } + } + + // This is the fanout scale. + fanout := len(r.psubs) + len(r.qsubs) + + // Check for no interest, short circuit if so. + if fanout == 0 { + return + } + + if c.typ == ROUTER { + c.processRoutedMsg(r, msg) + return + } + + // Client connection processing here. + msgh := c.prepMsgHeader() + si := len(msgh) + + // Used to only send messages once across any given route. + var rmap map[string]struct{} + + // Loop over all normal subscriptions that match. + for _, sub := range r.psubs { + // Check if this is a send to a ROUTER, make sure we only send it + // once. The other side will handle the appropriate re-processing + // and fan-out. Also enforce 1-Hop semantics, so no routing to another. + if sub.client.typ == ROUTER { + // Check to see if we have already sent it here. + if rmap == nil { + rmap = make(map[string]struct{}, srv.numRoutes()) + } + sub.client.mu.Lock() + if sub.client.nc == nil || + sub.client.route == nil || + sub.client.route.remoteID == "" { + c.Debugf("Bad or Missing ROUTER Identity, not processing msg") + sub.client.mu.Unlock() + continue + } + if _, ok := rmap[sub.client.route.remoteID]; ok { + c.Debugf("Ignoring route, already processed and sent msg") + sub.client.mu.Unlock() + continue + } + rmap[sub.client.route.remoteID] = routeSeen + sub.client.mu.Unlock() + } + // Normal delivery + mh := c.msgHeader(msgh[:si], sub) + c.deliverMsg(sub, mh, msg) + } + + // Check to see if we have our own rand yet. Global rand + // has contention with lots of clients, etc. + if c.in.prand == nil { + c.in.prand = rand.New(rand.NewSource(time.Now().UnixNano())) + } + // Process queue subs + for i := 0; i < len(r.qsubs); i++ { + qsubs := r.qsubs[i] + // Find a subscription that is able to deliver this message + // starting at a random index. + startIndex := c.in.prand.Intn(len(qsubs)) + for i := 0; i < len(qsubs); i++ { + index := (startIndex + i) % len(qsubs) + sub := qsubs[index] + if sub != nil { + mh := c.msgHeader(msgh[:si], sub) + if c.deliverMsg(sub, mh, msg) { + break + } + } + } + } +} + +func (c *client) pubPermissionViolation(subject []byte) { + c.sendErr(fmt.Sprintf("Permissions Violation for Publish to %q", subject)) + c.Errorf("Publish Violation - User %q, Subject %q", c.opts.Username, subject) +} + +func (c *client) processPingTimer() { + c.mu.Lock() + defer c.mu.Unlock() + c.ping.tmr = nil + // Check if connection is still opened + if c.nc == nil { + return + } + + c.Debugf("%s Ping Timer", c.typeString()) + + // Check for violation + if c.ping.out+1 > c.srv.getOpts().MaxPingsOut { + c.Debugf("Stale Client Connection - Closing") + c.sendProto([]byte(fmt.Sprintf("-ERR '%s'\r\n", "Stale Connection")), true) + c.clearConnection(StaleConnection) + return + } + + // If we have had activity within the PingInterval no + // need to send a ping. + if delta := time.Since(c.last); delta < c.srv.getOpts().PingInterval { + c.Debugf("Delaying PING due to activity %v ago", delta.Round(time.Second)) + } else { + // Send PING + c.sendPing() + } + + // Reset to fire again. + c.setPingTimer() +} + +// Lock should be held +func (c *client) setPingTimer() { + if c.srv == nil { + return + } + d := c.srv.getOpts().PingInterval + c.ping.tmr = time.AfterFunc(d, c.processPingTimer) +} + +// Lock should be held +func (c *client) clearPingTimer() { + if c.ping.tmr == nil { + return + } + c.ping.tmr.Stop() + c.ping.tmr = nil +} + +// Lock should be held +func (c *client) setAuthTimer(d time.Duration) { + c.atmr = time.AfterFunc(d, func() { c.authTimeout() }) +} + +// Lock should be held +func (c *client) clearAuthTimer() bool { + if c.atmr == nil { + return true + } + stopped := c.atmr.Stop() + c.atmr = nil + return stopped +} + +func (c *client) isAuthTimerSet() bool { + c.mu.Lock() + isSet := c.atmr != nil + c.mu.Unlock() + return isSet +} + +// Lock should be held +func (c *client) clearConnection(reason ClosedState) { + if c.flags.isSet(clearConnection) { + return + } + c.flags.set(clearConnection) + + nc := c.nc + if nc == nil || c.srv == nil { + return + } + // Flush any pending. + c.flushOutbound() + + // Clear outbound here. + c.out.sg.Broadcast() + + // With TLS, Close() is sending an alert (that is doing a write). + // Need to set a deadline otherwise the server could block there + // if the peer is not reading from socket. + if c.flags.isSet(handshakeComplete) { + nc.SetWriteDeadline(time.Now().Add(c.out.wdl)) + } + nc.Close() + // Do this always to also kick out any IO writes. + nc.SetWriteDeadline(time.Time{}) + + // Save off the connection if its a client. + if c.typ == CLIENT && c.srv != nil { + go c.srv.saveClosedClient(c, nc, reason) + } +} + +func (c *client) typeString() string { + switch c.typ { + case CLIENT: + return "Client" + case ROUTER: + return "Router" + } + return "Unknown Type" +} + +func (c *client) closeConnection(reason ClosedState) { + c.mu.Lock() + if c.nc == nil { + c.mu.Unlock() + return + } + + c.Debugf("%s connection closed", c.typeString()) + + c.clearAuthTimer() + c.clearPingTimer() + c.clearConnection(reason) + c.nc = nil + + // Snapshot for use. + subs := make([]*subscription, 0, len(c.subs)) + for _, sub := range c.subs { + // Auto-unsubscribe subscriptions must be unsubscribed forcibly. + sub.max = 0 + subs = append(subs, sub) + } + srv := c.srv + + var ( + routeClosed bool + retryImplicit bool + connectURLs []string + ) + if c.route != nil { + routeClosed = c.route.closed + if !routeClosed { + retryImplicit = c.route.retry + } + connectURLs = c.route.connectURLs + } + + c.mu.Unlock() + + if srv != nil { + // This is a route that disconnected... + if len(connectURLs) > 0 { + // Unless disabled, possibly update the server's INFO protcol + // and send to clients that know how to handle async INFOs. + if !srv.getOpts().Cluster.NoAdvertise { + srv.removeClientConnectURLsAndSendINFOToClients(connectURLs) + } + } + + // Unregister + srv.removeClient(c) + + // Remove clients subscriptions. + srv.sl.RemoveBatch(subs) + if c.typ != ROUTER { + for _, sub := range subs { + // Forward on unsubscribes if we are not + // a router ourselves. + srv.broadcastUnSubscribe(sub) + } + } + } + + // Don't reconnect routes that are being closed. + if routeClosed { + return + } + + // Check for a solicited route. If it was, start up a reconnect unless + // we are already connected to the other end. + if c.isSolicitedRoute() || retryImplicit { + // Capture these under lock + c.mu.Lock() + rid := c.route.remoteID + rtype := c.route.routeType + rurl := c.route.url + c.mu.Unlock() + + srv.mu.Lock() + defer srv.mu.Unlock() + + // It is possible that the server is being shutdown. + // If so, don't try to reconnect + if !srv.running { + return + } + + if rid != "" && srv.remotes[rid] != nil { + c.srv.Debugf("Not attempting reconnect for solicited route, already connected to \"%s\"", rid) + return + } else if rid == srv.info.ID { + c.srv.Debugf("Detected route to self, ignoring \"%s\"", rurl) + return + } else if rtype != Implicit || retryImplicit { + c.srv.Debugf("Attempting reconnect for solicited route \"%s\"", rurl) + // Keep track of this go-routine so we can wait for it on + // server shutdown. + srv.startGoRoutine(func() { srv.reConnectToRoute(rurl, rtype) }) + } + } +} + +// If the client is a route connection, sets the `closed` flag to true +// to prevent any reconnecting attempt when c.closeConnection() is called. +func (c *client) setRouteNoReconnectOnClose() { + c.mu.Lock() + if c.route != nil { + c.route.closed = true + } + c.mu.Unlock() +} + +// Logging functionality scoped to a client or route. + +func (c *client) Errorf(format string, v ...interface{}) { + format = fmt.Sprintf("%s - %s", c, format) + c.srv.Errorf(format, v...) +} + +func (c *client) Debugf(format string, v ...interface{}) { + format = fmt.Sprintf("%s - %s", c, format) + c.srv.Debugf(format, v...) +} + +func (c *client) Noticef(format string, v ...interface{}) { + format = fmt.Sprintf("%s - %s", c, format) + c.srv.Noticef(format, v...) +} + +func (c *client) Tracef(format string, v ...interface{}) { + format = fmt.Sprintf("%s - %s", c, format) + c.srv.Tracef(format, v...) +} diff --git a/vendor/github.com/nats-io/gnatsd/server/const.go b/vendor/github.com/nats-io/gnatsd/server/const.go new file mode 100644 index 00000000000..5a8362acc1c --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/const.go @@ -0,0 +1,124 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "time" +) + +// Command is a signal used to control a running gnatsd process. +type Command string + +// Valid Command values. +const ( + CommandStop = Command("stop") + CommandQuit = Command("quit") + CommandReopen = Command("reopen") + CommandReload = Command("reload") +) + +var ( + // gitCommit injected at build + gitCommit string +) + +const ( + // VERSION is the current version for the server. + VERSION = "1.3.0" + + // PROTO is the currently supported protocol. + // 0 was the original + // 1 maintains proto 0, adds echo abilities for CONNECT from the client. Clients + // should not send echo unless proto in INFO is >= 1. + PROTO = 1 + + // DEFAULT_PORT is the default port for client connections. + DEFAULT_PORT = 4222 + + // RANDOM_PORT is the value for port that, when supplied, will cause the + // server to listen on a randomly-chosen available port. The resolved port + // is available via the Addr() method. + RANDOM_PORT = -1 + + // DEFAULT_HOST defaults to all interfaces. + DEFAULT_HOST = "0.0.0.0" + + // MAX_CONTROL_LINE_SIZE is the maximum allowed protocol control line size. + // 1k should be plenty since payloads sans connect string are separate + MAX_CONTROL_LINE_SIZE = 1024 + + // MAX_PAYLOAD_SIZE is the maximum allowed payload size. Should be using + // something different if > 1MB payloads are needed. + MAX_PAYLOAD_SIZE = (1024 * 1024) + + // MAX_PENDING_SIZE is the maximum outbound pending bytes per client. + MAX_PENDING_SIZE = (256 * 1024 * 1024) + + // DEFAULT_MAX_CONNECTIONS is the default maximum connections allowed. + DEFAULT_MAX_CONNECTIONS = (64 * 1024) + + // TLS_TIMEOUT is the TLS wait time. + TLS_TIMEOUT = 500 * time.Millisecond + + // AUTH_TIMEOUT is the authorization wait time. + AUTH_TIMEOUT = 2 * TLS_TIMEOUT + + // DEFAULT_PING_INTERVAL is how often pings are sent to clients and routes. + DEFAULT_PING_INTERVAL = 2 * time.Minute + + // DEFAULT_PING_MAX_OUT is maximum allowed pings outstanding before disconnect. + DEFAULT_PING_MAX_OUT = 2 + + // CR_LF string + CR_LF = "\r\n" + + // LEN_CR_LF hold onto the computed size. + LEN_CR_LF = len(CR_LF) + + // DEFAULT_FLUSH_DEADLINE is the write/flush deadlines. + DEFAULT_FLUSH_DEADLINE = 2 * time.Second + + // DEFAULT_HTTP_PORT is the default monitoring port. + DEFAULT_HTTP_PORT = 8222 + + // ACCEPT_MIN_SLEEP is the minimum acceptable sleep times on temporary errors. + ACCEPT_MIN_SLEEP = 10 * time.Millisecond + + // ACCEPT_MAX_SLEEP is the maximum acceptable sleep times on temporary errors + ACCEPT_MAX_SLEEP = 1 * time.Second + + // DEFAULT_ROUTE_CONNECT Route solicitation intervals. + DEFAULT_ROUTE_CONNECT = 1 * time.Second + + // DEFAULT_ROUTE_RECONNECT Route reconnect intervals. + DEFAULT_ROUTE_RECONNECT = 1 * time.Second + + // DEFAULT_ROUTE_DIAL Route dial timeout. + DEFAULT_ROUTE_DIAL = 1 * time.Second + + // PROTO_SNIPPET_SIZE is the default size of proto to print on parse errors. + PROTO_SNIPPET_SIZE = 32 + + // MAX_MSG_ARGS Maximum possible number of arguments from MSG proto. + MAX_MSG_ARGS = 4 + + // MAX_PUB_ARGS Maximum possible number of arguments from PUB proto. + MAX_PUB_ARGS = 3 + + // DEFAULT_REMOTE_QSUBS_SWEEPER + DEFAULT_REMOTE_QSUBS_SWEEPER = 30 * time.Second + + // DEFAULT_MAX_CLOSED_CLIENTS + DEFAULT_MAX_CLOSED_CLIENTS = 10000 +) diff --git a/vendor/github.com/nats-io/gnatsd/server/errors.go b/vendor/github.com/nats-io/gnatsd/server/errors.go new file mode 100644 index 00000000000..c722bfc43b2 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/errors.go @@ -0,0 +1,51 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import "errors" + +var ( + // ErrConnectionClosed represents an error condition on a closed connection. + ErrConnectionClosed = errors.New("Connection Closed") + + // ErrAuthorization represents an error condition on failed authorization. + ErrAuthorization = errors.New("Authorization Error") + + // ErrAuthTimeout represents an error condition on failed authorization due to timeout. + ErrAuthTimeout = errors.New("Authorization Timeout") + + // ErrMaxPayload represents an error condition when the payload is too big. + ErrMaxPayload = errors.New("Maximum Payload Exceeded") + + // ErrMaxControlLine represents an error condition when the control line is too big. + ErrMaxControlLine = errors.New("Maximum Control Line Exceeded") + + // ErrReservedPublishSubject represents an error condition when sending to a reserved subject, e.g. _SYS.> + ErrReservedPublishSubject = errors.New("Reserved Internal Subject") + + // ErrBadClientProtocol signals a client requested an invalud client protocol. + ErrBadClientProtocol = errors.New("Invalid Client Protocol") + + // ErrTooManyConnections signals a client that the maximum number of connections supported by the + // server has been reached. + ErrTooManyConnections = errors.New("Maximum Connections Exceeded") + + // ErrTooManySubs signals a client that the maximum number of subscriptions per connection + // has been reached. + ErrTooManySubs = errors.New("Maximum Subscriptions Exceeded") + + // ErrClientConnectedToRoutePort represents an error condition when a client + // attempted to connect to the route listen port. + ErrClientConnectedToRoutePort = errors.New("Attempted To Connect To Route Port") +) diff --git a/vendor/github.com/nats-io/gnatsd/server/log.go b/vendor/github.com/nats-io/gnatsd/server/log.go new file mode 100644 index 00000000000..8c2be370f25 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/log.go @@ -0,0 +1,184 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "io" + "os" + "sync/atomic" + + srvlog "github.com/nats-io/gnatsd/logger" +) + +// Logger interface of the NATS Server +type Logger interface { + + // Log a notice statement + Noticef(format string, v ...interface{}) + + // Log a fatal error + Fatalf(format string, v ...interface{}) + + // Log an error + Errorf(format string, v ...interface{}) + + // Log a debug statement + Debugf(format string, v ...interface{}) + + // Log a trace statement + Tracef(format string, v ...interface{}) +} + +// ConfigureLogger configures and sets the logger for the server. +func (s *Server) ConfigureLogger() { + var ( + log Logger + + // Snapshot server options. + opts = s.getOpts() + ) + + syslog := opts.Syslog + if isWindowsService() && opts.LogFile == "" { + // Enable syslog if no log file is specified and we're running as a + // Windows service so that logs are written to the Windows event log. + syslog = true + } + + if opts.LogFile != "" { + log = srvlog.NewFileLogger(opts.LogFile, opts.Logtime, opts.Debug, opts.Trace, true) + } else if opts.RemoteSyslog != "" { + log = srvlog.NewRemoteSysLogger(opts.RemoteSyslog, opts.Debug, opts.Trace) + } else if syslog { + log = srvlog.NewSysLogger(opts.Debug, opts.Trace) + } else { + colors := true + // Check to see if stderr is being redirected and if so turn off color + // Also turn off colors if we're running on Windows where os.Stderr.Stat() returns an invalid handle-error + stat, err := os.Stderr.Stat() + if err != nil || (stat.Mode()&os.ModeCharDevice) == 0 { + colors = false + } + log = srvlog.NewStdLogger(opts.Logtime, opts.Debug, opts.Trace, colors, true) + } + + s.SetLogger(log, opts.Debug, opts.Trace) +} + +// SetLogger sets the logger of the server +func (s *Server) SetLogger(logger Logger, debugFlag, traceFlag bool) { + if debugFlag { + atomic.StoreInt32(&s.logging.debug, 1) + } else { + atomic.StoreInt32(&s.logging.debug, 0) + } + if traceFlag { + atomic.StoreInt32(&s.logging.trace, 1) + } else { + atomic.StoreInt32(&s.logging.trace, 0) + } + s.logging.Lock() + if s.logging.logger != nil { + // Check to see if the logger implements io.Closer. This could be a + // logger from another process embedding the NATS server or a dummy + // test logger that may not implement that interface. + if l, ok := s.logging.logger.(io.Closer); ok { + if err := l.Close(); err != nil { + s.Errorf("Error closing logger: %v", err) + } + } + } + s.logging.logger = logger + s.logging.Unlock() +} + +// If the logger is a file based logger, close and re-open the file. +// This allows for file rotation by 'mv'ing the file then signaling +// the process to trigger this function. +func (s *Server) ReOpenLogFile() { + // Check to make sure this is a file logger. + s.logging.RLock() + ll := s.logging.logger + s.logging.RUnlock() + + if ll == nil { + s.Noticef("File log re-open ignored, no logger") + return + } + + // Snapshot server options. + opts := s.getOpts() + + if opts.LogFile == "" { + s.Noticef("File log re-open ignored, not a file logger") + } else { + fileLog := srvlog.NewFileLogger(opts.LogFile, + opts.Logtime, opts.Debug, opts.Trace, true) + s.SetLogger(fileLog, opts.Debug, opts.Trace) + s.Noticef("File log re-opened") + } +} + +// Noticef logs a notice statement +func (s *Server) Noticef(format string, v ...interface{}) { + s.executeLogCall(func(logger Logger, format string, v ...interface{}) { + logger.Noticef(format, v...) + }, format, v...) +} + +// Errorf logs an error +func (s *Server) Errorf(format string, v ...interface{}) { + s.executeLogCall(func(logger Logger, format string, v ...interface{}) { + logger.Errorf(format, v...) + }, format, v...) +} + +// Fatalf logs a fatal error +func (s *Server) Fatalf(format string, v ...interface{}) { + s.executeLogCall(func(logger Logger, format string, v ...interface{}) { + logger.Fatalf(format, v...) + }, format, v...) +} + +// Debugf logs a debug statement +func (s *Server) Debugf(format string, v ...interface{}) { + if atomic.LoadInt32(&s.logging.debug) == 0 { + return + } + + s.executeLogCall(func(logger Logger, format string, v ...interface{}) { + logger.Debugf(format, v...) + }, format, v...) +} + +// Tracef logs a trace statement +func (s *Server) Tracef(format string, v ...interface{}) { + if atomic.LoadInt32(&s.logging.trace) == 0 { + return + } + + s.executeLogCall(func(logger Logger, format string, v ...interface{}) { + logger.Tracef(format, v...) + }, format, v...) +} + +func (s *Server) executeLogCall(f func(logger Logger, format string, v ...interface{}), format string, args ...interface{}) { + s.logging.RLock() + defer s.logging.RUnlock() + if s.logging.logger == nil { + return + } + + f(s.logging.logger, format, args...) +} diff --git a/vendor/github.com/nats-io/gnatsd/server/monitor.go b/vendor/github.com/nats-io/gnatsd/server/monitor.go new file mode 100644 index 00000000000..550b8083e05 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/monitor.go @@ -0,0 +1,1029 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "net" + "net/http" + "runtime" + "sort" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/nats-io/gnatsd/server/pse" +) + +// Snapshot this +var numCores int + +func init() { + numCores = runtime.NumCPU() +} + +// Connz represents detailed information on current client connections. +type Connz struct { + ID string `json:"server_id"` + Now time.Time `json:"now"` + NumConns int `json:"num_connections"` + Total int `json:"total"` + Offset int `json:"offset"` + Limit int `json:"limit"` + Conns []*ConnInfo `json:"connections"` +} + +// ConnzOptions are the options passed to Connz() +type ConnzOptions struct { + // Sort indicates how the results will be sorted. Check SortOpt for possible values. + // Only the sort by connection ID (ByCid) is ascending, all others are descending. + Sort SortOpt `json:"sort"` + + // Username indicates if user names should be included in the results. + Username bool `json:"auth"` + + // Subscriptions indicates if subscriptions should be included in the results. + Subscriptions bool `json:"subscriptions"` + + // Offset is used for pagination. Connz() only returns connections starting at this + // offset from the global results. + Offset int `json:"offset"` + + // Limit is the maximum number of connections that should be returned by Connz(). + Limit int `json:"limit"` + + // Filter for this explicit client connection. + CID uint64 `json:"cid"` + + // Filter by connection state. + State ConnState `json:"state"` +} + +// For filtering states of connections. We will only have two, open and closed. +type ConnState int + +const ( + ConnOpen = ConnState(iota) + ConnClosed + ConnAll +) + +// ConnInfo has detailed information on a per connection basis. +type ConnInfo struct { + Cid uint64 `json:"cid"` + IP string `json:"ip"` + Port int `json:"port"` + Start time.Time `json:"start"` + LastActivity time.Time `json:"last_activity"` + Stop *time.Time `json:"stop,omitempty"` + Reason string `json:"reason,omitempty"` + RTT string `json:"rtt,omitempty"` + Uptime string `json:"uptime"` + Idle string `json:"idle"` + Pending int `json:"pending_bytes"` + InMsgs int64 `json:"in_msgs"` + OutMsgs int64 `json:"out_msgs"` + InBytes int64 `json:"in_bytes"` + OutBytes int64 `json:"out_bytes"` + NumSubs uint32 `json:"subscriptions"` + Name string `json:"name,omitempty"` + Lang string `json:"lang,omitempty"` + Version string `json:"version,omitempty"` + TLSVersion string `json:"tls_version,omitempty"` + TLSCipher string `json:"tls_cipher_suite,omitempty"` + AuthorizedUser string `json:"authorized_user,omitempty"` + Subs []string `json:"subscriptions_list,omitempty"` +} + +// DefaultConnListSize is the default size of the connection list. +const DefaultConnListSize = 1024 + +// DefaultSubListSize is the default size of the subscriptions list. +const DefaultSubListSize = 1024 + +const defaultStackBufSize = 10000 + +// Connz returns a Connz struct containing inormation about connections. +func (s *Server) Connz(opts *ConnzOptions) (*Connz, error) { + var ( + sortOpt = ByCid + auth bool + subs bool + offset int + limit = DefaultConnListSize + cid = uint64(0) + state = ConnOpen + ) + + if opts != nil { + // If no sort option given or sort is by uptime, then sort by cid + if opts.Sort == "" { + sortOpt = ByCid + } else { + sortOpt = opts.Sort + if !sortOpt.IsValid() { + return nil, fmt.Errorf("Invalid sorting option: %s", sortOpt) + } + } + auth = opts.Username + subs = opts.Subscriptions + offset = opts.Offset + if offset < 0 { + offset = 0 + } + limit = opts.Limit + if limit <= 0 { + limit = DefaultConnListSize + } + // state + state = opts.State + + // ByStop only makes sense on closed connections + if sortOpt == ByStop && state != ConnClosed { + return nil, fmt.Errorf("Sort by stop only valid on closed connections") + } + // ByReason is the same. + if sortOpt == ByReason && state != ConnClosed { + return nil, fmt.Errorf("Sort by reason only valid on closed connections") + } + + // If searching by CID + if opts.CID > 0 { + cid = opts.CID + limit = 1 + } + } + + c := &Connz{ + Offset: offset, + Limit: limit, + Now: time.Now(), + } + + // Open clients + var openClients []*client + // Hold for closed clients if requested. + var closedClients []*closedClient + + // Walk the open client list with server lock held. + s.mu.Lock() + + // copy the server id for monitoring + c.ID = s.info.ID + + // Number of total clients. The resulting ConnInfo array + // may be smaller if pagination is used. + switch state { + case ConnOpen: + c.Total = len(s.clients) + case ConnClosed: + c.Total = s.closed.len() + closedClients = s.closed.closedClients() + case ConnAll: + c.Total = len(s.clients) + s.closed.len() + closedClients = s.closed.closedClients() + } + + totalClients := c.Total + if cid > 0 { // Meaning we only want 1. + totalClients = 1 + } + if state == ConnOpen || state == ConnAll { + openClients = make([]*client, 0, totalClients) + } + + // Data structures for results. + var conns []ConnInfo // Limits allocs for actual ConnInfos. + var pconns ConnInfos + + switch state { + case ConnOpen: + conns = make([]ConnInfo, totalClients) + pconns = make(ConnInfos, totalClients) + case ConnClosed: + pconns = make(ConnInfos, totalClients) + case ConnAll: + conns = make([]ConnInfo, cap(openClients)) + pconns = make(ConnInfos, totalClients) + } + + // Search by individual CID. + if cid > 0 { + if state == ConnClosed || state == ConnAll { + copyClosed := closedClients + closedClients = nil + for _, cc := range copyClosed { + if cc.Cid == cid { + closedClients = []*closedClient{cc} + break + } + } + } else if state == ConnOpen || state == ConnAll { + client := s.clients[cid] + if client != nil { + openClients = append(openClients, client) + } + } + } else { + // Gather all open clients. + if state == ConnOpen || state == ConnAll { + for _, client := range s.clients { + openClients = append(openClients, client) + } + } + } + s.mu.Unlock() + + // Just return with empty array if nothing here. + if len(openClients) == 0 && len(closedClients) == 0 { + c.Conns = ConnInfos{} + return c, nil + } + + // Now whip through and generate ConnInfo entries + + // Open Clients + i := 0 + for _, client := range openClients { + client.mu.Lock() + ci := &conns[i] + ci.fill(client, client.nc, c.Now) + // Fill in subscription data if requested. + if subs && len(client.subs) > 0 { + ci.Subs = make([]string, 0, len(client.subs)) + for _, sub := range client.subs { + ci.Subs = append(ci.Subs, string(sub.subject)) + } + } + // Fill in user if auth requested. + if auth { + ci.AuthorizedUser = client.opts.Username + } + client.mu.Unlock() + pconns[i] = ci + i++ + } + // Closed Clients + var needCopy bool + if subs || auth { + needCopy = true + } + for _, cc := range closedClients { + // Copy if needed for any changes to the ConnInfo + if needCopy { + cx := *cc + cc = &cx + } + // Fill in subscription data if requested. + if subs && len(cc.subs) > 0 { + cc.Subs = cc.subs + } + // Fill in user if auth requested. + if auth { + cc.AuthorizedUser = cc.user + } + pconns[i] = &cc.ConnInfo + i++ + } + + switch sortOpt { + case ByCid, ByStart: + sort.Sort(byCid{pconns}) + case BySubs: + sort.Sort(sort.Reverse(bySubs{pconns})) + case ByPending: + sort.Sort(sort.Reverse(byPending{pconns})) + case ByOutMsgs: + sort.Sort(sort.Reverse(byOutMsgs{pconns})) + case ByInMsgs: + sort.Sort(sort.Reverse(byInMsgs{pconns})) + case ByOutBytes: + sort.Sort(sort.Reverse(byOutBytes{pconns})) + case ByInBytes: + sort.Sort(sort.Reverse(byInBytes{pconns})) + case ByLast: + sort.Sort(sort.Reverse(byLast{pconns})) + case ByIdle: + sort.Sort(sort.Reverse(byIdle{pconns})) + case ByUptime: + sort.Sort(byUptime{pconns, time.Now()}) + case ByStop: + sort.Sort(sort.Reverse(byStop{pconns})) + case ByReason: + sort.Sort(byReason{pconns}) + } + + minoff := c.Offset + maxoff := c.Offset + c.Limit + + maxIndex := totalClients + + // Make sure these are sane. + if minoff > maxIndex { + minoff = maxIndex + } + if maxoff > maxIndex { + maxoff = maxIndex + } + + // Now pare down to the requested size. + // TODO(dlc) - for very large number of connections we + // could save the whole list in a hash, send hash on first + // request and allow users to use has for subsequent pages. + // Low TTL, say < 1sec. + c.Conns = pconns[minoff:maxoff] + c.NumConns = len(c.Conns) + + return c, nil +} + +// Fills in the ConnInfo from the client. +// client should be locked. +func (ci *ConnInfo) fill(client *client, nc net.Conn, now time.Time) { + ci.Cid = client.cid + ci.Start = client.start + ci.LastActivity = client.last + ci.Uptime = myUptime(now.Sub(client.start)) + ci.Idle = myUptime(now.Sub(client.last)) + ci.RTT = client.getRTT() + ci.OutMsgs = client.outMsgs + ci.OutBytes = client.outBytes + ci.NumSubs = uint32(len(client.subs)) + ci.Pending = int(client.out.pb) + ci.Name = client.opts.Name + ci.Lang = client.opts.Lang + ci.Version = client.opts.Version + // inMsgs and inBytes are updated outside of the client's lock, so + // we need to use atomic here. + ci.InMsgs = atomic.LoadInt64(&client.inMsgs) + ci.InBytes = atomic.LoadInt64(&client.inBytes) + + // If the connection is gone, too bad, we won't set TLSVersion and TLSCipher. + // Exclude clients that are still doing handshake so we don't block in + // ConnectionState(). + if client.flags.isSet(handshakeComplete) && nc != nil { + conn := nc.(*tls.Conn) + cs := conn.ConnectionState() + ci.TLSVersion = tlsVersion(cs.Version) + ci.TLSCipher = tlsCipher(cs.CipherSuite) + } + + switch conn := nc.(type) { + case *net.TCPConn, *tls.Conn: + addr := conn.RemoteAddr().(*net.TCPAddr) + ci.Port = addr.Port + ci.IP = addr.IP.String() + } +} + +// Assume lock is held +func (c *client) getRTT() string { + if c.rtt == 0 { + // If a real client, go ahead and send ping now to get a value + // for RTT. For tests and telnet, etc skip. + if c.flags.isSet(connectReceived) && c.opts.Lang != "" { + c.sendPing() + } + return "" + } + var rtt time.Duration + if c.rtt > time.Microsecond && c.rtt < time.Millisecond { + rtt = c.rtt.Truncate(time.Microsecond) + } else { + rtt = c.rtt.Truncate(time.Millisecond) + } + return rtt.String() +} + +func decodeBool(w http.ResponseWriter, r *http.Request, param string) (bool, error) { + str := r.URL.Query().Get(param) + if str == "" { + return false, nil + } + val, err := strconv.ParseBool(str) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("Error decoding boolean for '%s': %v", param, err))) + return false, err + } + return val, nil +} + +func decodeUint64(w http.ResponseWriter, r *http.Request, param string) (uint64, error) { + str := r.URL.Query().Get(param) + if str == "" { + return 0, nil + } + val, err := strconv.ParseUint(str, 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("Error decoding uint64 for '%s': %v", param, err))) + return 0, err + } + return val, nil +} + +func decodeInt(w http.ResponseWriter, r *http.Request, param string) (int, error) { + str := r.URL.Query().Get(param) + if str == "" { + return 0, nil + } + val, err := strconv.Atoi(str) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("Error decoding int for '%s': %v", param, err))) + return 0, err + } + return val, nil +} + +func decodeState(w http.ResponseWriter, r *http.Request) (ConnState, error) { + str := r.URL.Query().Get("state") + if str == "" { + return ConnOpen, nil + } + switch strings.ToLower(str) { + case "open": + return ConnOpen, nil + case "closed": + return ConnClosed, nil + case "any", "all": + return ConnAll, nil + } + // We do not understand intended state here. + w.WriteHeader(http.StatusBadRequest) + err := fmt.Errorf("Error decoding state for %s", str) + w.Write([]byte(err.Error())) + return 0, err +} + +// HandleConnz process HTTP requests for connection information. +func (s *Server) HandleConnz(w http.ResponseWriter, r *http.Request) { + sortOpt := SortOpt(r.URL.Query().Get("sort")) + auth, err := decodeBool(w, r, "auth") + if err != nil { + return + } + subs, err := decodeBool(w, r, "subs") + if err != nil { + return + } + offset, err := decodeInt(w, r, "offset") + if err != nil { + return + } + limit, err := decodeInt(w, r, "limit") + if err != nil { + return + } + cid, err := decodeUint64(w, r, "cid") + if err != nil { + return + } + state, err := decodeState(w, r) + if err != nil { + return + } + + connzOpts := &ConnzOptions{ + Sort: sortOpt, + Username: auth, + Subscriptions: subs, + Offset: offset, + Limit: limit, + CID: cid, + State: state, + } + + s.mu.Lock() + s.httpReqStats[ConnzPath]++ + s.mu.Unlock() + + c, err := s.Connz(connzOpts) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + return + } + b, err := json.MarshalIndent(c, "", " ") + if err != nil { + s.Errorf("Error marshaling response to /connz request: %v", err) + } + + // Handle response + ResponseHandler(w, r, b) +} + +// Routez represents detailed information on current client connections. +type Routez struct { + ID string `json:"server_id"` + Now time.Time `json:"now"` + NumRoutes int `json:"num_routes"` + Routes []*RouteInfo `json:"routes"` +} + +// RoutezOptions are options passed to Routez +type RoutezOptions struct { + // Subscriptions indicates that Routez will return a route's subscriptions + Subscriptions bool `json:"subscriptions"` +} + +// RouteInfo has detailed information on a per connection basis. +type RouteInfo struct { + Rid uint64 `json:"rid"` + RemoteID string `json:"remote_id"` + DidSolicit bool `json:"did_solicit"` + IsConfigured bool `json:"is_configured"` + IP string `json:"ip"` + Port int `json:"port"` + Pending int `json:"pending_size"` + InMsgs int64 `json:"in_msgs"` + OutMsgs int64 `json:"out_msgs"` + InBytes int64 `json:"in_bytes"` + OutBytes int64 `json:"out_bytes"` + NumSubs uint32 `json:"subscriptions"` + Subs []string `json:"subscriptions_list,omitempty"` +} + +// Routez returns a Routez struct containing inormation about routes. +func (s *Server) Routez(routezOpts *RoutezOptions) (*Routez, error) { + rs := &Routez{Routes: []*RouteInfo{}} + rs.Now = time.Now() + + subs := routezOpts != nil && routezOpts.Subscriptions + + // Walk the list + s.mu.Lock() + rs.NumRoutes = len(s.routes) + + // copy the server id for monitoring + rs.ID = s.info.ID + + for _, r := range s.routes { + r.mu.Lock() + ri := &RouteInfo{ + Rid: r.cid, + RemoteID: r.route.remoteID, + DidSolicit: r.route.didSolicit, + IsConfigured: r.route.routeType == Explicit, + InMsgs: atomic.LoadInt64(&r.inMsgs), + OutMsgs: r.outMsgs, + InBytes: atomic.LoadInt64(&r.inBytes), + OutBytes: r.outBytes, + NumSubs: uint32(len(r.subs)), + } + + if subs && len(r.subs) > 0 { + ri.Subs = make([]string, 0, len(r.subs)) + for _, sub := range r.subs { + ri.Subs = append(ri.Subs, string(sub.subject)) + } + } + switch conn := r.nc.(type) { + case *net.TCPConn, *tls.Conn: + addr := conn.RemoteAddr().(*net.TCPAddr) + ri.Port = addr.Port + ri.IP = addr.IP.String() + } + r.mu.Unlock() + rs.Routes = append(rs.Routes, ri) + } + s.mu.Unlock() + return rs, nil +} + +// HandleRoutez process HTTP requests for route information. +func (s *Server) HandleRoutez(w http.ResponseWriter, r *http.Request) { + subs, err := decodeBool(w, r, "subs") + if err != nil { + return + } + var opts *RoutezOptions + if subs { + opts = &RoutezOptions{Subscriptions: true} + } + + s.mu.Lock() + s.httpReqStats[RoutezPath]++ + s.mu.Unlock() + + // As of now, no error is ever returned. + rs, _ := s.Routez(opts) + b, err := json.MarshalIndent(rs, "", " ") + if err != nil { + s.Errorf("Error marshaling response to /routez request: %v", err) + } + + // Handle response + ResponseHandler(w, r, b) +} + +// Subsz represents detail information on current connections. +type Subsz struct { + *SublistStats + Total int `json:"total"` + Offset int `json:"offset"` + Limit int `json:"limit"` + Subs []SubDetail `json:"subscriptions_list,omitempty"` +} + +// SubszOptions are the options passed to Subsz. +// As of now, there are no options defined. +type SubszOptions struct { + // Offset is used for pagination. Subsz() only returns connections starting at this + // offset from the global results. + Offset int `json:"offset"` + + // Limit is the maximum number of subscriptions that should be returned by Subsz(). + Limit int `json:"limit"` + + // Subscriptions indicates if subscriptions should be included in the results. + Subscriptions bool `json:"subscriptions"` + + // Test the list against this subject. Needs to be literal since it signifies a publish subject. + // We will only return subscriptions that would match if a message was sent to this subject. + Test string `json:"test,omitempty"` +} + +type SubDetail struct { + Subject string `json:"subject"` + Queue string `json:"qgroup,omitempty"` + Sid string `json:"sid"` + Msgs int64 `json:"msgs"` + Max int64 `json:"max,omitempty"` + Cid uint64 `json:"cid"` +} + +// Subsz returns a Subsz struct containing subjects statistics +func (s *Server) Subsz(opts *SubszOptions) (*Subsz, error) { + var ( + subdetail bool + test bool + offset int + limit = DefaultSubListSize + testSub = "" + ) + + if opts != nil { + subdetail = opts.Subscriptions + offset = opts.Offset + if offset < 0 { + offset = 0 + } + limit = opts.Limit + if limit <= 0 { + limit = DefaultSubListSize + } + if opts.Test != "" { + testSub = opts.Test + test = true + if !IsValidLiteralSubject(testSub) { + return nil, fmt.Errorf("Invalid test subject, must be valid publish subject: %s", testSub) + } + } + } + + sz := &Subsz{s.sl.Stats(), 0, offset, limit, nil} + + if subdetail { + // Now add in subscription's details + var raw [4096]*subscription + subs := raw[:0] + + s.sl.localSubs(&subs) + details := make([]SubDetail, len(subs)) + i := 0 + // TODO(dlc) - may be inefficient and could just do normal match when total subs is large and filtering. + for _, sub := range subs { + // Check for filter + if test && !matchLiteral(testSub, string(sub.subject)) { + continue + } + if sub.client == nil { + continue + } + sub.client.mu.Lock() + details[i] = SubDetail{ + Subject: string(sub.subject), + Queue: string(sub.queue), + Sid: string(sub.sid), + Msgs: sub.nm, + Max: sub.max, + Cid: sub.client.cid, + } + sub.client.mu.Unlock() + i++ + } + minoff := sz.Offset + maxoff := sz.Offset + sz.Limit + + maxIndex := i + + // Make sure these are sane. + if minoff > maxIndex { + minoff = maxIndex + } + if maxoff > maxIndex { + maxoff = maxIndex + } + sz.Subs = details[minoff:maxoff] + sz.Total = len(sz.Subs) + } + + return sz, nil +} + +// HandleSubsz processes HTTP requests for subjects stats. +func (s *Server) HandleSubsz(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + s.httpReqStats[SubszPath]++ + s.mu.Unlock() + + subs, err := decodeBool(w, r, "subs") + if err != nil { + return + } + offset, err := decodeInt(w, r, "offset") + if err != nil { + return + } + limit, err := decodeInt(w, r, "limit") + if err != nil { + return + } + testSub := r.URL.Query().Get("test") + + subszOpts := &SubszOptions{ + Subscriptions: subs, + Offset: offset, + Limit: limit, + Test: testSub, + } + + st, err := s.Subsz(subszOpts) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + return + } + + var b []byte + + if len(st.Subs) == 0 { + b, err = json.MarshalIndent(st.SublistStats, "", " ") + } else { + b, err = json.MarshalIndent(st, "", " ") + } + if err != nil { + s.Errorf("Error marshaling response to /subscriptionsz request: %v", err) + } + + // Handle response + ResponseHandler(w, r, b) +} + +// HandleStacksz processes HTTP requests for getting stacks +func (s *Server) HandleStacksz(w http.ResponseWriter, r *http.Request) { + // Do not get any lock here that would prevent getting the stacks + // if we were to have a deadlock somewhere. + var defaultBuf [defaultStackBufSize]byte + size := defaultStackBufSize + buf := defaultBuf[:size] + n := 0 + for { + n = runtime.Stack(buf, true) + if n < size { + break + } + size *= 2 + buf = make([]byte, size) + } + // Handle response + ResponseHandler(w, r, buf[:n]) +} + +// Varz will output server information on the monitoring port at /varz. +type Varz struct { + *Info + *Options + Port int `json:"port"` + MaxPayload int `json:"max_payload"` + Start time.Time `json:"start"` + Now time.Time `json:"now"` + Uptime string `json:"uptime"` + Mem int64 `json:"mem"` + Cores int `json:"cores"` + CPU float64 `json:"cpu"` + Connections int `json:"connections"` + TotalConnections uint64 `json:"total_connections"` + Routes int `json:"routes"` + Remotes int `json:"remotes"` + InMsgs int64 `json:"in_msgs"` + OutMsgs int64 `json:"out_msgs"` + InBytes int64 `json:"in_bytes"` + OutBytes int64 `json:"out_bytes"` + SlowConsumers int64 `json:"slow_consumers"` + MaxPending int64 `json:"max_pending"` + WriteDeadline time.Duration `json:"write_deadline"` + Subscriptions uint32 `json:"subscriptions"` + HTTPReqStats map[string]uint64 `json:"http_req_stats"` + ConfigLoadTime time.Time `json:"config_load_time"` +} + +// VarzOptions are the options passed to Varz(). +// Currently, there are no options defined. +type VarzOptions struct{} + +func myUptime(d time.Duration) string { + // Just use total seconds for uptime, and display days / years + tsecs := d / time.Second + tmins := tsecs / 60 + thrs := tmins / 60 + tdays := thrs / 24 + tyrs := tdays / 365 + + if tyrs > 0 { + return fmt.Sprintf("%dy%dd%dh%dm%ds", tyrs, tdays%365, thrs%24, tmins%60, tsecs%60) + } + if tdays > 0 { + return fmt.Sprintf("%dd%dh%dm%ds", tdays, thrs%24, tmins%60, tsecs%60) + } + if thrs > 0 { + return fmt.Sprintf("%dh%dm%ds", thrs, tmins%60, tsecs%60) + } + if tmins > 0 { + return fmt.Sprintf("%dm%ds", tmins, tsecs%60) + } + return fmt.Sprintf("%ds", tsecs) +} + +// HandleRoot will show basic info and links to others handlers. +func (s *Server) HandleRoot(w http.ResponseWriter, r *http.Request) { + // This feels dumb to me, but is required: https://code.google.com/p/go/issues/detail?id=4799 + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + s.mu.Lock() + s.httpReqStats[RootPath]++ + s.mu.Unlock() + fmt.Fprintf(w, ` + + + + + + NATS +
+ varz
+ connz
+ routez
+ subsz
+
+ help + +`) +} + +// Varz returns a Varz struct containing the server information. +func (s *Server) Varz(varzOpts *VarzOptions) (*Varz, error) { + // Snapshot server options. + opts := s.getOpts() + + v := &Varz{Info: &s.info, Options: opts, MaxPayload: opts.MaxPayload, Start: s.start} + v.Now = time.Now() + v.Uptime = myUptime(time.Since(s.start)) + v.Port = v.Info.Port + + updateUsage(v) + + s.mu.Lock() + v.Connections = len(s.clients) + v.TotalConnections = s.totalClients + v.Routes = len(s.routes) + v.Remotes = len(s.remotes) + v.InMsgs = atomic.LoadInt64(&s.inMsgs) + v.InBytes = atomic.LoadInt64(&s.inBytes) + v.OutMsgs = atomic.LoadInt64(&s.outMsgs) + v.OutBytes = atomic.LoadInt64(&s.outBytes) + v.SlowConsumers = atomic.LoadInt64(&s.slowConsumers) + v.MaxPending = opts.MaxPending + v.WriteDeadline = opts.WriteDeadline + v.Subscriptions = s.sl.Count() + v.ConfigLoadTime = s.configTime + // Need a copy here since s.httpReqStats can change while doing + // the marshaling down below. + v.HTTPReqStats = make(map[string]uint64, len(s.httpReqStats)) + for key, val := range s.httpReqStats { + v.HTTPReqStats[key] = val + } + s.mu.Unlock() + + return v, nil +} + +// HandleVarz will process HTTP requests for server information. +func (s *Server) HandleVarz(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + s.httpReqStats[VarzPath]++ + s.mu.Unlock() + + // As of now, no error is ever returned + v, _ := s.Varz(nil) + b, err := json.MarshalIndent(v, "", " ") + if err != nil { + s.Errorf("Error marshaling response to /varz request: %v", err) + } + + // Handle response + ResponseHandler(w, r, b) +} + +// Grab RSS and PCPU +func updateUsage(v *Varz) { + var rss, vss int64 + var pcpu float64 + + pse.ProcUsage(&pcpu, &rss, &vss) + + v.Mem = rss + v.CPU = pcpu + v.Cores = numCores +} + +// ResponseHandler handles responses for monitoring routes +func ResponseHandler(w http.ResponseWriter, r *http.Request, data []byte) { + // Get callback from request + callback := r.URL.Query().Get("callback") + // If callback is not empty then + if callback != "" { + // Response for JSONP + w.Header().Set("Content-Type", "application/javascript") + fmt.Fprintf(w, "%s(%s)", callback, data) + } else { + // Otherwise JSON + w.Header().Set("Content-Type", "application/json") + w.Write(data) + } +} + +func (reason ClosedState) String() string { + switch reason { + case ClientClosed: + return "Client" + case AuthenticationTimeout: + return "Authentication Timeout" + case AuthenticationViolation: + return "Authentication Failure" + case TLSHandshakeError: + return "TLS Handshake Failure" + case SlowConsumerPendingBytes: + return "Slow Consumer (Pending Bytes)" + case SlowConsumerWriteDeadline: + return "Slow Consumer (Write Deadline)" + case WriteError: + return "Write Error" + case ReadError: + return "Read Error" + case ParseError: + return "Parse Error" + case StaleConnection: + return "Stale Connection" + case ProtocolViolation: + return "Protocol Violation" + case BadClientProtocolVersion: + return "Bad Client Protocol Version" + case WrongPort: + return "Incorrect Port" + case MaxConnectionsExceeded: + return "Maximum Connections Exceeded" + case MaxPayloadExceeded: + return "Maximum Message Payload Exceeded" + case MaxControlLineExceeded: + return "Maximum Control Line Exceeded" + case DuplicateRoute: + return "Duplicate Route" + case RouteRemoved: + return "Route Removed" + case ServerShutdown: + return "Server Shutdown" + } + return "Unknown State" +} diff --git a/vendor/github.com/nats-io/gnatsd/server/monitor_sort_opts.go b/vendor/github.com/nats-io/gnatsd/server/monitor_sort_opts.go new file mode 100644 index 00000000000..926ceb26ca9 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/monitor_sort_opts.go @@ -0,0 +1,147 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "time" +) + +// Represents a connection info list. We use pointers since it will be sorted. +type ConnInfos []*ConnInfo + +// For sorting +func (cl ConnInfos) Len() int { return len(cl) } +func (cl ConnInfos) Swap(i, j int) { cl[i], cl[j] = cl[j], cl[i] } + +// SortOpt is a helper type to sort clients +type SortOpt string + +// Possible sort options +const ( + ByCid SortOpt = "cid" // By connection ID + ByStart SortOpt = "start" // By connection start time, same as CID + BySubs SortOpt = "subs" // By number of subscriptions + ByPending SortOpt = "pending" // By amount of data in bytes waiting to be sent to client + ByOutMsgs SortOpt = "msgs_to" // By number of messages sent + ByInMsgs SortOpt = "msgs_from" // By number of messages received + ByOutBytes SortOpt = "bytes_to" // By amount of bytes sent + ByInBytes SortOpt = "bytes_from" // By amount of bytes received + ByLast SortOpt = "last" // By the last activity + ByIdle SortOpt = "idle" // By the amount of inactivity + ByUptime SortOpt = "uptime" // By the amount of time connections exist + ByStop SortOpt = "stop" // By the stop time for a closed connection + ByReason SortOpt = "reason" // By the reason for a closed connection + +) + +// Individual sort options provide the Less for sort.Interface. Len and Swap are on cList. +// CID +type byCid struct{ ConnInfos } + +func (l byCid) Less(i, j int) bool { return l.ConnInfos[i].Cid < l.ConnInfos[j].Cid } + +// Number of Subscriptions +type bySubs struct{ ConnInfos } + +func (l bySubs) Less(i, j int) bool { return l.ConnInfos[i].NumSubs < l.ConnInfos[j].NumSubs } + +// Pending Bytes +type byPending struct{ ConnInfos } + +func (l byPending) Less(i, j int) bool { return l.ConnInfos[i].Pending < l.ConnInfos[j].Pending } + +// Outbound Msgs +type byOutMsgs struct{ ConnInfos } + +func (l byOutMsgs) Less(i, j int) bool { return l.ConnInfos[i].OutMsgs < l.ConnInfos[j].OutMsgs } + +// Inbound Msgs +type byInMsgs struct{ ConnInfos } + +func (l byInMsgs) Less(i, j int) bool { return l.ConnInfos[i].InMsgs < l.ConnInfos[j].InMsgs } + +// Outbound Bytes +type byOutBytes struct{ ConnInfos } + +func (l byOutBytes) Less(i, j int) bool { return l.ConnInfos[i].OutBytes < l.ConnInfos[j].OutBytes } + +// Inbound Bytes +type byInBytes struct{ ConnInfos } + +func (l byInBytes) Less(i, j int) bool { return l.ConnInfos[i].InBytes < l.ConnInfos[j].InBytes } + +// Last Activity +type byLast struct{ ConnInfos } + +func (l byLast) Less(i, j int) bool { + return l.ConnInfos[i].LastActivity.UnixNano() < l.ConnInfos[j].LastActivity.UnixNano() +} + +// Idle time +type byIdle struct{ ConnInfos } + +func (l byIdle) Less(i, j int) bool { + ii := l.ConnInfos[i].LastActivity.Sub(l.ConnInfos[i].Start) + ij := l.ConnInfos[j].LastActivity.Sub(l.ConnInfos[j].Start) + return ii < ij +} + +// Uptime +type byUptime struct { + ConnInfos + now time.Time +} + +func (l byUptime) Less(i, j int) bool { + ci := l.ConnInfos[i] + cj := l.ConnInfos[j] + var upi, upj time.Duration + if ci.Stop == nil || ci.Stop.IsZero() { + upi = l.now.Sub(ci.Start) + } else { + upi = ci.Stop.Sub(ci.Start) + } + if cj.Stop == nil || cj.Stop.IsZero() { + upj = l.now.Sub(cj.Start) + } else { + upj = cj.Stop.Sub(cj.Start) + } + return upi < upj +} + +// Stop +type byStop struct{ ConnInfos } + +func (l byStop) Less(i, j int) bool { + ciStop := l.ConnInfos[i].Stop + cjStop := l.ConnInfos[j].Stop + return ciStop.Before(*cjStop) +} + +// Reason +type byReason struct{ ConnInfos } + +func (l byReason) Less(i, j int) bool { + return l.ConnInfos[i].Reason < l.ConnInfos[j].Reason +} + +// IsValid determines if a sort option is valid +func (s SortOpt) IsValid() bool { + switch s { + case "", ByCid, ByStart, BySubs, ByPending, ByOutMsgs, ByInMsgs, ByOutBytes, ByInBytes, ByLast, ByIdle, ByUptime, ByStop, ByReason: + return true + default: + return false + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/opts.go b/vendor/github.com/nats-io/gnatsd/server/opts.go new file mode 100644 index 00000000000..05ed57a2325 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/opts.go @@ -0,0 +1,1312 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "flag" + "fmt" + "io/ioutil" + "net" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/nats-io/gnatsd/conf" + "github.com/nats-io/gnatsd/util" +) + +// ClusterOpts are options for clusters. +type ClusterOpts struct { + Host string `json:"addr,omitempty"` + Port int `json:"cluster_port,omitempty"` + Username string `json:"-"` + Password string `json:"-"` + AuthTimeout float64 `json:"auth_timeout,omitempty"` + Permissions *RoutePermissions `json:"-"` + TLSTimeout float64 `json:"-"` + TLSConfig *tls.Config `json:"-"` + ListenStr string `json:"-"` + Advertise string `json:"-"` + NoAdvertise bool `json:"-"` + ConnectRetries int `json:"-"` +} + +// Options block for gnatsd server. +type Options struct { + ConfigFile string `json:"-"` + Host string `json:"addr"` + Port int `json:"port"` + ClientAdvertise string `json:"-"` + Trace bool `json:"-"` + Debug bool `json:"-"` + NoLog bool `json:"-"` + NoSigs bool `json:"-"` + Logtime bool `json:"-"` + MaxConn int `json:"max_connections"` + MaxSubs int `json:"max_subscriptions,omitempty"` + Users []*User `json:"-"` + Username string `json:"-"` + Password string `json:"-"` + Authorization string `json:"-"` + PingInterval time.Duration `json:"ping_interval"` + MaxPingsOut int `json:"ping_max"` + HTTPHost string `json:"http_host"` + HTTPPort int `json:"http_port"` + HTTPSPort int `json:"https_port"` + AuthTimeout float64 `json:"auth_timeout"` + MaxControlLine int `json:"max_control_line"` + MaxPayload int `json:"max_payload"` + MaxPending int64 `json:"max_pending"` + Cluster ClusterOpts `json:"cluster,omitempty"` + ProfPort int `json:"-"` + PidFile string `json:"-"` + PortsFileDir string `json:"-"` + LogFile string `json:"-"` + Syslog bool `json:"-"` + RemoteSyslog string `json:"-"` + Routes []*url.URL `json:"-"` + RoutesStr string `json:"-"` + TLSTimeout float64 `json:"tls_timeout"` + TLS bool `json:"-"` + TLSVerify bool `json:"-"` + TLSCert string `json:"-"` + TLSKey string `json:"-"` + TLSCaCert string `json:"-"` + TLSConfig *tls.Config `json:"-"` + WriteDeadline time.Duration `json:"-"` + RQSubsSweep time.Duration `json:"-"` + MaxClosedClients int `json:"-"` + + CustomClientAuthentication Authentication `json:"-"` + CustomRouterAuthentication Authentication `json:"-"` +} + +// Clone performs a deep copy of the Options struct, returning a new clone +// with all values copied. +func (o *Options) Clone() *Options { + if o == nil { + return nil + } + clone := &Options{} + *clone = *o + if o.Users != nil { + clone.Users = make([]*User, len(o.Users)) + for i, user := range o.Users { + clone.Users[i] = user.clone() + } + } + if o.Routes != nil { + clone.Routes = make([]*url.URL, len(o.Routes)) + for i, route := range o.Routes { + routeCopy := &url.URL{} + *routeCopy = *route + clone.Routes[i] = routeCopy + } + } + if o.TLSConfig != nil { + clone.TLSConfig = util.CloneTLSConfig(o.TLSConfig) + } + if o.Cluster.TLSConfig != nil { + clone.Cluster.TLSConfig = util.CloneTLSConfig(o.Cluster.TLSConfig) + } + return clone +} + +// Configuration file authorization section. +type authorization struct { + // Singles + user string + pass string + token string + // Multiple Users + users []*User + timeout float64 + defaultPermissions *Permissions +} + +// TLSConfigOpts holds the parsed tls config information, +// used with flag parsing +type TLSConfigOpts struct { + CertFile string + KeyFile string + CaFile string + Verify bool + Timeout float64 + Ciphers []uint16 + CurvePreferences []tls.CurveID +} + +var tlsUsage = ` +TLS configuration is specified in the tls section of a configuration file: + +e.g. + + tls { + cert_file: "./certs/server-cert.pem" + key_file: "./certs/server-key.pem" + ca_file: "./certs/ca.pem" + verify: true + + cipher_suites: [ + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" + ] + curve_preferences: [ + "CurveP256", + "CurveP384", + "CurveP521" + ] + } + +Available cipher suites include: +` + +// ProcessConfigFile processes a configuration file. +// FIXME(dlc): Hacky +func ProcessConfigFile(configFile string) (*Options, error) { + opts := &Options{} + if err := opts.ProcessConfigFile(configFile); err != nil { + return nil, err + } + return opts, nil +} + +// ProcessConfigFile updates the Options structure with options +// present in the given configuration file. +// This version is convenient if one wants to set some default +// options and then override them with what is in the config file. +// For instance, this version allows you to do something such as: +// +// opts := &Options{Debug: true} +// opts.ProcessConfigFile(myConfigFile) +// +// If the config file contains "debug: false", after this call, +// opts.Debug would really be false. It would be impossible to +// achieve that with the non receiver ProcessConfigFile() version, +// since one would not know after the call if "debug" was not present +// or was present but set to false. +func (o *Options) ProcessConfigFile(configFile string) error { + o.ConfigFile = configFile + if configFile == "" { + return nil + } + + m, err := conf.ParseFile(configFile) + if err != nil { + return err + } + + for k, v := range m { + switch strings.ToLower(k) { + case "listen": + hp, err := parseListen(v) + if err != nil { + return err + } + o.Host = hp.host + o.Port = hp.port + case "client_advertise": + o.ClientAdvertise = v.(string) + case "port": + o.Port = int(v.(int64)) + case "host", "net": + o.Host = v.(string) + case "debug": + o.Debug = v.(bool) + case "trace": + o.Trace = v.(bool) + case "logtime": + o.Logtime = v.(bool) + case "authorization": + am := v.(map[string]interface{}) + auth, err := parseAuthorization(am) + if err != nil { + return err + } + o.Username = auth.user + o.Password = auth.pass + o.Authorization = auth.token + if (auth.user != "" || auth.pass != "") && auth.token != "" { + return fmt.Errorf("Cannot have a user/pass and token") + } + o.AuthTimeout = auth.timeout + // Check for multiple users defined + if auth.users != nil { + if auth.user != "" { + return fmt.Errorf("Can not have a single user/pass and a users array") + } + if auth.token != "" { + return fmt.Errorf("Can not have a token and a users array") + } + o.Users = auth.users + } + case "http": + hp, err := parseListen(v) + if err != nil { + return err + } + o.HTTPHost = hp.host + o.HTTPPort = hp.port + case "https": + hp, err := parseListen(v) + if err != nil { + return err + } + o.HTTPHost = hp.host + o.HTTPSPort = hp.port + case "http_port", "monitor_port": + o.HTTPPort = int(v.(int64)) + case "https_port": + o.HTTPSPort = int(v.(int64)) + case "cluster": + cm := v.(map[string]interface{}) + if err := parseCluster(cm, o); err != nil { + return err + } + case "logfile", "log_file": + o.LogFile = v.(string) + case "syslog": + o.Syslog = v.(bool) + case "remote_syslog": + o.RemoteSyslog = v.(string) + case "pidfile", "pid_file": + o.PidFile = v.(string) + case "ports_file_dir": + o.PortsFileDir = v.(string) + case "prof_port": + o.ProfPort = int(v.(int64)) + case "max_control_line": + o.MaxControlLine = int(v.(int64)) + case "max_payload": + o.MaxPayload = int(v.(int64)) + case "max_pending": + o.MaxPending = v.(int64) + case "max_connections", "max_conn": + o.MaxConn = int(v.(int64)) + case "max_subscriptions", "max_subs": + o.MaxSubs = int(v.(int64)) + case "ping_interval": + o.PingInterval = time.Duration(int(v.(int64))) * time.Second + case "ping_max": + o.MaxPingsOut = int(v.(int64)) + case "tls": + tlsm := v.(map[string]interface{}) + tc, err := parseTLS(tlsm) + if err != nil { + return err + } + if o.TLSConfig, err = GenTLSConfig(tc); err != nil { + return err + } + o.TLSTimeout = tc.Timeout + case "write_deadline": + wd, ok := v.(string) + if ok { + dur, err := time.ParseDuration(wd) + if err != nil { + return fmt.Errorf("error parsing write_deadline: %v", err) + } + o.WriteDeadline = dur + } else { + // Backward compatible with old type, assume this is the + // number of seconds. + o.WriteDeadline = time.Duration(v.(int64)) * time.Second + fmt.Printf("WARNING: write_deadline should be converted to a duration\n") + } + } + } + return nil +} + +// hostPort is simple struct to hold parsed listen/addr strings. +type hostPort struct { + host string + port int +} + +// parseListen will parse listen option which is replacing host/net and port +func parseListen(v interface{}) (*hostPort, error) { + hp := &hostPort{} + switch v.(type) { + // Only a port + case int64: + hp.port = int(v.(int64)) + case string: + host, port, err := net.SplitHostPort(v.(string)) + if err != nil { + return nil, fmt.Errorf("Could not parse address string %q", v) + } + hp.port, err = strconv.Atoi(port) + if err != nil { + return nil, fmt.Errorf("Could not parse port %q", port) + } + hp.host = host + } + return hp, nil +} + +// parseCluster will parse the cluster config. +func parseCluster(cm map[string]interface{}, opts *Options) error { + for mk, mv := range cm { + switch strings.ToLower(mk) { + case "listen": + hp, err := parseListen(mv) + if err != nil { + return err + } + opts.Cluster.Host = hp.host + opts.Cluster.Port = hp.port + case "port": + opts.Cluster.Port = int(mv.(int64)) + case "host", "net": + opts.Cluster.Host = mv.(string) + case "authorization": + am := mv.(map[string]interface{}) + auth, err := parseAuthorization(am) + if err != nil { + return err + } + if auth.users != nil { + return fmt.Errorf("Cluster authorization does not allow multiple users") + } + opts.Cluster.Username = auth.user + opts.Cluster.Password = auth.pass + opts.Cluster.AuthTimeout = auth.timeout + if auth.defaultPermissions != nil { + // Import is whether or not we will send a SUB for interest to the other side. + // Export is whether or not we will accept a SUB from the remote for a given subject. + // Both only effect interest registration. + // The parsing sets Import into Publish and Export into Subscribe, convert + // accordingly. + opts.Cluster.Permissions = &RoutePermissions{ + Import: auth.defaultPermissions.Publish, + Export: auth.defaultPermissions.Subscribe, + } + } + case "routes": + ra := mv.([]interface{}) + opts.Routes = make([]*url.URL, 0, len(ra)) + for _, r := range ra { + routeURL := r.(string) + url, err := url.Parse(routeURL) + if err != nil { + return fmt.Errorf("error parsing route url [%q]", routeURL) + } + opts.Routes = append(opts.Routes, url) + } + case "tls": + tlsm := mv.(map[string]interface{}) + tc, err := parseTLS(tlsm) + if err != nil { + return err + } + if opts.Cluster.TLSConfig, err = GenTLSConfig(tc); err != nil { + return err + } + // For clusters, we will force strict verification. We also act + // as both client and server, so will mirror the rootCA to the + // clientCA pool. + opts.Cluster.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + opts.Cluster.TLSConfig.RootCAs = opts.Cluster.TLSConfig.ClientCAs + opts.Cluster.TLSTimeout = tc.Timeout + case "cluster_advertise", "advertise": + opts.Cluster.Advertise = mv.(string) + case "no_advertise": + opts.Cluster.NoAdvertise = mv.(bool) + case "connect_retries": + opts.Cluster.ConnectRetries = int(mv.(int64)) + } + } + return nil +} + +// Helper function to parse Authorization configs. +func parseAuthorization(am map[string]interface{}) (*authorization, error) { + auth := &authorization{} + for mk, mv := range am { + switch strings.ToLower(mk) { + case "user", "username": + auth.user = mv.(string) + case "pass", "password": + auth.pass = mv.(string) + case "token": + auth.token = mv.(string) + case "timeout": + at := float64(1) + switch mv.(type) { + case int64: + at = float64(mv.(int64)) + case float64: + at = mv.(float64) + } + auth.timeout = at + case "users": + users, err := parseUsers(mv) + if err != nil { + return nil, err + } + auth.users = users + case "default_permission", "default_permissions", "permissions": + pm, ok := mv.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("Expected default permissions to be a map/struct, got %+v", mv) + } + permissions, err := parseUserPermissions(pm) + if err != nil { + return nil, err + } + auth.defaultPermissions = permissions + } + + // Now check for permission defaults with multiple users, etc. + if auth.users != nil && auth.defaultPermissions != nil { + for _, user := range auth.users { + if user.Permissions == nil { + user.Permissions = auth.defaultPermissions + } + } + } + + } + return auth, nil +} + +// Helper function to parse multiple users array with optional permissions. +func parseUsers(mv interface{}) ([]*User, error) { + // Make sure we have an array + uv, ok := mv.([]interface{}) + if !ok { + return nil, fmt.Errorf("Expected users field to be an array, got %v", mv) + } + users := []*User{} + for _, u := range uv { + // Check its a map/struct + um, ok := u.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("Expected user entry to be a map/struct, got %v", u) + } + user := &User{} + for k, v := range um { + switch strings.ToLower(k) { + case "user", "username": + user.Username = v.(string) + case "pass", "password": + user.Password = v.(string) + case "permission", "permissions", "authorization": + pm, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("Expected user permissions to be a map/struct, got %+v", v) + } + permissions, err := parseUserPermissions(pm) + if err != nil { + return nil, err + } + user.Permissions = permissions + } + } + // Check to make sure we have at least username and password + if user.Username == "" || user.Password == "" { + return nil, fmt.Errorf("User entry requires a user and a password") + } + users = append(users, user) + } + return users, nil +} + +// Helper function to parse user/account permissions +func parseUserPermissions(pm map[string]interface{}) (*Permissions, error) { + p := &Permissions{} + for k, v := range pm { + switch strings.ToLower(k) { + // For routes: + // Import is Publish + // Export is Subscribe + case "pub", "publish", "import": + perms, err := parseVariablePermissions(v) + if err != nil { + return nil, err + } + p.Publish = perms + case "sub", "subscribe", "export": + perms, err := parseVariablePermissions(v) + if err != nil { + return nil, err + } + p.Subscribe = perms + default: + return nil, fmt.Errorf("Unknown field %s parsing permissions", k) + } + } + return p, nil +} + +// Tope level parser for authorization configurations. +func parseVariablePermissions(v interface{}) (*SubjectPermission, error) { + switch v.(type) { + case map[string]interface{}: + // New style with allow and/or deny properties. + return parseSubjectPermission(v.(map[string]interface{})) + default: + // Old style + return parseOldPermissionStyle(v) + } +} + +// Helper function to parse subject singeltons and/or arrays +func parseSubjects(v interface{}) ([]string, error) { + var subjects []string + switch v.(type) { + case string: + subjects = append(subjects, v.(string)) + case []string: + subjects = v.([]string) + case []interface{}: + for _, i := range v.([]interface{}) { + subject, ok := i.(string) + if !ok { + return nil, fmt.Errorf("Subject in permissions array cannot be cast to string") + } + subjects = append(subjects, subject) + } + default: + return nil, fmt.Errorf("Expected subject permissions to be a subject, or array of subjects, got %T", v) + } + if err := checkSubjectArray(subjects); err != nil { + return nil, err + } + return subjects, nil +} + +// Helper function to parse old style authorization configs. +func parseOldPermissionStyle(v interface{}) (*SubjectPermission, error) { + subjects, err := parseSubjects(v) + if err != nil { + return nil, err + } + return &SubjectPermission{Allow: subjects}, nil +} + +// Helper function to parse new style authorization into a SubjectPermission with Allow and Deny. +func parseSubjectPermission(m map[string]interface{}) (*SubjectPermission, error) { + if len(m) == 0 { + return nil, nil + } + + p := &SubjectPermission{} + + for k, v := range m { + switch strings.ToLower(k) { + case "allow": + subjects, err := parseSubjects(v) + if err != nil { + return nil, err + } + p.Allow = subjects + case "deny": + subjects, err := parseSubjects(v) + if err != nil { + return nil, err + } + p.Deny = subjects + default: + return nil, fmt.Errorf("Unknown field name %q parsing subject permissions, only 'allow' or 'deny' are permitted", k) + } + } + return p, nil +} + +// Helper function to validate subjects, etc for account permissioning. +func checkSubjectArray(sa []string) error { + for _, s := range sa { + if !IsValidSubject(s) { + return fmt.Errorf("Subject %q is not a valid subject", s) + } + } + return nil +} + +// PrintTLSHelpAndDie prints TLS usage and exits. +func PrintTLSHelpAndDie() { + fmt.Printf("%s", tlsUsage) + for k := range cipherMap { + fmt.Printf(" %s\n", k) + } + fmt.Printf("\nAvailable curve preferences include:\n") + for k := range curvePreferenceMap { + fmt.Printf(" %s\n", k) + } + os.Exit(0) +} + +func parseCipher(cipherName string) (uint16, error) { + + cipher, exists := cipherMap[cipherName] + if !exists { + return 0, fmt.Errorf("Unrecognized cipher %s", cipherName) + } + + return cipher, nil +} + +func parseCurvePreferences(curveName string) (tls.CurveID, error) { + curve, exists := curvePreferenceMap[curveName] + if !exists { + return 0, fmt.Errorf("Unrecognized curve preference %s", curveName) + } + return curve, nil +} + +// Helper function to parse TLS configs. +func parseTLS(tlsm map[string]interface{}) (*TLSConfigOpts, error) { + tc := TLSConfigOpts{} + for mk, mv := range tlsm { + switch strings.ToLower(mk) { + case "cert_file": + certFile, ok := mv.(string) + if !ok { + return nil, fmt.Errorf("error parsing tls config, expected 'cert_file' to be filename") + } + tc.CertFile = certFile + case "key_file": + keyFile, ok := mv.(string) + if !ok { + return nil, fmt.Errorf("error parsing tls config, expected 'key_file' to be filename") + } + tc.KeyFile = keyFile + case "ca_file": + caFile, ok := mv.(string) + if !ok { + return nil, fmt.Errorf("error parsing tls config, expected 'ca_file' to be filename") + } + tc.CaFile = caFile + case "verify": + verify, ok := mv.(bool) + if !ok { + return nil, fmt.Errorf("error parsing tls config, expected 'verify' to be a boolean") + } + tc.Verify = verify + case "cipher_suites": + ra := mv.([]interface{}) + if len(ra) == 0 { + return nil, fmt.Errorf("error parsing tls config, 'cipher_suites' cannot be empty") + } + tc.Ciphers = make([]uint16, 0, len(ra)) + for _, r := range ra { + cipher, err := parseCipher(r.(string)) + if err != nil { + return nil, err + } + tc.Ciphers = append(tc.Ciphers, cipher) + } + case "curve_preferences": + ra := mv.([]interface{}) + if len(ra) == 0 { + return nil, fmt.Errorf("error parsing tls config, 'curve_preferences' cannot be empty") + } + tc.CurvePreferences = make([]tls.CurveID, 0, len(ra)) + for _, r := range ra { + cps, err := parseCurvePreferences(r.(string)) + if err != nil { + return nil, err + } + tc.CurvePreferences = append(tc.CurvePreferences, cps) + } + case "timeout": + at := float64(0) + switch mv.(type) { + case int64: + at = float64(mv.(int64)) + case float64: + at = mv.(float64) + } + tc.Timeout = at + default: + return nil, fmt.Errorf("error parsing tls config, unknown field [%q]", mk) + } + } + + // If cipher suites were not specified then use the defaults + if tc.Ciphers == nil { + tc.Ciphers = defaultCipherSuites() + } + + // If curve preferences were not specified, then use the defaults + if tc.CurvePreferences == nil { + tc.CurvePreferences = defaultCurvePreferences() + } + + return &tc, nil +} + +// GenTLSConfig loads TLS related configuration parameters. +func GenTLSConfig(tc *TLSConfigOpts) (*tls.Config, error) { + + // Now load in cert and private key + cert, err := tls.LoadX509KeyPair(tc.CertFile, tc.KeyFile) + if err != nil { + return nil, fmt.Errorf("error parsing X509 certificate/key pair: %v", err) + } + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + // Create TLSConfig + // We will determine the cipher suites that we prefer. + config := tls.Config{ + CurvePreferences: tc.CurvePreferences, + Certificates: []tls.Certificate{cert}, + PreferServerCipherSuites: true, + MinVersion: tls.VersionTLS12, + CipherSuites: tc.Ciphers, + } + + // Require client certificates as needed + if tc.Verify { + config.ClientAuth = tls.RequireAndVerifyClientCert + } + // Add in CAs if applicable. + if tc.CaFile != "" { + rootPEM, err := ioutil.ReadFile(tc.CaFile) + if err != nil || rootPEM == nil { + return nil, err + } + pool := x509.NewCertPool() + ok := pool.AppendCertsFromPEM(rootPEM) + if !ok { + return nil, fmt.Errorf("failed to parse root ca certificate") + } + config.ClientCAs = pool + } + + return &config, nil +} + +// MergeOptions will merge two options giving preference to the flagOpts +// if the item is present. +func MergeOptions(fileOpts, flagOpts *Options) *Options { + if fileOpts == nil { + return flagOpts + } + if flagOpts == nil { + return fileOpts + } + // Merge the two, flagOpts override + opts := *fileOpts + + if flagOpts.Port != 0 { + opts.Port = flagOpts.Port + } + if flagOpts.Host != "" { + opts.Host = flagOpts.Host + } + if flagOpts.ClientAdvertise != "" { + opts.ClientAdvertise = flagOpts.ClientAdvertise + } + if flagOpts.Username != "" { + opts.Username = flagOpts.Username + } + if flagOpts.Password != "" { + opts.Password = flagOpts.Password + } + if flagOpts.Authorization != "" { + opts.Authorization = flagOpts.Authorization + } + if flagOpts.HTTPPort != 0 { + opts.HTTPPort = flagOpts.HTTPPort + } + if flagOpts.Debug { + opts.Debug = true + } + if flagOpts.Trace { + opts.Trace = true + } + if flagOpts.Logtime { + opts.Logtime = true + } + if flagOpts.LogFile != "" { + opts.LogFile = flagOpts.LogFile + } + if flagOpts.PidFile != "" { + opts.PidFile = flagOpts.PidFile + } + if flagOpts.PortsFileDir != "" { + opts.PortsFileDir = flagOpts.PortsFileDir + } + if flagOpts.ProfPort != 0 { + opts.ProfPort = flagOpts.ProfPort + } + if flagOpts.Cluster.ListenStr != "" { + opts.Cluster.ListenStr = flagOpts.Cluster.ListenStr + } + if flagOpts.Cluster.NoAdvertise { + opts.Cluster.NoAdvertise = true + } + if flagOpts.Cluster.ConnectRetries != 0 { + opts.Cluster.ConnectRetries = flagOpts.Cluster.ConnectRetries + } + if flagOpts.Cluster.Advertise != "" { + opts.Cluster.Advertise = flagOpts.Cluster.Advertise + } + if flagOpts.RoutesStr != "" { + mergeRoutes(&opts, flagOpts) + } + return &opts +} + +// RoutesFromStr parses route URLs from a string +func RoutesFromStr(routesStr string) []*url.URL { + routes := strings.Split(routesStr, ",") + if len(routes) == 0 { + return nil + } + routeUrls := []*url.URL{} + for _, r := range routes { + r = strings.TrimSpace(r) + u, _ := url.Parse(r) + routeUrls = append(routeUrls, u) + } + return routeUrls +} + +// This will merge the flag routes and override anything that was present. +func mergeRoutes(opts, flagOpts *Options) { + routeUrls := RoutesFromStr(flagOpts.RoutesStr) + if routeUrls == nil { + return + } + opts.Routes = routeUrls + opts.RoutesStr = flagOpts.RoutesStr +} + +// RemoveSelfReference removes this server from an array of routes +func RemoveSelfReference(clusterPort int, routes []*url.URL) ([]*url.URL, error) { + var cleanRoutes []*url.URL + cport := strconv.Itoa(clusterPort) + + selfIPs, err := getInterfaceIPs() + if err != nil { + return nil, err + } + for _, r := range routes { + host, port, err := net.SplitHostPort(r.Host) + if err != nil { + return nil, err + } + + ipList, err := getURLIP(host) + if err != nil { + return nil, err + } + if cport == port && isIPInList(selfIPs, ipList) { + continue + } + cleanRoutes = append(cleanRoutes, r) + } + + return cleanRoutes, nil +} + +func isIPInList(list1 []net.IP, list2 []net.IP) bool { + for _, ip1 := range list1 { + for _, ip2 := range list2 { + if ip1.Equal(ip2) { + return true + } + } + } + return false +} + +func getURLIP(ipStr string) ([]net.IP, error) { + ipList := []net.IP{} + + ip := net.ParseIP(ipStr) + if ip != nil { + ipList = append(ipList, ip) + return ipList, nil + } + + hostAddr, err := net.LookupHost(ipStr) + if err != nil { + return nil, fmt.Errorf("Error looking up host with route hostname: %v", err) + } + for _, addr := range hostAddr { + ip = net.ParseIP(addr) + if ip != nil { + ipList = append(ipList, ip) + } + } + return ipList, nil +} + +func getInterfaceIPs() ([]net.IP, error) { + var localIPs []net.IP + + interfaceAddr, err := net.InterfaceAddrs() + if err != nil { + return nil, fmt.Errorf("Error getting self referencing address: %v", err) + } + + for i := 0; i < len(interfaceAddr); i++ { + interfaceIP, _, _ := net.ParseCIDR(interfaceAddr[i].String()) + if net.ParseIP(interfaceIP.String()) != nil { + localIPs = append(localIPs, interfaceIP) + } else { + return nil, fmt.Errorf("Error parsing self referencing address: %v", err) + } + } + return localIPs, nil +} + +func processOptions(opts *Options) { + // Setup non-standard Go defaults + if opts.Host == "" { + opts.Host = DEFAULT_HOST + } + if opts.HTTPHost == "" { + // Default to same bind from server if left undefined + opts.HTTPHost = opts.Host + } + if opts.Port == 0 { + opts.Port = DEFAULT_PORT + } else if opts.Port == RANDOM_PORT { + // Choose randomly inside of net.Listen + opts.Port = 0 + } + if opts.MaxConn == 0 { + opts.MaxConn = DEFAULT_MAX_CONNECTIONS + } + if opts.PingInterval == 0 { + opts.PingInterval = DEFAULT_PING_INTERVAL + } + if opts.MaxPingsOut == 0 { + opts.MaxPingsOut = DEFAULT_PING_MAX_OUT + } + if opts.TLSTimeout == 0 { + opts.TLSTimeout = float64(TLS_TIMEOUT) / float64(time.Second) + } + if opts.AuthTimeout == 0 { + opts.AuthTimeout = float64(AUTH_TIMEOUT) / float64(time.Second) + } + if opts.Cluster.Port != 0 { + if opts.Cluster.Host == "" { + opts.Cluster.Host = DEFAULT_HOST + } + if opts.Cluster.TLSTimeout == 0 { + opts.Cluster.TLSTimeout = float64(TLS_TIMEOUT) / float64(time.Second) + } + if opts.Cluster.AuthTimeout == 0 { + opts.Cluster.AuthTimeout = float64(AUTH_TIMEOUT) / float64(time.Second) + } + } + if opts.MaxControlLine == 0 { + opts.MaxControlLine = MAX_CONTROL_LINE_SIZE + } + if opts.MaxPayload == 0 { + opts.MaxPayload = MAX_PAYLOAD_SIZE + } + if opts.MaxPending == 0 { + opts.MaxPending = MAX_PENDING_SIZE + } + if opts.WriteDeadline == time.Duration(0) { + opts.WriteDeadline = DEFAULT_FLUSH_DEADLINE + } + if opts.RQSubsSweep == time.Duration(0) { + opts.RQSubsSweep = DEFAULT_REMOTE_QSUBS_SWEEPER + } + if opts.MaxClosedClients == 0 { + opts.MaxClosedClients = DEFAULT_MAX_CLOSED_CLIENTS + } +} + +// ConfigureOptions accepts a flag set and augment it with NATS Server +// specific flags. On success, an options structure is returned configured +// based on the selected flags and/or configuration file. +// The command line options take precedence to the ones in the configuration file. +func ConfigureOptions(fs *flag.FlagSet, args []string, printVersion, printHelp, printTLSHelp func()) (*Options, error) { + opts := &Options{} + var ( + showVersion bool + showHelp bool + showTLSHelp bool + signal string + configFile string + err error + ) + + fs.BoolVar(&showHelp, "h", false, "Show this message.") + fs.BoolVar(&showHelp, "help", false, "Show this message.") + fs.IntVar(&opts.Port, "port", 0, "Port to listen on.") + fs.IntVar(&opts.Port, "p", 0, "Port to listen on.") + fs.StringVar(&opts.Host, "addr", "", "Network host to listen on.") + fs.StringVar(&opts.Host, "a", "", "Network host to listen on.") + fs.StringVar(&opts.Host, "net", "", "Network host to listen on.") + fs.StringVar(&opts.ClientAdvertise, "client_advertise", "", "Client URL to advertise to other servers.") + fs.BoolVar(&opts.Debug, "D", false, "Enable Debug logging.") + fs.BoolVar(&opts.Debug, "debug", false, "Enable Debug logging.") + fs.BoolVar(&opts.Trace, "V", false, "Enable Trace logging.") + fs.BoolVar(&opts.Trace, "trace", false, "Enable Trace logging.") + fs.Bool("DV", false, "Enable Debug and Trace logging.") + fs.BoolVar(&opts.Logtime, "T", true, "Timestamp log entries.") + fs.BoolVar(&opts.Logtime, "logtime", true, "Timestamp log entries.") + fs.StringVar(&opts.Username, "user", "", "Username required for connection.") + fs.StringVar(&opts.Password, "pass", "", "Password required for connection.") + fs.StringVar(&opts.Authorization, "auth", "", "Authorization token required for connection.") + fs.IntVar(&opts.HTTPPort, "m", 0, "HTTP Port for /varz, /connz endpoints.") + fs.IntVar(&opts.HTTPPort, "http_port", 0, "HTTP Port for /varz, /connz endpoints.") + fs.IntVar(&opts.HTTPSPort, "ms", 0, "HTTPS Port for /varz, /connz endpoints.") + fs.IntVar(&opts.HTTPSPort, "https_port", 0, "HTTPS Port for /varz, /connz endpoints.") + fs.StringVar(&configFile, "c", "", "Configuration file.") + fs.StringVar(&configFile, "config", "", "Configuration file.") + fs.StringVar(&signal, "sl", "", "Send signal to gnatsd process (stop, quit, reopen, reload)") + fs.StringVar(&signal, "signal", "", "Send signal to gnatsd process (stop, quit, reopen, reload)") + fs.StringVar(&opts.PidFile, "P", "", "File to store process pid.") + fs.StringVar(&opts.PidFile, "pid", "", "File to store process pid.") + fs.StringVar(&opts.PortsFileDir, "ports_file_dir", "", "Creates a ports file in the specified directory (_.ports)") + fs.StringVar(&opts.LogFile, "l", "", "File to store logging output.") + fs.StringVar(&opts.LogFile, "log", "", "File to store logging output.") + fs.BoolVar(&opts.Syslog, "s", false, "Enable syslog as log method.") + fs.BoolVar(&opts.Syslog, "syslog", false, "Enable syslog as log method..") + fs.StringVar(&opts.RemoteSyslog, "r", "", "Syslog server addr (udp://127.0.0.1:514).") + fs.StringVar(&opts.RemoteSyslog, "remote_syslog", "", "Syslog server addr (udp://127.0.0.1:514).") + fs.BoolVar(&showVersion, "version", false, "Print version information.") + fs.BoolVar(&showVersion, "v", false, "Print version information.") + fs.IntVar(&opts.ProfPort, "profile", 0, "Profiling HTTP port") + fs.StringVar(&opts.RoutesStr, "routes", "", "Routes to actively solicit a connection.") + fs.StringVar(&opts.Cluster.ListenStr, "cluster", "", "Cluster url from which members can solicit routes.") + fs.StringVar(&opts.Cluster.ListenStr, "cluster_listen", "", "Cluster url from which members can solicit routes.") + fs.StringVar(&opts.Cluster.Advertise, "cluster_advertise", "", "Cluster URL to advertise to other servers.") + fs.BoolVar(&opts.Cluster.NoAdvertise, "no_advertise", false, "Advertise known cluster IPs to clients.") + fs.IntVar(&opts.Cluster.ConnectRetries, "connect_retries", 0, "For implicit routes, number of connect retries") + fs.BoolVar(&showTLSHelp, "help_tls", false, "TLS help.") + fs.BoolVar(&opts.TLS, "tls", false, "Enable TLS.") + fs.BoolVar(&opts.TLSVerify, "tlsverify", false, "Enable TLS with client verification.") + fs.StringVar(&opts.TLSCert, "tlscert", "", "Server certificate file.") + fs.StringVar(&opts.TLSKey, "tlskey", "", "Private key for server certificate.") + fs.StringVar(&opts.TLSCaCert, "tlscacert", "", "Client certificate CA for verification.") + + // The flags definition above set "default" values to some of the options. + // Calling Parse() here will override the default options with any value + // specified from the command line. This is ok. We will then update the + // options with the content of the configuration file (if present), and then, + // call Parse() again to override the default+config with command line values. + // Calling Parse() before processing config file is necessary since configFile + // itself is a command line argument, and also Parse() is required in order + // to know if user wants simply to show "help" or "version", etc... + if err := fs.Parse(args); err != nil { + return nil, err + } + + if showVersion { + printVersion() + return nil, nil + } + + if showHelp { + printHelp() + return nil, nil + } + + if showTLSHelp { + printTLSHelp() + return nil, nil + } + + // Process args looking for non-flag options, + // 'version' and 'help' only for now + showVersion, showHelp, err = ProcessCommandLineArgs(fs) + if err != nil { + return nil, err + } else if showVersion { + printVersion() + return nil, nil + } else if showHelp { + printHelp() + return nil, nil + } + + // Snapshot flag options. + FlagSnapshot = opts.Clone() + + // Process signal control. + if signal != "" { + if err := processSignal(signal); err != nil { + return nil, err + } + } + + // Parse config if given + if configFile != "" { + // This will update the options with values from the config file. + if err := opts.ProcessConfigFile(configFile); err != nil { + return nil, err + } + // Call this again to override config file options with options from command line. + // Note: We don't need to check error here since if there was an error, it would + // have been caught the first time this function was called (after setting up the + // flags). + fs.Parse(args) + } + + // Special handling of some flags + var ( + flagErr error + tlsDisabled bool + tlsOverride bool + ) + fs.Visit(func(f *flag.Flag) { + // short-circuit if an error was encountered + if flagErr != nil { + return + } + if strings.HasPrefix(f.Name, "tls") { + if f.Name == "tls" { + if !opts.TLS { + // User has specified "-tls=false", we need to disable TLS + opts.TLSConfig = nil + tlsDisabled = true + tlsOverride = false + return + } + tlsOverride = true + } else if !tlsDisabled { + tlsOverride = true + } + } else { + switch f.Name { + case "DV": + // Check value to support -DV=false + boolValue, _ := strconv.ParseBool(f.Value.String()) + opts.Trace, opts.Debug = boolValue, boolValue + case "cluster", "cluster_listen": + // Override cluster config if explicitly set via flags. + flagErr = overrideCluster(opts) + case "routes": + // Keep in mind that the flag has updated opts.RoutesStr at this point. + if opts.RoutesStr == "" { + // Set routes array to nil since routes string is empty + opts.Routes = nil + return + } + routeUrls := RoutesFromStr(opts.RoutesStr) + opts.Routes = routeUrls + } + } + }) + if flagErr != nil { + return nil, flagErr + } + + // This will be true if some of the `-tls` params have been set and + // `-tls=false` has not been set. + if tlsOverride { + if err := overrideTLS(opts); err != nil { + return nil, err + } + } + + // If we don't have cluster defined in the configuration + // file and no cluster listen string override, but we do + // have a routes override, we need to report misconfiguration. + if opts.RoutesStr != "" && opts.Cluster.ListenStr == "" && opts.Cluster.Host == "" && opts.Cluster.Port == 0 { + return nil, errors.New("solicited routes require cluster capabilities, e.g. --cluster") + } + + return opts, nil +} + +// overrideTLS is called when at least "-tls=true" has been set. +func overrideTLS(opts *Options) error { + if opts.TLSCert == "" { + return errors.New("TLS Server certificate must be present and valid") + } + if opts.TLSKey == "" { + return errors.New("TLS Server private key must be present and valid") + } + + tc := TLSConfigOpts{} + tc.CertFile = opts.TLSCert + tc.KeyFile = opts.TLSKey + tc.CaFile = opts.TLSCaCert + tc.Verify = opts.TLSVerify + + var err error + opts.TLSConfig, err = GenTLSConfig(&tc) + return err +} + +// overrideCluster updates Options.Cluster if that flag "cluster" (or "cluster_listen") +// has explicitly be set in the command line. If it is set to empty string, it will +// clear the Cluster options. +func overrideCluster(opts *Options) error { + if opts.Cluster.ListenStr == "" { + // This one is enough to disable clustering. + opts.Cluster.Port = 0 + return nil + } + clusterURL, err := url.Parse(opts.Cluster.ListenStr) + if err != nil { + return err + } + h, p, err := net.SplitHostPort(clusterURL.Host) + if err != nil { + return err + } + opts.Cluster.Host = h + _, err = fmt.Sscan(p, &opts.Cluster.Port) + if err != nil { + return err + } + + if clusterURL.User != nil { + pass, hasPassword := clusterURL.User.Password() + if !hasPassword { + return errors.New("expected cluster password to be set") + } + opts.Cluster.Password = pass + + user := clusterURL.User.Username() + opts.Cluster.Username = user + } else { + // Since we override from flag and there is no user/pwd, make + // sure we clear what we may have gotten from config file. + opts.Cluster.Username = "" + opts.Cluster.Password = "" + } + + return nil +} + +func processSignal(signal string) error { + var ( + pid string + commandAndPid = strings.Split(signal, "=") + ) + if l := len(commandAndPid); l == 2 { + pid = commandAndPid[1] + } else if l > 2 { + return fmt.Errorf("invalid signal parameters: %v", commandAndPid[2:]) + } + if err := ProcessSignal(Command(commandAndPid[0]), pid); err != nil { + return err + } + os.Exit(0) + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/parser.go b/vendor/github.com/nats-io/gnatsd/server/parser.go new file mode 100644 index 00000000000..088894fb78e --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/parser.go @@ -0,0 +1,749 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" +) + +type pubArg struct { + subject []byte + reply []byte + sid []byte + szb []byte + size int +} + +type parseState struct { + state int + as int + drop int + pa pubArg + argBuf []byte + msgBuf []byte + scratch [MAX_CONTROL_LINE_SIZE]byte +} + +// Parser constants +const ( + OP_START = iota + OP_PLUS + OP_PLUS_O + OP_PLUS_OK + OP_MINUS + OP_MINUS_E + OP_MINUS_ER + OP_MINUS_ERR + OP_MINUS_ERR_SPC + MINUS_ERR_ARG + OP_C + OP_CO + OP_CON + OP_CONN + OP_CONNE + OP_CONNEC + OP_CONNECT + CONNECT_ARG + OP_P + OP_PU + OP_PUB + OP_PUB_SPC + PUB_ARG + OP_PI + OP_PIN + OP_PING + OP_PO + OP_PON + OP_PONG + MSG_PAYLOAD + MSG_END + OP_S + OP_SU + OP_SUB + OP_SUB_SPC + SUB_ARG + OP_U + OP_UN + OP_UNS + OP_UNSU + OP_UNSUB + OP_UNSUB_SPC + UNSUB_ARG + OP_M + OP_MS + OP_MSG + OP_MSG_SPC + MSG_ARG + OP_I + OP_IN + OP_INF + OP_INFO + INFO_ARG +) + +func (c *client) parse(buf []byte) error { + var i int + var b byte + + mcl := MAX_CONTROL_LINE_SIZE + if c.srv != nil && c.srv.getOpts() != nil { + mcl = c.srv.getOpts().MaxControlLine + } + + // snapshot this, and reset when we receive a + // proper CONNECT if needed. + authSet := c.isAuthTimerSet() + + // Move to loop instead of range syntax to allow jumping of i + for i = 0; i < len(buf); i++ { + b = buf[i] + + switch c.state { + case OP_START: + if b != 'C' && b != 'c' && authSet { + goto authErr + } + switch b { + case 'P', 'p': + c.state = OP_P + case 'S', 's': + c.state = OP_S + case 'U', 'u': + c.state = OP_U + case 'M', 'm': + if c.typ == CLIENT { + goto parseErr + } else { + c.state = OP_M + } + case 'C', 'c': + c.state = OP_C + case 'I', 'i': + c.state = OP_I + case '+': + c.state = OP_PLUS + case '-': + c.state = OP_MINUS + default: + goto parseErr + } + case OP_P: + switch b { + case 'U', 'u': + c.state = OP_PU + case 'I', 'i': + c.state = OP_PI + case 'O', 'o': + c.state = OP_PO + default: + goto parseErr + } + case OP_PU: + switch b { + case 'B', 'b': + c.state = OP_PUB + default: + goto parseErr + } + case OP_PUB: + switch b { + case ' ', '\t': + c.state = OP_PUB_SPC + default: + goto parseErr + } + case OP_PUB_SPC: + switch b { + case ' ', '\t': + continue + default: + c.state = PUB_ARG + c.as = i + } + case PUB_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + } else { + arg = buf[c.as : i-c.drop] + } + if err := c.processPub(arg); err != nil { + return err + } + c.drop, c.as, c.state = OP_START, i+1, MSG_PAYLOAD + // If we don't have a saved buffer then jump ahead with + // the index. If this overruns what is left we fall out + // and process split buffer. + if c.msgBuf == nil { + i = c.as + c.pa.size - LEN_CR_LF + } + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + case MSG_PAYLOAD: + if c.msgBuf != nil { + // copy as much as we can to the buffer and skip ahead. + toCopy := c.pa.size - len(c.msgBuf) + avail := len(buf) - i + if avail < toCopy { + toCopy = avail + } + if toCopy > 0 { + start := len(c.msgBuf) + // This is needed for copy to work. + c.msgBuf = c.msgBuf[:start+toCopy] + copy(c.msgBuf[start:], buf[i:i+toCopy]) + // Update our index + i = (i + toCopy) - 1 + } else { + // Fall back to append if needed. + c.msgBuf = append(c.msgBuf, b) + } + if len(c.msgBuf) >= c.pa.size { + c.state = MSG_END + } + } else if i-c.as >= c.pa.size { + c.state = MSG_END + } + case MSG_END: + switch b { + case '\n': + if c.msgBuf != nil { + c.msgBuf = append(c.msgBuf, b) + } else { + c.msgBuf = buf[c.as : i+1] + } + // strict check for proto + if len(c.msgBuf) != c.pa.size+LEN_CR_LF { + goto parseErr + } + c.processMsg(c.msgBuf) + c.argBuf, c.msgBuf = nil, nil + c.drop, c.as, c.state = 0, i+1, OP_START + default: + if c.msgBuf != nil { + c.msgBuf = append(c.msgBuf, b) + } + continue + } + case OP_S: + switch b { + case 'U', 'u': + c.state = OP_SU + default: + goto parseErr + } + case OP_SU: + switch b { + case 'B', 'b': + c.state = OP_SUB + default: + goto parseErr + } + case OP_SUB: + switch b { + case ' ', '\t': + c.state = OP_SUB_SPC + default: + goto parseErr + } + case OP_SUB_SPC: + switch b { + case ' ', '\t': + continue + default: + c.state = SUB_ARG + c.as = i + } + case SUB_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + c.argBuf = nil + } else { + arg = buf[c.as : i-c.drop] + } + if err := c.processSub(arg); err != nil { + return err + } + c.drop, c.as, c.state = 0, i+1, OP_START + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + case OP_U: + switch b { + case 'N', 'n': + c.state = OP_UN + default: + goto parseErr + } + case OP_UN: + switch b { + case 'S', 's': + c.state = OP_UNS + default: + goto parseErr + } + case OP_UNS: + switch b { + case 'U', 'u': + c.state = OP_UNSU + default: + goto parseErr + } + case OP_UNSU: + switch b { + case 'B', 'b': + c.state = OP_UNSUB + default: + goto parseErr + } + case OP_UNSUB: + switch b { + case ' ', '\t': + c.state = OP_UNSUB_SPC + default: + goto parseErr + } + case OP_UNSUB_SPC: + switch b { + case ' ', '\t': + continue + default: + c.state = UNSUB_ARG + c.as = i + } + case UNSUB_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + c.argBuf = nil + } else { + arg = buf[c.as : i-c.drop] + } + if err := c.processUnsub(arg); err != nil { + return err + } + c.drop, c.as, c.state = 0, i+1, OP_START + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + case OP_PI: + switch b { + case 'N', 'n': + c.state = OP_PIN + default: + goto parseErr + } + case OP_PIN: + switch b { + case 'G', 'g': + c.state = OP_PING + default: + goto parseErr + } + case OP_PING: + switch b { + case '\n': + c.processPing() + c.drop, c.state = 0, OP_START + } + case OP_PO: + switch b { + case 'N', 'n': + c.state = OP_PON + default: + goto parseErr + } + case OP_PON: + switch b { + case 'G', 'g': + c.state = OP_PONG + default: + goto parseErr + } + case OP_PONG: + switch b { + case '\n': + c.processPong() + c.drop, c.state = 0, OP_START + } + case OP_C: + switch b { + case 'O', 'o': + c.state = OP_CO + default: + goto parseErr + } + case OP_CO: + switch b { + case 'N', 'n': + c.state = OP_CON + default: + goto parseErr + } + case OP_CON: + switch b { + case 'N', 'n': + c.state = OP_CONN + default: + goto parseErr + } + case OP_CONN: + switch b { + case 'E', 'e': + c.state = OP_CONNE + default: + goto parseErr + } + case OP_CONNE: + switch b { + case 'C', 'c': + c.state = OP_CONNEC + default: + goto parseErr + } + case OP_CONNEC: + switch b { + case 'T', 't': + c.state = OP_CONNECT + default: + goto parseErr + } + case OP_CONNECT: + switch b { + case ' ', '\t': + continue + default: + c.state = CONNECT_ARG + c.as = i + } + case CONNECT_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + c.argBuf = nil + } else { + arg = buf[c.as : i-c.drop] + } + if err := c.processConnect(arg); err != nil { + return err + } + c.drop, c.state = 0, OP_START + // Reset notion on authSet + authSet = c.isAuthTimerSet() + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + case OP_M: + switch b { + case 'S', 's': + c.state = OP_MS + default: + goto parseErr + } + case OP_MS: + switch b { + case 'G', 'g': + c.state = OP_MSG + default: + goto parseErr + } + case OP_MSG: + switch b { + case ' ', '\t': + c.state = OP_MSG_SPC + default: + goto parseErr + } + case OP_MSG_SPC: + switch b { + case ' ', '\t': + continue + default: + c.state = MSG_ARG + c.as = i + } + case MSG_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + } else { + arg = buf[c.as : i-c.drop] + } + if err := c.processMsgArgs(arg); err != nil { + return err + } + c.drop, c.as, c.state = 0, i+1, MSG_PAYLOAD + + // jump ahead with the index. If this overruns + // what is left we fall out and process split + // buffer. + i = c.as + c.pa.size - 1 + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + case OP_I: + switch b { + case 'N', 'n': + c.state = OP_IN + default: + goto parseErr + } + case OP_IN: + switch b { + case 'F', 'f': + c.state = OP_INF + default: + goto parseErr + } + case OP_INF: + switch b { + case 'O', 'o': + c.state = OP_INFO + default: + goto parseErr + } + case OP_INFO: + switch b { + case ' ', '\t': + continue + default: + c.state = INFO_ARG + c.as = i + } + case INFO_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + c.argBuf = nil + } else { + arg = buf[c.as : i-c.drop] + } + if err := c.processInfo(arg); err != nil { + return err + } + c.drop, c.as, c.state = 0, i+1, OP_START + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + case OP_PLUS: + switch b { + case 'O', 'o': + c.state = OP_PLUS_O + default: + goto parseErr + } + case OP_PLUS_O: + switch b { + case 'K', 'k': + c.state = OP_PLUS_OK + default: + goto parseErr + } + case OP_PLUS_OK: + switch b { + case '\n': + c.drop, c.state = 0, OP_START + } + case OP_MINUS: + switch b { + case 'E', 'e': + c.state = OP_MINUS_E + default: + goto parseErr + } + case OP_MINUS_E: + switch b { + case 'R', 'r': + c.state = OP_MINUS_ER + default: + goto parseErr + } + case OP_MINUS_ER: + switch b { + case 'R', 'r': + c.state = OP_MINUS_ERR + default: + goto parseErr + } + case OP_MINUS_ERR: + switch b { + case ' ', '\t': + c.state = OP_MINUS_ERR_SPC + default: + goto parseErr + } + case OP_MINUS_ERR_SPC: + switch b { + case ' ', '\t': + continue + default: + c.state = MINUS_ERR_ARG + c.as = i + } + case MINUS_ERR_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + c.argBuf = nil + } else { + arg = buf[c.as : i-c.drop] + } + c.processErr(string(arg)) + c.drop, c.as, c.state = 0, i+1, OP_START + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + default: + goto parseErr + } + } + + // Check for split buffer scenarios for any ARG state. + if c.state == SUB_ARG || c.state == UNSUB_ARG || c.state == PUB_ARG || + c.state == MSG_ARG || c.state == MINUS_ERR_ARG || + c.state == CONNECT_ARG || c.state == INFO_ARG { + // Setup a holder buffer to deal with split buffer scenario. + if c.argBuf == nil { + c.argBuf = c.scratch[:0] + c.argBuf = append(c.argBuf, buf[c.as:i-c.drop]...) + } + // Check for violations of control line length here. Note that this is not + // exact at all but the performance hit is too great to be precise, and + // catching here should prevent memory exhaustion attacks. + if len(c.argBuf) > mcl { + c.sendErr("Maximum Control Line Exceeded") + c.closeConnection(MaxControlLineExceeded) + return ErrMaxControlLine + } + } + + // Check for split msg + if (c.state == MSG_PAYLOAD || c.state == MSG_END) && c.msgBuf == nil { + // We need to clone the pubArg if it is still referencing the + // read buffer and we are not able to process the msg. + if c.argBuf == nil { + // Works also for MSG_ARG, when message comes from ROUTE. + c.clonePubArg() + } + + // If we will overflow the scratch buffer, just create a + // new buffer to hold the split message. + if c.pa.size > cap(c.scratch)-len(c.argBuf) { + lrem := len(buf[c.as:]) + + // Consider it a protocol error when the remaining payload + // is larger than the reported size for PUB. It can happen + // when processing incomplete messages from rogue clients. + if lrem > c.pa.size+LEN_CR_LF { + goto parseErr + } + c.msgBuf = make([]byte, lrem, c.pa.size+LEN_CR_LF) + copy(c.msgBuf, buf[c.as:]) + } else { + c.msgBuf = c.scratch[len(c.argBuf):len(c.argBuf)] + c.msgBuf = append(c.msgBuf, (buf[c.as:])...) + } + } + + return nil + +authErr: + c.authViolation() + return ErrAuthorization + +parseErr: + c.sendErr("Unknown Protocol Operation") + snip := protoSnippet(i, buf) + err := fmt.Errorf("%s parser ERROR, state=%d, i=%d: proto='%s...'", + c.typeString(), c.state, i, snip) + return err +} + +func protoSnippet(start int, buf []byte) string { + stop := start + PROTO_SNIPPET_SIZE + bufSize := len(buf) + if start >= bufSize { + return `""` + } + if stop > bufSize { + stop = bufSize - 1 + } + return fmt.Sprintf("%q", buf[start:stop]) +} + +// clonePubArg is used when the split buffer scenario has the pubArg in the existing read buffer, but +// we need to hold onto it into the next read. +func (c *client) clonePubArg() { + c.argBuf = c.scratch[:0] + c.argBuf = append(c.argBuf, c.pa.subject...) + c.argBuf = append(c.argBuf, c.pa.reply...) + c.argBuf = append(c.argBuf, c.pa.sid...) + c.argBuf = append(c.argBuf, c.pa.szb...) + + c.pa.subject = c.argBuf[:len(c.pa.subject)] + + if c.pa.reply != nil { + c.pa.reply = c.argBuf[len(c.pa.subject) : len(c.pa.subject)+len(c.pa.reply)] + } + + if c.pa.sid != nil { + c.pa.sid = c.argBuf[len(c.pa.subject)+len(c.pa.reply) : len(c.pa.subject)+len(c.pa.reply)+len(c.pa.sid)] + } + + c.pa.szb = c.argBuf[len(c.pa.subject)+len(c.pa.reply)+len(c.pa.sid):] +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_darwin.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_darwin.go new file mode 100644 index 00000000000..b00f1e00f63 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_darwin.go @@ -0,0 +1,34 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pse + +import ( + "fmt" + "os" + "os/exec" +) + +// ProcUsage returns CPU usage +func ProcUsage(pcpu *float64, rss, vss *int64) error { + pidStr := fmt.Sprintf("%d", os.Getpid()) + out, err := exec.Command("ps", "o", "pcpu=,rss=,vsz=", "-p", pidStr).Output() + if err != nil { + *rss, *vss = -1, -1 + return fmt.Errorf("ps call failed:%v", err) + } + fmt.Sscanf(string(out), "%f %d %d", pcpu, rss, vss) + *rss *= 1024 // 1k blocks, want bytes. + *vss *= 1024 // 1k blocks, want bytes. + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_freebsd.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_freebsd.go new file mode 100644 index 00000000000..40c52847111 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_freebsd.go @@ -0,0 +1,83 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pse + +/* +#include +#include +#include +#include +#include + +long pagetok(long size) +{ + int pageshift, pagesize; + + pagesize = getpagesize(); + pageshift = 0; + + while (pagesize > 1) { + pageshift++; + pagesize >>= 1; + } + + return (size << pageshift); +} + +int getusage(double *pcpu, unsigned int *rss, unsigned int *vss) +{ + int mib[4], ret; + size_t len; + struct kinfo_proc kp; + + len = 4; + sysctlnametomib("kern.proc.pid", mib, &len); + + mib[3] = getpid(); + len = sizeof(kp); + + ret = sysctl(mib, 4, &kp, &len, NULL, 0); + if (ret != 0) { + return (errno); + } + + *rss = pagetok(kp.ki_rssize); + *vss = kp.ki_size; + *pcpu = kp.ki_pctcpu; + + return 0; +} + +*/ +import "C" + +import ( + "syscall" +) + +// This is a placeholder for now. +func ProcUsage(pcpu *float64, rss, vss *int64) error { + var r, v C.uint + var c C.double + + if ret := C.getusage(&c, &r, &v); ret != 0 { + return syscall.Errno(ret) + } + + *pcpu = float64(c) + *rss = int64(r) + *vss = int64(v) + + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_linux.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_linux.go new file mode 100644 index 00000000000..9fea3e07dcf --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_linux.go @@ -0,0 +1,126 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pse + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "sync/atomic" + "syscall" + "time" +) + +var ( + procStatFile string + ticks int64 + lastTotal int64 + lastSeconds int64 + ipcpu int64 +) + +const ( + utimePos = 13 + stimePos = 14 + startPos = 21 + vssPos = 22 + rssPos = 23 +) + +func init() { + // Avoiding to generate docker image without CGO + ticks = 100 // int64(C.sysconf(C._SC_CLK_TCK)) + procStatFile = fmt.Sprintf("/proc/%d/stat", os.Getpid()) + periodic() +} + +// Sampling function to keep pcpu relevant. +func periodic() { + contents, err := ioutil.ReadFile(procStatFile) + if err != nil { + return + } + fields := bytes.Fields(contents) + + // PCPU + pstart := parseInt64(fields[startPos]) + utime := parseInt64(fields[utimePos]) + stime := parseInt64(fields[stimePos]) + total := utime + stime + + var sysinfo syscall.Sysinfo_t + if err := syscall.Sysinfo(&sysinfo); err != nil { + return + } + + seconds := int64(sysinfo.Uptime) - (pstart / ticks) + + // Save off temps + lt := lastTotal + ls := lastSeconds + + // Update last sample + lastTotal = total + lastSeconds = seconds + + // Adjust to current time window + total -= lt + seconds -= ls + + if seconds > 0 { + atomic.StoreInt64(&ipcpu, (total*1000/ticks)/seconds) + } + + time.AfterFunc(1*time.Second, periodic) +} + +func ProcUsage(pcpu *float64, rss, vss *int64) error { + contents, err := ioutil.ReadFile(procStatFile) + if err != nil { + return err + } + fields := bytes.Fields(contents) + + // Memory + *rss = (parseInt64(fields[rssPos])) << 12 + *vss = parseInt64(fields[vssPos]) + + // PCPU + // We track this with periodic sampling, so just load and go. + *pcpu = float64(atomic.LoadInt64(&ipcpu)) / 10.0 + + return nil +} + +// Ascii numbers 0-9 +const ( + asciiZero = 48 + asciiNine = 57 +) + +// parseInt64 expects decimal positive numbers. We +// return -1 to signal error +func parseInt64(d []byte) (n int64) { + if len(d) == 0 { + return -1 + } + for _, dec := range d { + if dec < asciiZero || dec > asciiNine { + return -1 + } + n = n*10 + (int64(dec) - asciiZero) + } + return n +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_openbsd.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_openbsd.go new file mode 100644 index 00000000000..260f1a7ce65 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_openbsd.go @@ -0,0 +1,36 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copied from pse_darwin.go + +package pse + +import ( + "fmt" + "os" + "os/exec" +) + +// ProcUsage returns CPU usage +func ProcUsage(pcpu *float64, rss, vss *int64) error { + pidStr := fmt.Sprintf("%d", os.Getpid()) + out, err := exec.Command("ps", "o", "pcpu=,rss=,vsz=", "-p", pidStr).Output() + if err != nil { + *rss, *vss = -1, -1 + return fmt.Errorf("ps call failed:%v", err) + } + fmt.Sscanf(string(out), "%f %d %d", pcpu, rss, vss) + *rss *= 1024 // 1k blocks, want bytes. + *vss *= 1024 // 1k blocks, want bytes. + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_rumprun.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_rumprun.go new file mode 100644 index 00000000000..48e80fca2c3 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_rumprun.go @@ -0,0 +1,25 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build rumprun + +package pse + +// This is a placeholder for now. +func ProcUsage(pcpu *float64, rss, vss *int64) error { + *pcpu = 0.0 + *rss = 0 + *vss = 0 + + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_solaris.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_solaris.go new file mode 100644 index 00000000000..8e40d2ed306 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_solaris.go @@ -0,0 +1,23 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pse + +// This is a placeholder for now. +func ProcUsage(pcpu *float64, rss, vss *int64) error { + *pcpu = 0.0 + *rss = 0 + *vss = 0 + + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_windows.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_windows.go new file mode 100644 index 00000000000..a8b110704fe --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_windows.go @@ -0,0 +1,280 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build windows + +package pse + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + "unsafe" +) + +var ( + pdh = syscall.NewLazyDLL("pdh.dll") + winPdhOpenQuery = pdh.NewProc("PdhOpenQuery") + winPdhAddCounter = pdh.NewProc("PdhAddCounterW") + winPdhCollectQueryData = pdh.NewProc("PdhCollectQueryData") + winPdhGetFormattedCounterValue = pdh.NewProc("PdhGetFormattedCounterValue") + winPdhGetFormattedCounterArray = pdh.NewProc("PdhGetFormattedCounterArrayW") +) + +// global performance counter query handle and counters +var ( + pcHandle PDH_HQUERY + pidCounter, cpuCounter, rssCounter, vssCounter PDH_HCOUNTER + prevCPU float64 + prevRss int64 + prevVss int64 + lastSampleTime time.Time + processPid int + pcQueryLock sync.Mutex + initialSample = true +) + +// maxQuerySize is the number of values to return from a query. +// It represents the maximum # of servers that can be queried +// simultaneously running on a machine. +const maxQuerySize = 512 + +// Keep static memory around to reuse; this works best for passing +// into the pdh API. +var counterResults [maxQuerySize]PDH_FMT_COUNTERVALUE_ITEM_DOUBLE + +// PDH Types +type ( + PDH_HQUERY syscall.Handle + PDH_HCOUNTER syscall.Handle +) + +// PDH constants used here +const ( + PDH_FMT_DOUBLE = 0x00000200 + PDH_INVALID_DATA = 0xC0000BC6 + PDH_MORE_DATA = 0x800007D2 +) + +// PDH_FMT_COUNTERVALUE_DOUBLE - double value +type PDH_FMT_COUNTERVALUE_DOUBLE struct { + CStatus uint32 + DoubleValue float64 +} + +// PDH_FMT_COUNTERVALUE_ITEM_DOUBLE is an array +// element of a double value +type PDH_FMT_COUNTERVALUE_ITEM_DOUBLE struct { + SzName *uint16 // pointer to a string + FmtValue PDH_FMT_COUNTERVALUE_DOUBLE +} + +func pdhAddCounter(hQuery PDH_HQUERY, szFullCounterPath string, dwUserData uintptr, phCounter *PDH_HCOUNTER) error { + ptxt, _ := syscall.UTF16PtrFromString(szFullCounterPath) + r0, _, _ := winPdhAddCounter.Call( + uintptr(hQuery), + uintptr(unsafe.Pointer(ptxt)), + dwUserData, + uintptr(unsafe.Pointer(phCounter))) + + if r0 != 0 { + return fmt.Errorf("pdhAddCounter failed. %d", r0) + } + return nil +} + +func pdhOpenQuery(datasrc *uint16, userdata uint32, query *PDH_HQUERY) error { + r0, _, _ := syscall.Syscall(winPdhOpenQuery.Addr(), 3, 0, uintptr(userdata), uintptr(unsafe.Pointer(query))) + if r0 != 0 { + return fmt.Errorf("pdhOpenQuery failed - %d", r0) + } + return nil +} + +func pdhCollectQueryData(hQuery PDH_HQUERY) error { + r0, _, _ := winPdhCollectQueryData.Call(uintptr(hQuery)) + if r0 != 0 { + return fmt.Errorf("pdhCollectQueryData failed - %d", r0) + } + return nil +} + +// pdhGetFormattedCounterArrayDouble returns the value of return code +// rather than error, to easily check return codes +func pdhGetFormattedCounterArrayDouble(hCounter PDH_HCOUNTER, lpdwBufferSize *uint32, lpdwBufferCount *uint32, itemBuffer *PDH_FMT_COUNTERVALUE_ITEM_DOUBLE) uint32 { + ret, _, _ := winPdhGetFormattedCounterArray.Call( + uintptr(hCounter), + uintptr(PDH_FMT_DOUBLE), + uintptr(unsafe.Pointer(lpdwBufferSize)), + uintptr(unsafe.Pointer(lpdwBufferCount)), + uintptr(unsafe.Pointer(itemBuffer))) + + return uint32(ret) +} + +func getCounterArrayData(counter PDH_HCOUNTER) ([]float64, error) { + var bufSize uint32 + var bufCount uint32 + + // Retrieving array data requires two calls, the first which + // requires an addressable empty buffer, and sets size fields. + // The second call returns the data. + initialBuf := make([]PDH_FMT_COUNTERVALUE_ITEM_DOUBLE, 1) + ret := pdhGetFormattedCounterArrayDouble(counter, &bufSize, &bufCount, &initialBuf[0]) + if ret == PDH_MORE_DATA { + // we'll likely never get here, but be safe. + if bufCount > maxQuerySize { + bufCount = maxQuerySize + } + ret = pdhGetFormattedCounterArrayDouble(counter, &bufSize, &bufCount, &counterResults[0]) + if ret == 0 { + rv := make([]float64, bufCount) + for i := 0; i < int(bufCount); i++ { + rv[i] = counterResults[i].FmtValue.DoubleValue + } + return rv, nil + } + } + if ret != 0 { + return nil, fmt.Errorf("getCounterArrayData failed - %d", ret) + } + + return nil, nil +} + +// getProcessImageName returns the name of the process image, as expected by +// the performance counter API. +func getProcessImageName() (name string) { + name = filepath.Base(os.Args[0]) + name = strings.TrimRight(name, ".exe") + return +} + +// initialize our counters +func initCounters() (err error) { + + processPid = os.Getpid() + // require an addressible nil pointer + var source uint16 + if err := pdhOpenQuery(&source, 0, &pcHandle); err != nil { + return err + } + + // setup the performance counters, search for all server instances + name := fmt.Sprintf("%s*", getProcessImageName()) + pidQuery := fmt.Sprintf("\\Process(%s)\\ID Process", name) + cpuQuery := fmt.Sprintf("\\Process(%s)\\%% Processor Time", name) + rssQuery := fmt.Sprintf("\\Process(%s)\\Working Set - Private", name) + vssQuery := fmt.Sprintf("\\Process(%s)\\Virtual Bytes", name) + + if err = pdhAddCounter(pcHandle, pidQuery, 0, &pidCounter); err != nil { + return err + } + if err = pdhAddCounter(pcHandle, cpuQuery, 0, &cpuCounter); err != nil { + return err + } + if err = pdhAddCounter(pcHandle, rssQuery, 0, &rssCounter); err != nil { + return err + } + if err = pdhAddCounter(pcHandle, vssQuery, 0, &vssCounter); err != nil { + return err + } + + // prime the counters by collecting once, and sleep to get somewhat + // useful information the first request. Counters for the CPU require + // at least two collect calls. + if err = pdhCollectQueryData(pcHandle); err != nil { + return err + } + time.Sleep(50) + + return nil +} + +// ProcUsage returns process CPU and memory statistics +func ProcUsage(pcpu *float64, rss, vss *int64) error { + var err error + + // For simplicity, protect the entire call. + // Most simultaneous requests will immediately return + // with cached values. + pcQueryLock.Lock() + defer pcQueryLock.Unlock() + + // First time through, initialize counters. + if initialSample { + if err = initCounters(); err != nil { + return err + } + initialSample = false + } else if time.Since(lastSampleTime) < (2 * time.Second) { + // only refresh every two seconds as to minimize impact + // on the server. + *pcpu = prevCPU + *rss = prevRss + *vss = prevVss + return nil + } + + // always save the sample time, even on errors. + defer func() { + lastSampleTime = time.Now() + }() + + // refresh the performance counter data + if err = pdhCollectQueryData(pcHandle); err != nil { + return err + } + + // retrieve the data + var pidAry, cpuAry, rssAry, vssAry []float64 + if pidAry, err = getCounterArrayData(pidCounter); err != nil { + return err + } + if cpuAry, err = getCounterArrayData(cpuCounter); err != nil { + return err + } + if rssAry, err = getCounterArrayData(rssCounter); err != nil { + return err + } + if vssAry, err = getCounterArrayData(vssCounter); err != nil { + return err + } + // find the index of the entry for this process + idx := int(-1) + for i := range pidAry { + if int(pidAry[i]) == processPid { + idx = i + break + } + } + // no pid found... + if idx < 0 { + return fmt.Errorf("could not find pid in performance counter results") + } + // assign values from the performance counters + *pcpu = cpuAry[idx] + *rss = int64(rssAry[idx]) + *vss = int64(vssAry[idx]) + + // save off cache values + prevCPU = *pcpu + prevRss = *rss + prevVss = *vss + + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/reload.go b/vendor/github.com/nats-io/gnatsd/server/reload.go new file mode 100644 index 00000000000..aa938ff0b22 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/reload.go @@ -0,0 +1,723 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "crypto/tls" + "errors" + "fmt" + "net/url" + "reflect" + "strings" + "sync/atomic" + "time" +) + +// FlagSnapshot captures the server options as specified by CLI flags at +// startup. This should not be modified once the server has started. +var FlagSnapshot *Options + +// option is a hot-swappable configuration setting. +type option interface { + // Apply the server option. + Apply(server *Server) + + // IsLoggingChange indicates if this option requires reloading the logger. + IsLoggingChange() bool + + // IsAuthChange indicates if this option requires reloading authorization. + IsAuthChange() bool +} + +// loggingOption is a base struct that provides default option behaviors for +// logging-related options. +type loggingOption struct{} + +func (l loggingOption) IsLoggingChange() bool { + return true +} + +func (l loggingOption) IsAuthChange() bool { + return false +} + +// traceOption implements the option interface for the `trace` setting. +type traceOption struct { + loggingOption + newValue bool +} + +// Apply is a no-op because logging will be reloaded after options are applied. +func (t *traceOption) Apply(server *Server) { + server.Noticef("Reloaded: trace = %v", t.newValue) +} + +// debugOption implements the option interface for the `debug` setting. +type debugOption struct { + loggingOption + newValue bool +} + +// Apply is a no-op because logging will be reloaded after options are applied. +func (d *debugOption) Apply(server *Server) { + server.Noticef("Reloaded: debug = %v", d.newValue) +} + +// logtimeOption implements the option interface for the `logtime` setting. +type logtimeOption struct { + loggingOption + newValue bool +} + +// Apply is a no-op because logging will be reloaded after options are applied. +func (l *logtimeOption) Apply(server *Server) { + server.Noticef("Reloaded: logtime = %v", l.newValue) +} + +// logfileOption implements the option interface for the `log_file` setting. +type logfileOption struct { + loggingOption + newValue string +} + +// Apply is a no-op because logging will be reloaded after options are applied. +func (l *logfileOption) Apply(server *Server) { + server.Noticef("Reloaded: log_file = %v", l.newValue) +} + +// syslogOption implements the option interface for the `syslog` setting. +type syslogOption struct { + loggingOption + newValue bool +} + +// Apply is a no-op because logging will be reloaded after options are applied. +func (s *syslogOption) Apply(server *Server) { + server.Noticef("Reloaded: syslog = %v", s.newValue) +} + +// remoteSyslogOption implements the option interface for the `remote_syslog` +// setting. +type remoteSyslogOption struct { + loggingOption + newValue string +} + +// Apply is a no-op because logging will be reloaded after options are applied. +func (r *remoteSyslogOption) Apply(server *Server) { + server.Noticef("Reloaded: remote_syslog = %v", r.newValue) +} + +// noopOption is a base struct that provides default no-op behaviors. +type noopOption struct{} + +func (n noopOption) IsLoggingChange() bool { + return false +} + +func (n noopOption) IsAuthChange() bool { + return false +} + +// tlsOption implements the option interface for the `tls` setting. +type tlsOption struct { + noopOption + newValue *tls.Config +} + +// Apply the tls change. +func (t *tlsOption) Apply(server *Server) { + server.mu.Lock() + tlsRequired := t.newValue != nil + server.info.TLSRequired = tlsRequired + message := "disabled" + if tlsRequired { + server.info.TLSVerify = (t.newValue.ClientAuth == tls.RequireAndVerifyClientCert) + message = "enabled" + } + server.mu.Unlock() + server.Noticef("Reloaded: tls = %s", message) +} + +// tlsTimeoutOption implements the option interface for the tls `timeout` +// setting. +type tlsTimeoutOption struct { + noopOption + newValue float64 +} + +// Apply is a no-op because the timeout will be reloaded after options are +// applied. +func (t *tlsTimeoutOption) Apply(server *Server) { + server.Noticef("Reloaded: tls timeout = %v", t.newValue) +} + +// authOption is a base struct that provides default option behaviors. +type authOption struct{} + +func (o authOption) IsLoggingChange() bool { + return false +} + +func (o authOption) IsAuthChange() bool { + return true +} + +// usernameOption implements the option interface for the `username` setting. +type usernameOption struct { + authOption +} + +// Apply is a no-op because authorization will be reloaded after options are +// applied. +func (u *usernameOption) Apply(server *Server) { + server.Noticef("Reloaded: authorization username") +} + +// passwordOption implements the option interface for the `password` setting. +type passwordOption struct { + authOption +} + +// Apply is a no-op because authorization will be reloaded after options are +// applied. +func (p *passwordOption) Apply(server *Server) { + server.Noticef("Reloaded: authorization password") +} + +// authorizationOption implements the option interface for the `token` +// authorization setting. +type authorizationOption struct { + authOption +} + +// Apply is a no-op because authorization will be reloaded after options are +// applied. +func (a *authorizationOption) Apply(server *Server) { + server.Noticef("Reloaded: authorization token") +} + +// authTimeoutOption implements the option interface for the authorization +// `timeout` setting. +type authTimeoutOption struct { + noopOption // Not authOption because this is a no-op; will be reloaded with options. + newValue float64 +} + +// Apply is a no-op because the timeout will be reloaded after options are +// applied. +func (a *authTimeoutOption) Apply(server *Server) { + server.Noticef("Reloaded: authorization timeout = %v", a.newValue) +} + +// usersOption implements the option interface for the authorization `users` +// setting. +type usersOption struct { + authOption + newValue []*User +} + +func (u *usersOption) Apply(server *Server) { + server.Noticef("Reloaded: authorization users") +} + +// clusterOption implements the option interface for the `cluster` setting. +type clusterOption struct { + authOption + newValue ClusterOpts +} + +// Apply the cluster change. +func (c *clusterOption) Apply(server *Server) { + // TODO: support enabling/disabling clustering. + server.mu.Lock() + tlsRequired := c.newValue.TLSConfig != nil + server.routeInfo.TLSRequired = tlsRequired + server.routeInfo.TLSVerify = tlsRequired + server.routeInfo.AuthRequired = c.newValue.Username != "" + if c.newValue.NoAdvertise { + server.routeInfo.ClientConnectURLs = nil + } else { + server.routeInfo.ClientConnectURLs = server.clientConnectURLs + } + server.setRouteInfoHostPortAndIP() + server.mu.Unlock() + server.Noticef("Reloaded: cluster") +} + +// routesOption implements the option interface for the cluster `routes` +// setting. +type routesOption struct { + noopOption + add []*url.URL + remove []*url.URL +} + +// Apply the route changes by adding and removing the necessary routes. +func (r *routesOption) Apply(server *Server) { + server.mu.Lock() + routes := make([]*client, len(server.routes)) + i := 0 + for _, client := range server.routes { + routes[i] = client + i++ + } + server.mu.Unlock() + + // Remove routes. + for _, remove := range r.remove { + for _, client := range routes { + var url *url.URL + client.mu.Lock() + if client.route != nil { + url = client.route.url + } + client.mu.Unlock() + if url != nil && urlsAreEqual(url, remove) { + // Do not attempt to reconnect when route is removed. + client.setRouteNoReconnectOnClose() + client.closeConnection(RouteRemoved) + server.Noticef("Removed route %v", remove) + } + } + } + + // Add routes. + server.solicitRoutes(r.add) + + server.Noticef("Reloaded: cluster routes") +} + +// maxConnOption implements the option interface for the `max_connections` +// setting. +type maxConnOption struct { + noopOption + newValue int +} + +// Apply the max connections change by closing random connections til we are +// below the limit if necessary. +func (m *maxConnOption) Apply(server *Server) { + server.mu.Lock() + var ( + clients = make([]*client, len(server.clients)) + i = 0 + ) + // Map iteration is random, which allows us to close random connections. + for _, client := range server.clients { + clients[i] = client + i++ + } + server.mu.Unlock() + + if m.newValue > 0 && len(clients) > m.newValue { + // Close connections til we are within the limit. + var ( + numClose = len(clients) - m.newValue + closed = 0 + ) + for _, client := range clients { + client.maxConnExceeded() + closed++ + if closed >= numClose { + break + } + } + server.Noticef("Closed %d connections to fall within max_connections", closed) + } + server.Noticef("Reloaded: max_connections = %v", m.newValue) +} + +// pidFileOption implements the option interface for the `pid_file` setting. +type pidFileOption struct { + noopOption + newValue string +} + +// Apply the setting by logging the pid to the new file. +func (p *pidFileOption) Apply(server *Server) { + if p.newValue == "" { + return + } + if err := server.logPid(); err != nil { + server.Errorf("Failed to write pidfile: %v", err) + } + server.Noticef("Reloaded: pid_file = %v", p.newValue) +} + +// portsFileDirOption implements the option interface for the `portFileDir` setting. +type portsFileDirOption struct { + noopOption + oldValue string + newValue string +} + +func (p *portsFileDirOption) Apply(server *Server) { + server.deletePortsFile(p.oldValue) + server.logPorts() + server.Noticef("Reloaded: ports_file_dir = %v", p.newValue) +} + +// maxControlLineOption implements the option interface for the +// `max_control_line` setting. +type maxControlLineOption struct { + noopOption + newValue int +} + +// Apply is a no-op because the max control line will be reloaded after options +// are applied +func (m *maxControlLineOption) Apply(server *Server) { + server.Noticef("Reloaded: max_control_line = %d", m.newValue) +} + +// maxPayloadOption implements the option interface for the `max_payload` +// setting. +type maxPayloadOption struct { + noopOption + newValue int +} + +// Apply the setting by updating the server info and each client. +func (m *maxPayloadOption) Apply(server *Server) { + server.mu.Lock() + server.info.MaxPayload = m.newValue + for _, client := range server.clients { + atomic.StoreInt64(&client.mpay, int64(m.newValue)) + } + server.mu.Unlock() + server.Noticef("Reloaded: max_payload = %d", m.newValue) +} + +// pingIntervalOption implements the option interface for the `ping_interval` +// setting. +type pingIntervalOption struct { + noopOption + newValue time.Duration +} + +// Apply is a no-op because the ping interval will be reloaded after options +// are applied. +func (p *pingIntervalOption) Apply(server *Server) { + server.Noticef("Reloaded: ping_interval = %s", p.newValue) +} + +// maxPingsOutOption implements the option interface for the `ping_max` +// setting. +type maxPingsOutOption struct { + noopOption + newValue int +} + +// Apply is a no-op because the ping interval will be reloaded after options +// are applied. +func (m *maxPingsOutOption) Apply(server *Server) { + server.Noticef("Reloaded: ping_max = %d", m.newValue) +} + +// writeDeadlineOption implements the option interface for the `write_deadline` +// setting. +type writeDeadlineOption struct { + noopOption + newValue time.Duration +} + +// Apply is a no-op because the write deadline will be reloaded after options +// are applied. +func (w *writeDeadlineOption) Apply(server *Server) { + server.Noticef("Reloaded: write_deadline = %s", w.newValue) +} + +// clientAdvertiseOption implements the option interface for the `client_advertise` setting. +type clientAdvertiseOption struct { + noopOption + newValue string +} + +// Apply the setting by updating the server info and regenerate the infoJSON byte array. +func (c *clientAdvertiseOption) Apply(server *Server) { + server.mu.Lock() + server.setInfoHostPortAndGenerateJSON() + server.mu.Unlock() + server.Noticef("Reload: client_advertise = %s", c.newValue) +} + +// Reload reads the current configuration file and applies any supported +// changes. This returns an error if the server was not started with a config +// file or an option which doesn't support hot-swapping was changed. +func (s *Server) Reload() error { + s.mu.Lock() + if s.configFile == "" { + s.mu.Unlock() + return errors.New("Can only reload config when a file is provided using -c or --config") + } + newOpts, err := ProcessConfigFile(s.configFile) + if err != nil { + s.mu.Unlock() + // TODO: Dump previous good config to a .bak file? + return err + } + clientOrgPort := s.clientActualPort + clusterOrgPort := s.clusterActualPort + s.mu.Unlock() + + // Apply flags over config file settings. + newOpts = MergeOptions(newOpts, FlagSnapshot) + processOptions(newOpts) + + // processOptions sets Port to 0 if set to -1 (RANDOM port) + // If that's the case, set it to the saved value when the accept loop was + // created. + if newOpts.Port == 0 { + newOpts.Port = clientOrgPort + } + // We don't do that for cluster, so check against -1. + if newOpts.Cluster.Port == -1 { + newOpts.Cluster.Port = clusterOrgPort + } + + if err := s.reloadOptions(newOpts); err != nil { + return err + } + s.mu.Lock() + s.configTime = time.Now() + s.mu.Unlock() + return nil +} + +// reloadOptions reloads the server config with the provided options. If an +// option that doesn't support hot-swapping is changed, this returns an error. +func (s *Server) reloadOptions(newOpts *Options) error { + changed, err := s.diffOptions(newOpts) + if err != nil { + return err + } + s.setOpts(newOpts) + s.applyOptions(changed) + return nil +} + +// diffOptions returns a slice containing options which have been changed. If +// an option that doesn't support hot-swapping is changed, this returns an +// error. +func (s *Server) diffOptions(newOpts *Options) ([]option, error) { + var ( + oldConfig = reflect.ValueOf(s.getOpts()).Elem() + newConfig = reflect.ValueOf(newOpts).Elem() + diffOpts = []option{} + ) + + for i := 0; i < oldConfig.NumField(); i++ { + var ( + field = oldConfig.Type().Field(i) + oldValue = oldConfig.Field(i).Interface() + newValue = newConfig.Field(i).Interface() + changed = !reflect.DeepEqual(oldValue, newValue) + ) + if !changed { + continue + } + switch strings.ToLower(field.Name) { + case "trace": + diffOpts = append(diffOpts, &traceOption{newValue: newValue.(bool)}) + case "debug": + diffOpts = append(diffOpts, &debugOption{newValue: newValue.(bool)}) + case "logtime": + diffOpts = append(diffOpts, &logtimeOption{newValue: newValue.(bool)}) + case "logfile": + diffOpts = append(diffOpts, &logfileOption{newValue: newValue.(string)}) + case "syslog": + diffOpts = append(diffOpts, &syslogOption{newValue: newValue.(bool)}) + case "remotesyslog": + diffOpts = append(diffOpts, &remoteSyslogOption{newValue: newValue.(string)}) + case "tlsconfig": + diffOpts = append(diffOpts, &tlsOption{newValue: newValue.(*tls.Config)}) + case "tlstimeout": + diffOpts = append(diffOpts, &tlsTimeoutOption{newValue: newValue.(float64)}) + case "username": + diffOpts = append(diffOpts, &usernameOption{}) + case "password": + diffOpts = append(diffOpts, &passwordOption{}) + case "authorization": + diffOpts = append(diffOpts, &authorizationOption{}) + case "authtimeout": + diffOpts = append(diffOpts, &authTimeoutOption{newValue: newValue.(float64)}) + case "users": + diffOpts = append(diffOpts, &usersOption{newValue: newValue.([]*User)}) + case "cluster": + newClusterOpts := newValue.(ClusterOpts) + if err := validateClusterOpts(oldValue.(ClusterOpts), newClusterOpts); err != nil { + return nil, err + } + diffOpts = append(diffOpts, &clusterOption{newValue: newClusterOpts}) + case "routes": + add, remove := diffRoutes(oldValue.([]*url.URL), newValue.([]*url.URL)) + diffOpts = append(diffOpts, &routesOption{add: add, remove: remove}) + case "maxconn": + diffOpts = append(diffOpts, &maxConnOption{newValue: newValue.(int)}) + case "pidfile": + diffOpts = append(diffOpts, &pidFileOption{newValue: newValue.(string)}) + case "portsfiledir": + diffOpts = append(diffOpts, &portsFileDirOption{newValue: newValue.(string), oldValue: oldValue.(string)}) + case "maxcontrolline": + diffOpts = append(diffOpts, &maxControlLineOption{newValue: newValue.(int)}) + case "maxpayload": + diffOpts = append(diffOpts, &maxPayloadOption{newValue: newValue.(int)}) + case "pinginterval": + diffOpts = append(diffOpts, &pingIntervalOption{newValue: newValue.(time.Duration)}) + case "maxpingsout": + diffOpts = append(diffOpts, &maxPingsOutOption{newValue: newValue.(int)}) + case "writedeadline": + diffOpts = append(diffOpts, &writeDeadlineOption{newValue: newValue.(time.Duration)}) + case "clientadvertise": + cliAdv := newValue.(string) + if cliAdv != "" { + // Validate ClientAdvertise syntax + if _, _, err := parseHostPort(cliAdv, 0); err != nil { + return nil, fmt.Errorf("invalid ClientAdvertise value of %s, err=%v", cliAdv, err) + } + } + diffOpts = append(diffOpts, &clientAdvertiseOption{newValue: cliAdv}) + case "nolog", "nosigs": + // Ignore NoLog and NoSigs options since they are not parsed and only used in + // testing. + continue + case "port": + // check to see if newValue == 0 and continue if so. + if newValue == 0 { + // ignore RANDOM_PORT + continue + } + fallthrough + default: + // Bail out if attempting to reload any unsupported options. + return nil, fmt.Errorf("Config reload not supported for %s: old=%v, new=%v", + field.Name, oldValue, newValue) + } + } + + return diffOpts, nil +} + +func (s *Server) applyOptions(opts []option) { + var ( + reloadLogging = false + reloadAuth = false + ) + for _, opt := range opts { + opt.Apply(s) + if opt.IsLoggingChange() { + reloadLogging = true + } + if opt.IsAuthChange() { + reloadAuth = true + } + } + + if reloadLogging { + s.ConfigureLogger() + } + if reloadAuth { + s.reloadAuthorization() + } + + s.Noticef("Reloaded server configuration") +} + +// reloadAuthorization reconfigures the server authorization settings, +// disconnects any clients who are no longer authorized, and removes any +// unauthorized subscriptions. +func (s *Server) reloadAuthorization() { + s.mu.Lock() + s.configureAuthorization() + clients := make(map[uint64]*client, len(s.clients)) + for i, client := range s.clients { + clients[i] = client + } + routes := make(map[uint64]*client, len(s.routes)) + for i, route := range s.routes { + routes[i] = route + } + s.mu.Unlock() + + for _, client := range clients { + // Disconnect any unauthorized clients. + if !s.isClientAuthorized(client) { + client.authViolation() + continue + } + + // Remove any unauthorized subscriptions. + s.removeUnauthorizedSubs(client) + } + + for _, route := range routes { + // Disconnect any unauthorized routes. + // Do this only for route that were accepted, not initiated + // because in the later case, we don't have the user name/password + // of the remote server. + if !route.isSolicitedRoute() && !s.isRouterAuthorized(route) { + route.setRouteNoReconnectOnClose() + route.authViolation() + } + } +} + +// validateClusterOpts ensures the new ClusterOpts does not change host or +// port, which do not support reload. +func validateClusterOpts(old, new ClusterOpts) error { + if old.Host != new.Host { + return fmt.Errorf("Config reload not supported for cluster host: old=%s, new=%s", + old.Host, new.Host) + } + if old.Port != new.Port { + return fmt.Errorf("Config reload not supported for cluster port: old=%d, new=%d", + old.Port, new.Port) + } + // Validate Cluster.Advertise syntax + if new.Advertise != "" { + if _, _, err := parseHostPort(new.Advertise, 0); err != nil { + return fmt.Errorf("invalid Cluster.Advertise value of %s, err=%v", new.Advertise, err) + } + } + return nil +} + +// diffRoutes diffs the old routes and the new routes and returns the ones that +// should be added and removed from the server. +func diffRoutes(old, new []*url.URL) (add, remove []*url.URL) { + // Find routes to remove. +removeLoop: + for _, oldRoute := range old { + for _, newRoute := range new { + if urlsAreEqual(oldRoute, newRoute) { + continue removeLoop + } + } + remove = append(remove, oldRoute) + } + + // Find routes to add. +addLoop: + for _, newRoute := range new { + for _, oldRoute := range old { + if urlsAreEqual(oldRoute, newRoute) { + continue addLoop + } + } + add = append(add, newRoute) + } + + return add, remove +} diff --git a/vendor/github.com/nats-io/gnatsd/server/ring.go b/vendor/github.com/nats-io/gnatsd/server/ring.go new file mode 100644 index 00000000000..b9232ca9955 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/ring.go @@ -0,0 +1,75 @@ +// Copyright 2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +// We wrap to hold onto optional items for /connz. +type closedClient struct { + ConnInfo + subs []string + user string +} + +// Fixed sized ringbuffer for closed connections. +type closedRingBuffer struct { + total uint64 + conns []*closedClient +} + +// Create a new ring buffer with at most max items. +func newClosedRingBuffer(max int) *closedRingBuffer { + rb := &closedRingBuffer{} + rb.conns = make([]*closedClient, max) + return rb +} + +// Adds in a new closed connection. If there is no more room, +// remove the oldest. +func (rb *closedRingBuffer) append(cc *closedClient) { + rb.conns[rb.next()] = cc + rb.total++ +} + +func (rb *closedRingBuffer) next() int { + return int(rb.total % uint64(cap(rb.conns))) +} + +func (rb *closedRingBuffer) len() int { + if rb.total > uint64(cap(rb.conns)) { + return cap(rb.conns) + } + return int(rb.total) +} + +func (rb *closedRingBuffer) totalConns() uint64 { + return rb.total +} + +// This will not be sorted. Will return a copy of the list +// which recipient can modify. If the contents of the client +// itself need to be modified, meaning swapping in any optional items, +// a copy should be made. We could introduce a new lock and hold that +// but since we return this list inside monitor which allows programatic +// access, we do not know when it would be done. +func (rb *closedRingBuffer) closedClients() []*closedClient { + dup := make([]*closedClient, rb.len()) + if rb.total <= uint64(cap(rb.conns)) { + copy(dup, rb.conns[:rb.len()]) + } else { + first := rb.next() + next := cap(rb.conns) - first + copy(dup, rb.conns[first:]) + copy(dup[next:], rb.conns[:next]) + } + return dup +} diff --git a/vendor/github.com/nats-io/gnatsd/server/route.go b/vendor/github.com/nats-io/gnatsd/server/route.go new file mode 100644 index 00000000000..e7afa97d05f --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/route.go @@ -0,0 +1,1103 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "math/rand" + "net" + "net/url" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/nats-io/gnatsd/util" +) + +// RouteType designates the router type +type RouteType int + +// Type of Route +const ( + // This route we learned from speaking to other routes. + Implicit RouteType = iota + // This route was explicitly configured. + Explicit +) + +type route struct { + remoteID string + didSolicit bool + retry bool + routeType RouteType + url *url.URL + authRequired bool + tlsRequired bool + closed bool + connectURLs []string +} + +type connectInfo struct { + Echo bool `json:"echo"` + Verbose bool `json:"verbose"` + Pedantic bool `json:"pedantic"` + User string `json:"user,omitempty"` + Pass string `json:"pass,omitempty"` + TLS bool `json:"tls_required"` + Name string `json:"name"` +} + +// Used to hold onto mappings for unsubscribed +// routed queue subscribers. +type rqsub struct { + group []byte + atime time.Time +} + +// Route protocol constants +const ( + ConProto = "CONNECT %s" + _CRLF_ + InfoProto = "INFO %s" + _CRLF_ +) + +// Clear up the timer and any map held for remote qsubs. +func (s *Server) clearRemoteQSubs() { + s.rqsMu.Lock() + defer s.rqsMu.Unlock() + if s.rqsubsTimer != nil { + s.rqsubsTimer.Stop() + s.rqsubsTimer = nil + } + s.rqsubs = nil +} + +// Check to see if we can remove any of the remote qsubs mappings +func (s *Server) purgeRemoteQSubs() { + ri := s.getOpts().RQSubsSweep + s.rqsMu.Lock() + exp := time.Now().Add(-ri) + for k, rqsub := range s.rqsubs { + if exp.After(rqsub.atime) { + delete(s.rqsubs, k) + } + } + if s.rqsubsTimer != nil { + // Reset timer. + s.rqsubsTimer = time.AfterFunc(ri, s.purgeRemoteQSubs) + } + s.rqsMu.Unlock() +} + +// Lookup a remote queue group sid. +func (s *Server) lookupRemoteQGroup(sid string) []byte { + s.rqsMu.RLock() + rqsub := s.rqsubs[sid] + s.rqsMu.RUnlock() + return rqsub.group +} + +// This will hold onto a remote queue subscriber to allow +// for mapping and handling if we get a message after the +// subscription goes away. +func (s *Server) holdRemoteQSub(sub *subscription) { + // Should not happen, but protect anyway. + if len(sub.queue) == 0 { + return + } + // Add the entry + s.rqsMu.Lock() + // Start timer if needed. + if s.rqsubsTimer == nil { + ri := s.getOpts().RQSubsSweep + s.rqsubsTimer = time.AfterFunc(ri, s.purgeRemoteQSubs) + } + // Create map if needed. + if s.rqsubs == nil { + s.rqsubs = make(map[string]rqsub) + } + group := make([]byte, len(sub.queue)) + copy(group, sub.queue) + rqsub := rqsub{group: group, atime: time.Now()} + s.rqsubs[routeSid(sub)] = rqsub + s.rqsMu.Unlock() +} + +// This is for when we receive a directed message for a queue subscriber +// that has gone away. We reroute like a new message but scope to only +// the queue subscribers that it was originally intended for. We will +// prefer local clients, but will bounce to another route if needed. +func (c *client) reRouteQMsg(r *SublistResult, msgh, msg, group []byte) { + c.Debugf("Attempting redelivery of message for absent queue subscriber on group '%q'", group) + + // We only care about qsubs here. Data structure not setup for optimized + // lookup for our specific group however. + + var qsubs []*subscription + for _, qs := range r.qsubs { + if len(qs) != 0 && bytes.Equal(group, qs[0].queue) { + qsubs = qs + break + } + } + + // If no match return. + if qsubs == nil { + c.Debugf("Redelivery failed, no queue subscribers for message on group '%q'", group) + return + } + + // We have a matched group of queue subscribers. + // We prefer a local subscriber since that was the original target. + + // Spin prand if needed. + if c.in.prand == nil { + c.in.prand = rand.New(rand.NewSource(time.Now().UnixNano())) + } + + // Hold onto a remote if we come across it to utilize in case no locals exist. + var rsub *subscription + + startIndex := c.in.prand.Intn(len(qsubs)) + for i := 0; i < len(qsubs); i++ { + index := (startIndex + i) % len(qsubs) + sub := qsubs[index] + if sub == nil { + continue + } + if rsub == nil && bytes.HasPrefix(sub.sid, []byte(QRSID)) { + rsub = sub + continue + } + mh := c.msgHeader(msgh[:], sub) + if c.deliverMsg(sub, mh, msg) { + c.Debugf("Redelivery succeeded for message on group '%q'", group) + return + } + } + // If we are here we failed to find a local, see if we snapshotted a + // remote sub, and if so deliver to that. + if rsub != nil { + mh := c.msgHeader(msgh[:], rsub) + if c.deliverMsg(rsub, mh, msg) { + c.Debugf("Re-routing message on group '%q' to remote server", group) + return + } + } + c.Debugf("Redelivery failed, no queue subscribers for message on group '%q'", group) +} + +// processRoutedMsg processes messages inbound from a route. +func (c *client) processRoutedMsg(r *SublistResult, msg []byte) { + // Snapshot server. + srv := c.srv + + msgh := c.prepMsgHeader() + si := len(msgh) + + // If we have a queue subscription, deliver direct + // since they are sent direct via L2 semantics over routes. + // If the match is a queue subscription, we will return from + // here regardless if we find a sub. + isq, sub, err := srv.routeSidQueueSubscriber(c.pa.sid) + if isq { + if err != nil { + // We got an invalid QRSID, so stop here + c.Errorf("Unable to deliver routed queue message: %v", err) + return + } + didDeliver := false + if sub != nil { + mh := c.msgHeader(msgh[:si], sub) + didDeliver = c.deliverMsg(sub, mh, msg) + } + if !didDeliver && c.srv != nil { + group := c.srv.lookupRemoteQGroup(string(c.pa.sid)) + c.reRouteQMsg(r, msgh, msg, group) + } + return + } + // Normal pub/sub message here + // Loop over all normal subscriptions that match. + for _, sub := range r.psubs { + // Check if this is a send to a ROUTER, if so we ignore to + // enforce 1-hop semantics. + if sub.client.typ == ROUTER { + continue + } + sub.client.mu.Lock() + if sub.client.nc == nil { + sub.client.mu.Unlock() + continue + } + sub.client.mu.Unlock() + + // Normal delivery + mh := c.msgHeader(msgh[:si], sub) + c.deliverMsg(sub, mh, msg) + } +} + +// Lock should be held entering here. +func (c *client) sendConnect(tlsRequired bool) { + var user, pass string + if userInfo := c.route.url.User; userInfo != nil { + user = userInfo.Username() + pass, _ = userInfo.Password() + } + cinfo := connectInfo{ + Echo: true, + Verbose: false, + Pedantic: false, + User: user, + Pass: pass, + TLS: tlsRequired, + Name: c.srv.info.ID, + } + b, err := json.Marshal(cinfo) + if err != nil { + c.Errorf("Error marshaling CONNECT to route: %v\n", err) + c.closeConnection(ProtocolViolation) + return + } + c.sendProto([]byte(fmt.Sprintf(ConProto, b)), true) +} + +// Process the info message if we are a route. +func (c *client) processRouteInfo(info *Info) { + c.mu.Lock() + // Connection can be closed at any time (by auth timeout, etc). + // Does not make sense to continue here if connection is gone. + if c.route == nil || c.nc == nil { + c.mu.Unlock() + return + } + + s := c.srv + remoteID := c.route.remoteID + + // We receive an INFO from a server that informs us about another server, + // so the info.ID in the INFO protocol does not match the ID of this route. + if remoteID != "" && remoteID != info.ID { + c.mu.Unlock() + + // Process this implicit route. We will check that it is not an explicit + // route and/or that it has not been connected already. + s.processImplicitRoute(info) + return + } + + // Need to set this for the detection of the route to self to work + // in closeConnection(). + c.route.remoteID = info.ID + + // Detect route to self. + if c.route.remoteID == s.info.ID { + c.mu.Unlock() + c.closeConnection(DuplicateRoute) + return + } + + // Copy over important information. + c.route.authRequired = info.AuthRequired + c.route.tlsRequired = info.TLSRequired + + // If we do not know this route's URL, construct one on the fly + // from the information provided. + if c.route.url == nil { + // Add in the URL from host and port + hp := net.JoinHostPort(info.Host, strconv.Itoa(info.Port)) + url, err := url.Parse(fmt.Sprintf("nats-route://%s/", hp)) + if err != nil { + c.Errorf("Error parsing URL from INFO: %v\n", err) + c.mu.Unlock() + c.closeConnection(ParseError) + return + } + c.route.url = url + } + + // Check to see if we have this remote already registered. + // This can happen when both servers have routes to each other. + c.mu.Unlock() + + if added, sendInfo := s.addRoute(c, info); added { + c.Debugf("Registering remote route %q", info.ID) + // Send our local subscriptions to this route. + s.sendLocalSubsToRoute(c) + // sendInfo will be false if the route that we just accepted + // is the only route there is. + if sendInfo { + // The incoming INFO from the route will have IP set + // if it has Cluster.Advertise. In that case, use that + // otherwise contruct it from the remote TCP address. + if info.IP == "" { + // Need to get the remote IP address. + c.mu.Lock() + switch conn := c.nc.(type) { + case *net.TCPConn, *tls.Conn: + addr := conn.RemoteAddr().(*net.TCPAddr) + info.IP = fmt.Sprintf("nats-route://%s/", net.JoinHostPort(addr.IP.String(), + strconv.Itoa(info.Port))) + default: + info.IP = c.route.url.String() + } + c.mu.Unlock() + } + // Now let the known servers know about this new route + s.forwardNewRouteInfoToKnownServers(info) + } + // Unless disabled, possibly update the server's INFO protocol + // and send to clients that know how to handle async INFOs. + if !s.getOpts().Cluster.NoAdvertise { + s.addClientConnectURLsAndSendINFOToClients(info.ClientConnectURLs) + } + } else { + c.Debugf("Detected duplicate remote route %q", info.ID) + c.closeConnection(DuplicateRoute) + } +} + +// sendAsyncInfoToClients sends an INFO protocol to all +// connected clients that accept async INFO updates. +// The server lock is held on entry. +func (s *Server) sendAsyncInfoToClients() { + // If there are no clients supporting async INFO protocols, we are done. + // Also don't send if we are shutting down... + if s.cproto == 0 || s.shutdown { + return + } + + for _, c := range s.clients { + c.mu.Lock() + // Here, we are going to send only to the clients that are fully + // registered (server has received CONNECT and first PING). For + // clients that are not at this stage, this will happen in the + // processing of the first PING (see client.processPing) + if c.opts.Protocol >= ClientProtoInfo && c.flags.isSet(firstPongSent) { + // sendInfo takes care of checking if the connection is still + // valid or not, so don't duplicate tests here. + c.sendInfo(c.generateClientInfoJSON(s.copyInfo())) + } + c.mu.Unlock() + } +} + +// This will process implicit route information received from another server. +// We will check to see if we have configured or are already connected, +// and if so we will ignore. Otherwise we will attempt to connect. +func (s *Server) processImplicitRoute(info *Info) { + remoteID := info.ID + + s.mu.Lock() + defer s.mu.Unlock() + + // Don't connect to ourself + if remoteID == s.info.ID { + return + } + // Check if this route already exists + if _, exists := s.remotes[remoteID]; exists { + return + } + // Check if we have this route as a configured route + if s.hasThisRouteConfigured(info) { + return + } + + // Initiate the connection, using info.IP instead of info.URL here... + r, err := url.Parse(info.IP) + if err != nil { + s.Errorf("Error parsing URL from INFO: %v\n", err) + return + } + + // Snapshot server options. + opts := s.getOpts() + + if info.AuthRequired { + r.User = url.UserPassword(opts.Cluster.Username, opts.Cluster.Password) + } + s.startGoRoutine(func() { s.connectToRoute(r, false) }) +} + +// hasThisRouteConfigured returns true if info.Host:info.Port is present +// in the server's opts.Routes, false otherwise. +// Server lock is assumed to be held by caller. +func (s *Server) hasThisRouteConfigured(info *Info) bool { + urlToCheckExplicit := strings.ToLower(net.JoinHostPort(info.Host, strconv.Itoa(info.Port))) + for _, ri := range s.getOpts().Routes { + if strings.ToLower(ri.Host) == urlToCheckExplicit { + return true + } + } + return false +} + +// forwardNewRouteInfoToKnownServers sends the INFO protocol of the new route +// to all routes known by this server. In turn, each server will contact this +// new route. +func (s *Server) forwardNewRouteInfoToKnownServers(info *Info) { + s.mu.Lock() + defer s.mu.Unlock() + + b, _ := json.Marshal(info) + infoJSON := []byte(fmt.Sprintf(InfoProto, b)) + + for _, r := range s.routes { + r.mu.Lock() + if r.route.remoteID != info.ID { + r.sendInfo(infoJSON) + } + r.mu.Unlock() + } +} + +// canImport is whether or not we will send a SUB for interest to the other side. +// This is for ROUTER connections only. +// Lock is held on entry. +func (c *client) canImport(subject []byte) bool { + // Use pubAllowed() since this checks Publish permissions which + // is what Import maps to. + return c.pubAllowed(subject) +} + +// canExport is whether or not we will accept a SUB from the remote for a given subject. +// This is for ROUTER connections only. +// Lock is held on entry +func (c *client) canExport(subject []byte) bool { + // Use canSubscribe() since this checks Subscribe permissions which + // is what Export maps to. + return c.canSubscribe(subject) +} + +// Initialize or reset cluster's permissions. +// This is for ROUTER connections only. +// Client lock is held on entry +func (c *client) setRoutePermissions(perms *RoutePermissions) { + // Reset if some were set + if perms == nil { + c.perms = nil + return + } + // Convert route permissions to user permissions. + // The Import permission is mapped to Publish + // and Export permission is mapped to Subscribe. + // For meaning of Import/Export, see canImport and canExport. + p := &Permissions{ + Publish: perms.Import, + Subscribe: perms.Export, + } + c.setPermissions(p) +} + +// This will send local subscription state to a new route connection. +// FIXME(dlc) - This could be a DOS or perf issue with many clients +// and large subscription space. Plus buffering in place not a good idea. +func (s *Server) sendLocalSubsToRoute(route *client) { + var raw [4096]*subscription + subs := raw[:0] + + s.sl.localSubs(&subs) + + route.mu.Lock() + for _, sub := range subs { + // Send SUB interest only if subject has a match in import permissions + if !route.canImport(sub.subject) { + continue + } + proto := fmt.Sprintf(subProto, sub.subject, sub.queue, routeSid(sub)) + route.queueOutbound([]byte(proto)) + if route.out.pb > int64(route.out.sz*2) { + route.flushSignal() + } + } + route.flushSignal() + route.mu.Unlock() + + route.Debugf("Sent local subscriptions to route") +} + +func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client { + // Snapshot server options. + opts := s.getOpts() + + didSolicit := rURL != nil + r := &route{didSolicit: didSolicit} + for _, route := range opts.Routes { + if rURL != nil && (strings.ToLower(rURL.Host) == strings.ToLower(route.Host)) { + r.routeType = Explicit + } + } + + c := &client{srv: s, nc: conn, opts: clientOpts{}, typ: ROUTER, route: r} + + // Grab server variables + s.mu.Lock() + infoJSON := s.routeInfoJSON + authRequired := s.routeInfo.AuthRequired + tlsRequired := s.routeInfo.TLSRequired + s.mu.Unlock() + + // Grab lock + c.mu.Lock() + + // Initialize + c.initClient() + + if didSolicit { + // Do this before the TLS code, otherwise, in case of failure + // and if route is explicit, it would try to reconnect to 'nil'... + r.url = rURL + + // Set permissions associated with the route user (if applicable). + // No lock needed since we are already under client lock. + c.setRoutePermissions(opts.Cluster.Permissions) + } + + // Check for TLS + if tlsRequired { + // Copy off the config to add in ServerName if we + tlsConfig := util.CloneTLSConfig(opts.Cluster.TLSConfig) + + // If we solicited, we will act like the client, otherwise the server. + if didSolicit { + c.Debugf("Starting TLS route client handshake") + // Specify the ServerName we are expecting. + host, _, _ := net.SplitHostPort(rURL.Host) + tlsConfig.ServerName = host + c.nc = tls.Client(c.nc, tlsConfig) + } else { + c.Debugf("Starting TLS route server handshake") + c.nc = tls.Server(c.nc, tlsConfig) + } + + conn := c.nc.(*tls.Conn) + + // Setup the timeout + ttl := secondsToDuration(opts.Cluster.TLSTimeout) + time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) + conn.SetReadDeadline(time.Now().Add(ttl)) + + c.mu.Unlock() + if err := conn.Handshake(); err != nil { + c.Errorf("TLS route handshake error: %v", err) + c.sendErr("Secure Connection - TLS Required") + c.closeConnection(TLSHandshakeError) + return nil + } + // Reset the read deadline + conn.SetReadDeadline(time.Time{}) + + // Re-Grab lock + c.mu.Lock() + + // Verify that the connection did not go away while we released the lock. + if c.nc == nil { + c.mu.Unlock() + return nil + } + } + + // Do final client initialization + + // Set the Ping timer + c.setPingTimer() + + // For routes, the "client" is added to s.routes only when processing + // the INFO protocol, that is much later. + // In the meantime, if the server shutsdown, there would be no reference + // to the client (connection) to be closed, leaving this readLoop + // uinterrupted, causing the Shutdown() to wait indefinitively. + // We need to store the client in a special map, under a special lock. + s.grMu.Lock() + running := s.grRunning + if running { + s.grTmpClients[c.cid] = c + } + s.grMu.Unlock() + if !running { + c.mu.Unlock() + c.setRouteNoReconnectOnClose() + c.closeConnection(ServerShutdown) + return nil + } + + // Check for Auth required state for incoming connections. + // Make sure to do this before spinning up readLoop. + if authRequired && !didSolicit { + ttl := secondsToDuration(opts.Cluster.AuthTimeout) + c.setAuthTimer(ttl) + } + + // Spin up the read loop. + s.startGoRoutine(func() { c.readLoop() }) + + // Spin up the write loop. + s.startGoRoutine(c.writeLoop) + + if tlsRequired { + c.Debugf("TLS handshake complete") + cs := c.nc.(*tls.Conn).ConnectionState() + c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite)) + } + + // Queue Connect proto if we solicited the connection. + if didSolicit { + c.Debugf("Route connect msg sent") + c.sendConnect(tlsRequired) + } + + // Send our info to the other side. + c.sendInfo(infoJSON) + + c.mu.Unlock() + + c.Noticef("Route connection created") + return c +} + +const ( + _CRLF_ = "\r\n" + _EMPTY_ = "" +) + +const ( + subProto = "SUB %s %s %s" + _CRLF_ + unsubProto = "UNSUB %s" + _CRLF_ +) + +// FIXME(dlc) - Make these reserved and reject if they come in as a sid +// from a client connection. +// Route constants +const ( + RSID = "RSID" + QRSID = "QRSID" + + QRSID_LEN = len(QRSID) +) + +// Parse the given rsid. If the protocol does not start with QRSID, +// returns false and no subscription nor error. +// If it does start with QRSID, returns true and possibly a subscription +// or an error if the QRSID protocol is malformed. +func (s *Server) routeSidQueueSubscriber(rsid []byte) (bool, *subscription, error) { + if !bytes.HasPrefix(rsid, []byte(QRSID)) { + return false, nil, nil + } + cid, sid, err := parseRouteQueueSid(rsid) + if err != nil { + return true, nil, err + } + + s.mu.Lock() + client := s.clients[cid] + s.mu.Unlock() + + if client == nil { + return true, nil, nil + } + + client.mu.Lock() + sub, ok := client.subs[string(sid)] + client.mu.Unlock() + if ok { + return true, sub, nil + } + return true, nil, nil +} + +// Creates a routable sid that can be used +// to reach remote subscriptions. +func routeSid(sub *subscription) string { + var qi string + if len(sub.queue) > 0 { + qi = "Q" + } + return fmt.Sprintf("%s%s:%d:%s", qi, RSID, sub.client.cid, sub.sid) +} + +// Parse the given `rsid` knowing that it starts with `QRSID`. +// Returns the cid and sid or an error not a valid QRSID. +func parseRouteQueueSid(rsid []byte) (uint64, []byte, error) { + var ( + cid uint64 + sid []byte + cidFound bool + sidFound bool + ) + // A valid QRSID needs to be at least QRSID:x:y + // First character here should be `:` + if len(rsid) >= QRSID_LEN+4 { + if rsid[QRSID_LEN] == ':' { + for i, count := QRSID_LEN+1, len(rsid); i < count; i++ { + switch rsid[i] { + case ':': + cid = uint64(parseInt64(rsid[QRSID_LEN+1 : i])) + cidFound = true + sid = rsid[i+1:] + } + } + if cidFound { + // We can't assume the content of sid, so as long + // as it is not len 0, we have to say it is a valid one. + if len(rsid) > 0 { + sidFound = true + } + } + } + } + if cidFound && sidFound { + return cid, sid, nil + } + return 0, nil, fmt.Errorf("invalid QRSID: %s", rsid) +} + +func (s *Server) addRoute(c *client, info *Info) (bool, bool) { + id := c.route.remoteID + sendInfo := false + + s.mu.Lock() + if !s.running { + s.mu.Unlock() + return false, false + } + remote, exists := s.remotes[id] + if !exists { + s.routes[c.cid] = c + s.remotes[id] = c + c.mu.Lock() + c.route.connectURLs = info.ClientConnectURLs + cid := c.cid + c.mu.Unlock() + + // Remove from the temporary map + s.grMu.Lock() + delete(s.grTmpClients, cid) + s.grMu.Unlock() + + // we don't need to send if the only route is the one we just accepted. + sendInfo = len(s.routes) > 1 + } + s.mu.Unlock() + + if exists { + var r *route + + c.mu.Lock() + // upgrade to solicited? + if c.route.didSolicit { + // Make a copy + rs := *c.route + r = &rs + } + c.mu.Unlock() + + remote.mu.Lock() + // r will be not nil if c.route.didSolicit was true + if r != nil { + // If we upgrade to solicited, we still want to keep the remote's + // connectURLs. So transfer those. + r.connectURLs = remote.route.connectURLs + remote.route = r + } + // This is to mitigate the issue where both sides add the route + // on the opposite connection, and therefore end-up with both + // connections being dropped. + remote.route.retry = true + remote.mu.Unlock() + } + + return !exists, sendInfo +} + +func (s *Server) broadcastInterestToRoutes(sub *subscription, proto string) { + var arg []byte + if atomic.LoadInt32(&s.logging.trace) == 1 { + arg = []byte(proto[:len(proto)-LEN_CR_LF]) + } + protoAsBytes := []byte(proto) + s.mu.Lock() + for _, route := range s.routes { + // FIXME(dlc) - Make same logic as deliverMsg + route.mu.Lock() + // The permission of this cluster applies to all routes, and each + // route will have the same `perms`, so check with the first route + // and send SUB interest only if subject has a match in import permissions. + // If there is no match, we stop here. + if !route.canImport(sub.subject) { + route.mu.Unlock() + break + } + route.sendProto(protoAsBytes, true) + route.mu.Unlock() + route.traceOutOp("", arg) + } + s.mu.Unlock() +} + +// broadcastSubscribe will forward a client subscription +// to all active routes. +func (s *Server) broadcastSubscribe(sub *subscription) { + if s.numRoutes() == 0 { + return + } + rsid := routeSid(sub) + proto := fmt.Sprintf(subProto, sub.subject, sub.queue, rsid) + s.broadcastInterestToRoutes(sub, proto) +} + +// broadcastUnSubscribe will forward a client unsubscribe +// action to all active routes. +func (s *Server) broadcastUnSubscribe(sub *subscription) { + if s.numRoutes() == 0 { + return + } + sub.client.mu.Lock() + // Max has no meaning on the other side of a route, so do not send. + hasMax := sub.max > 0 && sub.nm < sub.max + sub.client.mu.Unlock() + if hasMax { + return + } + rsid := routeSid(sub) + proto := fmt.Sprintf(unsubProto, rsid) + s.broadcastInterestToRoutes(sub, proto) +} + +func (s *Server) routeAcceptLoop(ch chan struct{}) { + defer func() { + if ch != nil { + close(ch) + } + }() + + // Snapshot server options. + opts := s.getOpts() + + // Snapshot server options. + port := opts.Cluster.Port + + if port == -1 { + port = 0 + } + + hp := net.JoinHostPort(opts.Cluster.Host, strconv.Itoa(port)) + l, e := net.Listen("tcp", hp) + if e != nil { + s.Fatalf("Error listening on router port: %d - %v", opts.Cluster.Port, e) + return + } + s.Noticef("Listening for route connections on %s", + net.JoinHostPort(opts.Cluster.Host, strconv.Itoa(l.Addr().(*net.TCPAddr).Port))) + + s.mu.Lock() + // Check for TLSConfig + tlsReq := opts.Cluster.TLSConfig != nil + info := Info{ + ID: s.info.ID, + Version: s.info.Version, + AuthRequired: false, + TLSRequired: tlsReq, + TLSVerify: tlsReq, + MaxPayload: s.info.MaxPayload, + } + // Set this if only if advertise is not disabled + if !opts.Cluster.NoAdvertise { + info.ClientConnectURLs = s.clientConnectURLs + } + // If we have selected a random port... + if port == 0 { + // Write resolved port back to options. + opts.Cluster.Port = l.Addr().(*net.TCPAddr).Port + } + // Keep track of actual listen port. This will be needed in case of + // config reload. + s.clusterActualPort = opts.Cluster.Port + // Check for Auth items + if opts.Cluster.Username != "" { + info.AuthRequired = true + } + s.routeInfo = info + // Possibly override Host/Port and set IP based on Cluster.Advertise + if err := s.setRouteInfoHostPortAndIP(); err != nil { + s.Fatalf("Error setting route INFO with Cluster.Advertise value of %s, err=%v", s.opts.Cluster.Advertise, err) + l.Close() + s.mu.Unlock() + return + } + // Setup state that can enable shutdown + s.routeListener = l + s.mu.Unlock() + + // Let them know we are up + close(ch) + ch = nil + + tmpDelay := ACCEPT_MIN_SLEEP + + for s.isRunning() { + conn, err := l.Accept() + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Temporary() { + s.Debugf("Temporary Route Accept Errorf(%v), sleeping %dms", + ne, tmpDelay/time.Millisecond) + time.Sleep(tmpDelay) + tmpDelay *= 2 + if tmpDelay > ACCEPT_MAX_SLEEP { + tmpDelay = ACCEPT_MAX_SLEEP + } + } else if s.isRunning() { + s.Noticef("Accept error: %v", err) + } + continue + } + tmpDelay = ACCEPT_MIN_SLEEP + s.startGoRoutine(func() { + s.createRoute(conn, nil) + s.grWG.Done() + }) + } + s.Debugf("Router accept loop exiting..") + s.done <- true +} + +// Similar to setInfoHostPortAndGenerateJSON, but for routeInfo. +func (s *Server) setRouteInfoHostPortAndIP() error { + if s.opts.Cluster.Advertise != "" { + advHost, advPort, err := parseHostPort(s.opts.Cluster.Advertise, s.opts.Cluster.Port) + if err != nil { + return err + } + s.routeInfo.Host = advHost + s.routeInfo.Port = advPort + s.routeInfo.IP = fmt.Sprintf("nats-route://%s/", net.JoinHostPort(advHost, strconv.Itoa(advPort))) + } else { + s.routeInfo.Host = s.opts.Cluster.Host + s.routeInfo.Port = s.opts.Cluster.Port + s.routeInfo.IP = "" + } + // (re)generate the routeInfoJSON byte array + s.generateRouteInfoJSON() + return nil +} + +// StartRouting will start the accept loop on the cluster host:port +// and will actively try to connect to listed routes. +func (s *Server) StartRouting(clientListenReady chan struct{}) { + defer s.grWG.Done() + + // Wait for the client listen port to be opened, and + // the possible ephemeral port to be selected. + <-clientListenReady + + // Spin up the accept loop + ch := make(chan struct{}) + go s.routeAcceptLoop(ch) + <-ch + + // Solicit Routes if needed. + s.solicitRoutes(s.getOpts().Routes) +} + +func (s *Server) reConnectToRoute(rURL *url.URL, rtype RouteType) { + tryForEver := rtype == Explicit + // If A connects to B, and B to A (regardless if explicit or + // implicit - due to auto-discovery), and if each server first + // registers the route on the opposite TCP connection, the + // two connections will end-up being closed. + // Add some random delay to reduce risk of repeated failures. + delay := time.Duration(rand.Intn(100)) * time.Millisecond + if tryForEver { + delay += DEFAULT_ROUTE_RECONNECT + } + time.Sleep(delay) + s.connectToRoute(rURL, tryForEver) +} + +// Checks to make sure the route is still valid. +func (s *Server) routeStillValid(rURL *url.URL) bool { + for _, ri := range s.getOpts().Routes { + if urlsAreEqual(ri, rURL) { + return true + } + } + return false +} + +func (s *Server) connectToRoute(rURL *url.URL, tryForEver bool) { + // Snapshot server options. + opts := s.getOpts() + + defer s.grWG.Done() + + attempts := 0 + for s.isRunning() && rURL != nil { + if tryForEver && !s.routeStillValid(rURL) { + return + } + s.Debugf("Trying to connect to route on %s", rURL.Host) + conn, err := net.DialTimeout("tcp", rURL.Host, DEFAULT_ROUTE_DIAL) + if err != nil { + s.Errorf("Error trying to connect to route: %v", err) + if !tryForEver { + if opts.Cluster.ConnectRetries <= 0 { + return + } + attempts++ + if attempts > opts.Cluster.ConnectRetries { + return + } + } + select { + case <-s.quitCh: + return + case <-time.After(DEFAULT_ROUTE_CONNECT): + continue + } + } + + if tryForEver && !s.routeStillValid(rURL) { + conn.Close() + return + } + + // We have a route connection here. + // Go ahead and create it and exit this func. + s.createRoute(conn, rURL) + return + } +} + +func (c *client) isSolicitedRoute() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.typ == ROUTER && c.route != nil && c.route.didSolicit +} + +func (s *Server) solicitRoutes(routes []*url.URL) { + for _, r := range routes { + route := r + s.startGoRoutine(func() { s.connectToRoute(route, true) }) + } +} + +func (s *Server) numRoutes() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.routes) +} diff --git a/vendor/github.com/nats-io/gnatsd/server/server.go b/vendor/github.com/nats-io/gnatsd/server/server.go new file mode 100644 index 00000000000..bddb9de3290 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/server.go @@ -0,0 +1,1420 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "net" + "net/http" + "os" + "path" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + // Allow dynamic profiling. + _ "net/http/pprof" + + "github.com/nats-io/gnatsd/util" +) + +// Info is the information sent to clients to help them understand information +// about this server. +type Info struct { + ID string `json:"server_id"` + Version string `json:"version"` + Proto int `json:"proto"` + GitCommit string `json:"git_commit,omitempty"` + GoVersion string `json:"go"` + Host string `json:"host"` + Port int `json:"port"` + AuthRequired bool `json:"auth_required,omitempty"` + TLSRequired bool `json:"tls_required,omitempty"` + TLSVerify bool `json:"tls_verify,omitempty"` + MaxPayload int `json:"max_payload"` + IP string `json:"ip,omitempty"` + CID uint64 `json:"client_id,omitempty"` + ClientConnectURLs []string `json:"connect_urls,omitempty"` // Contains URLs a client can connect to. +} + +// Server is our main struct. +type Server struct { + gcid uint64 + stats + mu sync.Mutex + info Info + sl *Sublist + configFile string + optsMu sync.RWMutex + opts *Options + running bool + shutdown bool + listener net.Listener + clients map[uint64]*client + routes map[uint64]*client + remotes map[string]*client + users map[string]*User + totalClients uint64 + closed *closedRingBuffer + done chan bool + start time.Time + http net.Listener + httpHandler http.Handler + profiler net.Listener + httpReqStats map[string]uint64 + routeListener net.Listener + routeInfo Info + routeInfoJSON []byte + quitCh chan struct{} + + // Tracking for remote QRSID tags. + rqsMu sync.RWMutex + rqsubs map[string]rqsub + rqsubsTimer *time.Timer + + // Tracking Go routines + grMu sync.Mutex + grTmpClients map[uint64]*client + grRunning bool + grWG sync.WaitGroup // to wait on various go routines + + cproto int64 // number of clients supporting async INFO + configTime time.Time // last time config was loaded + + logging struct { + sync.RWMutex + logger Logger + trace int32 + debug int32 + } + + clientConnectURLs []string + + // Used internally for quick look-ups. + clientConnectURLsMap map[string]struct{} + + lastCURLsUpdate int64 + + // These store the real client/cluster listen ports. They are + // required during config reload to reset the Options (after + // reload) to the actual listen port values. + clientActualPort int + clusterActualPort int + + // Used by tests to check that http.Servers do + // not set any timeout. + monitoringServer *http.Server + profilingServer *http.Server +} + +// Make sure all are 64bits for atomic use +type stats struct { + inMsgs int64 + outMsgs int64 + inBytes int64 + outBytes int64 + slowConsumers int64 +} + +// New will setup a new server struct after parsing the options. +func New(opts *Options) *Server { + processOptions(opts) + + // Process TLS options, including whether we require client certificates. + tlsReq := opts.TLSConfig != nil + verify := (tlsReq && opts.TLSConfig.ClientAuth == tls.RequireAndVerifyClientCert) + + info := Info{ + ID: genID(), + Version: VERSION, + Proto: PROTO, + GitCommit: gitCommit, + GoVersion: runtime.Version(), + Host: opts.Host, + Port: opts.Port, + AuthRequired: false, + TLSRequired: tlsReq, + TLSVerify: verify, + MaxPayload: opts.MaxPayload, + } + + now := time.Now() + s := &Server{ + configFile: opts.ConfigFile, + info: info, + sl: NewSublist(), + opts: opts, + done: make(chan bool, 1), + start: now, + configTime: now, + } + + s.mu.Lock() + defer s.mu.Unlock() + + // This is normally done in the AcceptLoop, once the + // listener has been created (possibly with random port), + // but since some tests may expect the INFO to be properly + // set after New(), let's do it now. + s.setInfoHostPortAndGenerateJSON() + + // Used internally for quick look-ups. + s.clientConnectURLsMap = make(map[string]struct{}) + + // For tracking clients + s.clients = make(map[uint64]*client) + + // For tracking closed clients. + s.closed = newClosedRingBuffer(opts.MaxClosedClients) + + // For tracking connections that are not yet registered + // in s.routes, but for which readLoop has started. + s.grTmpClients = make(map[uint64]*client) + + // For tracking routes and their remote ids + s.routes = make(map[uint64]*client) + s.remotes = make(map[string]*client) + + // Used to kick out all go routines possibly waiting on server + // to shutdown. + s.quitCh = make(chan struct{}) + + // Used to setup Authorization. + s.configureAuthorization() + + // Start signal handler + s.handleSignals() + + return s +} + +func (s *Server) getOpts() *Options { + s.optsMu.RLock() + opts := s.opts + s.optsMu.RUnlock() + return opts +} + +func (s *Server) setOpts(opts *Options) { + s.optsMu.Lock() + s.opts = opts + s.optsMu.Unlock() +} + +func (s *Server) generateRouteInfoJSON() { + b, _ := json.Marshal(s.routeInfo) + pcs := [][]byte{[]byte("INFO"), b, []byte(CR_LF)} + s.routeInfoJSON = bytes.Join(pcs, []byte(" ")) +} + +// PrintAndDie is exported for access in other packages. +func PrintAndDie(msg string) { + fmt.Fprintf(os.Stderr, "%s\n", msg) + os.Exit(1) +} + +// PrintServerAndExit will print our version and exit. +func PrintServerAndExit() { + fmt.Printf("nats-server version %s\n", VERSION) + os.Exit(0) +} + +// ProcessCommandLineArgs takes the command line arguments +// validating and setting flags for handling in case any +// sub command was present. +func ProcessCommandLineArgs(cmd *flag.FlagSet) (showVersion bool, showHelp bool, err error) { + if len(cmd.Args()) > 0 { + arg := cmd.Args()[0] + switch strings.ToLower(arg) { + case "version": + return true, false, nil + case "help": + return false, true, nil + default: + return false, false, fmt.Errorf("unrecognized command: %q", arg) + } + } + + return false, false, nil +} + +// Protected check on running state +func (s *Server) isRunning() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.running +} + +func (s *Server) logPid() error { + pidStr := strconv.Itoa(os.Getpid()) + return ioutil.WriteFile(s.getOpts().PidFile, []byte(pidStr), 0660) +} + +// Start up the server, this will block. +// Start via a Go routine if needed. +func (s *Server) Start() { + s.Noticef("Starting nats-server version %s", VERSION) + s.Debugf("Go build version %s", s.info.GoVersion) + gc := gitCommit + if gc == "" { + gc = "not set" + } + s.Noticef("Git commit [%s]", gc) + + // Avoid RACE between Start() and Shutdown() + s.mu.Lock() + s.running = true + s.mu.Unlock() + + s.grMu.Lock() + s.grRunning = true + s.grMu.Unlock() + + // Snapshot server options. + opts := s.getOpts() + + // Log the pid to a file + if opts.PidFile != _EMPTY_ { + if err := s.logPid(); err != nil { + PrintAndDie(fmt.Sprintf("Could not write pidfile: %v\n", err)) + } + } + + // Start monitoring if needed + if err := s.StartMonitoring(); err != nil { + s.Fatalf("Can't start monitoring: %v", err) + return + } + + // The Routing routine needs to wait for the client listen + // port to be opened and potential ephemeral port selected. + clientListenReady := make(chan struct{}) + + // Start up routing as well if needed. + if opts.Cluster.Port != 0 { + s.startGoRoutine(func() { + s.StartRouting(clientListenReady) + }) + } + + // Pprof http endpoint for the profiler. + if opts.ProfPort != 0 { + s.StartProfiler() + } + + if opts.PortsFileDir != _EMPTY_ { + s.logPorts() + } + + // Wait for clients. + s.AcceptLoop(clientListenReady) +} + +// Shutdown will shutdown the server instance by kicking out the AcceptLoop +// and closing all associated clients. +func (s *Server) Shutdown() { + s.mu.Lock() + // Prevent issues with multiple calls. + if s.shutdown { + s.mu.Unlock() + return + } + + opts := s.getOpts() + + s.shutdown = true + s.running = false + s.grMu.Lock() + s.grRunning = false + s.grMu.Unlock() + + conns := make(map[uint64]*client) + + // Copy off the clients + for i, c := range s.clients { + conns[i] = c + } + // Copy off the connections that are not yet registered + // in s.routes, but for which the readLoop has started + s.grMu.Lock() + for i, c := range s.grTmpClients { + conns[i] = c + } + s.grMu.Unlock() + // Copy off the routes + for i, r := range s.routes { + r.setRouteNoReconnectOnClose() + conns[i] = r + } + + // Number of done channel responses we expect. + doneExpected := 0 + + // Kick client AcceptLoop() + if s.listener != nil { + doneExpected++ + s.listener.Close() + s.listener = nil + } + + // Kick route AcceptLoop() + if s.routeListener != nil { + doneExpected++ + s.routeListener.Close() + s.routeListener = nil + } + + // Kick HTTP monitoring if its running + if s.http != nil { + doneExpected++ + s.http.Close() + s.http = nil + } + + // Kick Profiling if its running + if s.profiler != nil { + doneExpected++ + s.profiler.Close() + } + + // Clear any remote qsub mappings + s.clearRemoteQSubs() + s.mu.Unlock() + + // Release go routines that wait on that channel + close(s.quitCh) + + // Close client and route connections + for _, c := range conns { + c.closeConnection(ServerShutdown) + } + + // Block until the accept loops exit + for doneExpected > 0 { + <-s.done + doneExpected-- + } + + // Wait for go routines to be done. + s.grWG.Wait() + + if opts.PortsFileDir != _EMPTY_ { + s.deletePortsFile(opts.PortsFileDir) + } +} + +// AcceptLoop is exported for easier testing. +func (s *Server) AcceptLoop(clr chan struct{}) { + // If we were to exit before the listener is setup properly, + // make sure we close the channel. + defer func() { + if clr != nil { + close(clr) + } + }() + + // Snapshot server options. + opts := s.getOpts() + + hp := net.JoinHostPort(opts.Host, strconv.Itoa(opts.Port)) + l, e := net.Listen("tcp", hp) + if e != nil { + s.Fatalf("Error listening on port: %s, %q", hp, e) + return + } + s.Noticef("Listening for client connections on %s", + net.JoinHostPort(opts.Host, strconv.Itoa(l.Addr().(*net.TCPAddr).Port))) + + // Alert of TLS enabled. + if opts.TLSConfig != nil { + s.Noticef("TLS required for client connections") + } + + s.Debugf("Server id is %s", s.info.ID) + s.Noticef("Server is ready") + + // Setup state that can enable shutdown + s.mu.Lock() + s.listener = l + + // If server was started with RANDOM_PORT (-1), opts.Port would be equal + // to 0 at the beginning this function. So we need to get the actual port + if opts.Port == 0 { + // Write resolved port back to options. + opts.Port = l.Addr().(*net.TCPAddr).Port + } + // Keep track of actual listen port. This will be needed in case of + // config reload. + s.clientActualPort = opts.Port + + // Now that port has been set (if it was set to RANDOM), set the + // server's info Host/Port with either values from Options or + // ClientAdvertise. Also generate the JSON byte array. + if err := s.setInfoHostPortAndGenerateJSON(); err != nil { + s.Fatalf("Error setting server INFO with ClientAdvertise value of %s, err=%v", s.opts.ClientAdvertise, err) + s.mu.Unlock() + return + } + // Keep track of client connect URLs. We may need them later. + s.clientConnectURLs = s.getClientConnectURLs() + s.mu.Unlock() + + // Let the caller know that we are ready + close(clr) + clr = nil + + tmpDelay := ACCEPT_MIN_SLEEP + + for s.isRunning() { + conn, err := l.Accept() + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Temporary() { + s.Errorf("Temporary Client Accept Error (%v), sleeping %dms", + ne, tmpDelay/time.Millisecond) + time.Sleep(tmpDelay) + tmpDelay *= 2 + if tmpDelay > ACCEPT_MAX_SLEEP { + tmpDelay = ACCEPT_MAX_SLEEP + } + } else if s.isRunning() { + s.Errorf("Client Accept Error: %v", err) + } + continue + } + tmpDelay = ACCEPT_MIN_SLEEP + s.startGoRoutine(func() { + s.createClient(conn) + s.grWG.Done() + }) + } + s.Noticef("Server Exiting..") + s.done <- true +} + +// This function sets the server's info Host/Port based on server Options. +// Note that this function may be called during config reload, this is why +// Host/Port may be reset to original Options if the ClientAdvertise option +// is not set (since it may have previously been). +// The function then generates the server infoJSON. +func (s *Server) setInfoHostPortAndGenerateJSON() error { + // When this function is called, opts.Port is set to the actual listen + // port (if option was originally set to RANDOM), even during a config + // reload. So use of s.opts.Port is safe. + if s.opts.ClientAdvertise != "" { + h, p, err := parseHostPort(s.opts.ClientAdvertise, s.opts.Port) + if err != nil { + return err + } + s.info.Host = h + s.info.Port = p + } else { + s.info.Host = s.opts.Host + s.info.Port = s.opts.Port + } + return nil +} + +// StartProfiler is called to enable dynamic profiling. +func (s *Server) StartProfiler() { + // Snapshot server options. + opts := s.getOpts() + + port := opts.ProfPort + + // Check for Random Port + if port == -1 { + port = 0 + } + + hp := net.JoinHostPort(opts.Host, strconv.Itoa(port)) + + l, err := net.Listen("tcp", hp) + s.Noticef("profiling port: %d", l.Addr().(*net.TCPAddr).Port) + + if err != nil { + s.Fatalf("error starting profiler: %s", err) + } + + srv := &http.Server{ + Addr: hp, + Handler: http.DefaultServeMux, + MaxHeaderBytes: 1 << 20, + } + + s.mu.Lock() + s.profiler = l + s.profilingServer = srv + s.mu.Unlock() + + go func() { + // if this errors out, it's probably because the server is being shutdown + err := srv.Serve(l) + if err != nil { + s.mu.Lock() + shutdown := s.shutdown + s.mu.Unlock() + if !shutdown { + s.Fatalf("error starting profiler: %s", err) + } + } + s.done <- true + }() +} + +// StartHTTPMonitoring will enable the HTTP monitoring port. +// DEPRECATED: Should use StartMonitoring. +func (s *Server) StartHTTPMonitoring() { + s.startMonitoring(false) +} + +// StartHTTPSMonitoring will enable the HTTPS monitoring port. +// DEPRECATED: Should use StartMonitoring. +func (s *Server) StartHTTPSMonitoring() { + s.startMonitoring(true) +} + +// StartMonitoring starts the HTTP or HTTPs server if needed. +func (s *Server) StartMonitoring() error { + // Snapshot server options. + opts := s.getOpts() + + // Specifying both HTTP and HTTPS ports is a misconfiguration + if opts.HTTPPort != 0 && opts.HTTPSPort != 0 { + return fmt.Errorf("can't specify both HTTP (%v) and HTTPs (%v) ports", opts.HTTPPort, opts.HTTPSPort) + } + var err error + if opts.HTTPPort != 0 { + err = s.startMonitoring(false) + } else if opts.HTTPSPort != 0 { + if opts.TLSConfig == nil { + return fmt.Errorf("TLS cert and key required for HTTPS") + } + err = s.startMonitoring(true) + } + return err +} + +// HTTP endpoints +const ( + RootPath = "/" + VarzPath = "/varz" + ConnzPath = "/connz" + RoutezPath = "/routez" + SubszPath = "/subsz" + StackszPath = "/stacksz" +) + +// Start the monitoring server +func (s *Server) startMonitoring(secure bool) error { + // Snapshot server options. + opts := s.getOpts() + + // Used to track HTTP requests + s.httpReqStats = map[string]uint64{ + RootPath: 0, + VarzPath: 0, + ConnzPath: 0, + RoutezPath: 0, + SubszPath: 0, + } + + var ( + hp string + err error + httpListener net.Listener + port int + ) + + monitorProtocol := "http" + + if secure { + monitorProtocol += "s" + port = opts.HTTPSPort + if port == -1 { + port = 0 + } + hp = net.JoinHostPort(opts.HTTPHost, strconv.Itoa(port)) + config := util.CloneTLSConfig(opts.TLSConfig) + config.ClientAuth = tls.NoClientCert + httpListener, err = tls.Listen("tcp", hp, config) + + } else { + port = opts.HTTPPort + if port == -1 { + port = 0 + } + hp = net.JoinHostPort(opts.HTTPHost, strconv.Itoa(port)) + httpListener, err = net.Listen("tcp", hp) + } + + if err != nil { + return fmt.Errorf("can't listen to the monitor port: %v", err) + } + + s.Noticef("Starting %s monitor on %s", monitorProtocol, + net.JoinHostPort(opts.HTTPHost, strconv.Itoa(httpListener.Addr().(*net.TCPAddr).Port))) + + mux := http.NewServeMux() + + // Root + mux.HandleFunc(RootPath, s.HandleRoot) + // Varz + mux.HandleFunc(VarzPath, s.HandleVarz) + // Connz + mux.HandleFunc(ConnzPath, s.HandleConnz) + // Routez + mux.HandleFunc(RoutezPath, s.HandleRoutez) + // Subz + mux.HandleFunc(SubszPath, s.HandleSubsz) + // Subz alias for backwards compatibility + mux.HandleFunc("/subscriptionsz", s.HandleSubsz) + // Stacksz + mux.HandleFunc(StackszPath, s.HandleStacksz) + + // Do not set a WriteTimeout because it could cause cURL/browser + // to return empty response or unable to display page if the + // server needs more time to build the response. + srv := &http.Server{ + Addr: hp, + Handler: mux, + MaxHeaderBytes: 1 << 20, + } + s.mu.Lock() + s.http = httpListener + s.httpHandler = mux + s.monitoringServer = srv + s.mu.Unlock() + + go func() { + srv.Serve(httpListener) + srv.Handler = nil + s.mu.Lock() + s.httpHandler = nil + s.mu.Unlock() + s.done <- true + }() + + return nil +} + +// HTTPHandler returns the http.Handler object used to handle monitoring +// endpoints. It will return nil if the server is not configured for +// monitoring, or if the server has not been started yet (Server.Start()). +func (s *Server) HTTPHandler() http.Handler { + s.mu.Lock() + defer s.mu.Unlock() + return s.httpHandler +} + +// Perform a conditional deep copy due to reference nature of ClientConnectURLs. +// If updates are made to Info, this function should be consulted and updated. +// Assume lock is held. +func (s *Server) copyInfo() Info { + info := s.info + if info.ClientConnectURLs != nil { + info.ClientConnectURLs = make([]string, len(s.info.ClientConnectURLs)) + copy(info.ClientConnectURLs, s.info.ClientConnectURLs) + } + return info +} + +func (s *Server) createClient(conn net.Conn) *client { + // Snapshot server options. + opts := s.getOpts() + + max_pay := int64(opts.MaxPayload) + max_subs := opts.MaxSubs + now := time.Now() + + c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: max_pay, msubs: max_subs, start: now, last: now} + + // Grab JSON info string + s.mu.Lock() + info := s.copyInfo() + s.totalClients++ + s.mu.Unlock() + + // Grab lock + c.mu.Lock() + + // Initialize + c.initClient() + + c.Debugf("Client connection created") + + // Send our information. + c.sendInfo(c.generateClientInfoJSON(info)) + + // Unlock to register + c.mu.Unlock() + + // Register with the server. + s.mu.Lock() + // If server is not running, Shutdown() may have already gathered the + // list of connections to close. It won't contain this one, so we need + // to bail out now otherwise the readLoop started down there would not + // be interrupted. + if !s.running { + s.mu.Unlock() + return c + } + + // If there is a max connections specified, check that adding + // this new client would not push us over the max + if opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn { + s.mu.Unlock() + c.maxConnExceeded() + return nil + } + s.clients[c.cid] = c + s.mu.Unlock() + + // Re-Grab lock + c.mu.Lock() + + // Check for TLS + if info.TLSRequired { + c.Debugf("Starting TLS client connection handshake") + c.nc = tls.Server(c.nc, opts.TLSConfig) + conn := c.nc.(*tls.Conn) + + // Setup the timeout + ttl := secondsToDuration(opts.TLSTimeout) + time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) + conn.SetReadDeadline(time.Now().Add(ttl)) + + // Force handshake + c.mu.Unlock() + if err := conn.Handshake(); err != nil { + c.Errorf("TLS handshake error: %v", err) + c.closeConnection(TLSHandshakeError) + return nil + } + // Reset the read deadline + conn.SetReadDeadline(time.Time{}) + + // Re-Grab lock + c.mu.Lock() + + // Indicate that handshake is complete (used in monitoring) + c.flags.set(handshakeComplete) + } + + // The connection may have been closed + if c.nc == nil { + c.mu.Unlock() + return c + } + + // Check for Auth. We schedule this timer after the TLS handshake to avoid + // the race where the timer fires during the handshake and causes the + // server to write bad data to the socket. See issue #432. + if info.AuthRequired { + c.setAuthTimer(secondsToDuration(opts.AuthTimeout)) + } + + // Do final client initialization + + // Set the Ping timer + c.setPingTimer() + + // Spin up the read loop. + s.startGoRoutine(c.readLoop) + + // Spin up the write loop. + s.startGoRoutine(c.writeLoop) + + if info.TLSRequired { + c.Debugf("TLS handshake complete") + cs := c.nc.(*tls.Conn).ConnectionState() + c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite)) + } + + c.mu.Unlock() + + return c +} + +// This will save off a closed client in a ring buffer such that +// /connz can inspect. Useful for debugging, etc. +func (s *Server) saveClosedClient(c *client, nc net.Conn, reason ClosedState) { + now := time.Now() + + c.mu.Lock() + + cc := &closedClient{} + cc.fill(c, nc, now) + cc.Stop = &now + cc.Reason = reason.String() + + // Do subs, do not place by default in main ConnInfo + if len(c.subs) > 0 { + cc.subs = make([]string, 0, len(c.subs)) + for _, sub := range c.subs { + cc.subs = append(cc.subs, string(sub.subject)) + } + } + // Hold user as well. + cc.user = c.opts.Username + c.mu.Unlock() + + // Place in the ring buffer + s.mu.Lock() + s.closed.append(cc) + s.mu.Unlock() +} + +// Adds the given array of urls to the server's INFO.ClientConnectURLs +// array. The server INFO JSON is regenerated. +// Note that a check is made to ensure that given URLs are not +// already present. So the INFO JSON is regenerated only if new ULRs +// were added. +// If there was a change, an INFO protocol is sent to registered clients +// that support async INFO protocols. +func (s *Server) addClientConnectURLsAndSendINFOToClients(urls []string) { + s.updateServerINFOAndSendINFOToClients(urls, true) +} + +// Removes the given array of urls from the server's INFO.ClientConnectURLs +// array. The server INFO JSON is regenerated if needed. +// If there was a change, an INFO protocol is sent to registered clients +// that support async INFO protocols. +func (s *Server) removeClientConnectURLsAndSendINFOToClients(urls []string) { + s.updateServerINFOAndSendINFOToClients(urls, false) +} + +// Updates the server's Info object with the given array of URLs and re-generate +// the infoJSON byte array, then send an (async) INFO protocol to clients that +// support it. +func (s *Server) updateServerINFOAndSendINFOToClients(urls []string, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + + // Will be set to true if we alter the server's Info object. + wasUpdated := false + remove := !add + for _, url := range urls { + _, present := s.clientConnectURLsMap[url] + if add && !present { + s.clientConnectURLsMap[url] = struct{}{} + wasUpdated = true + } else if remove && present { + delete(s.clientConnectURLsMap, url) + wasUpdated = true + } + } + if wasUpdated { + // Recreate the info.ClientConnectURL array from the map + s.info.ClientConnectURLs = s.info.ClientConnectURLs[:0] + // Add this server client connect ULRs first... + s.info.ClientConnectURLs = append(s.info.ClientConnectURLs, s.clientConnectURLs...) + for url := range s.clientConnectURLsMap { + s.info.ClientConnectURLs = append(s.info.ClientConnectURLs, url) + } + // Update the time of this update + s.lastCURLsUpdate = time.Now().UnixNano() + // Send to all registered clients that support async INFO protocols. + s.sendAsyncInfoToClients() + } +} + +// Handle closing down a connection when the handshake has timedout. +func tlsTimeout(c *client, conn *tls.Conn) { + c.mu.Lock() + nc := c.nc + c.mu.Unlock() + // Check if already closed + if nc == nil { + return + } + cs := conn.ConnectionState() + if !cs.HandshakeComplete { + c.Errorf("TLS handshake timeout") + c.sendErr("Secure Connection - TLS Required") + c.closeConnection(TLSHandshakeError) + } +} + +// Seems silly we have to write these +func tlsVersion(ver uint16) string { + switch ver { + case tls.VersionTLS10: + return "1.0" + case tls.VersionTLS11: + return "1.1" + case tls.VersionTLS12: + return "1.2" + } + return fmt.Sprintf("Unknown [%x]", ver) +} + +// We use hex here so we don't need multiple versions +func tlsCipher(cs uint16) string { + name, present := cipherMapByID[cs] + if present { + return name + } + return fmt.Sprintf("Unknown [%x]", cs) +} + +// Remove a client or route from our internal accounting. +func (s *Server) removeClient(c *client) { + var rID string + c.mu.Lock() + cid := c.cid + typ := c.typ + r := c.route + if r != nil { + rID = r.remoteID + } + updateProtoInfoCount := false + if typ == CLIENT && c.opts.Protocol >= ClientProtoInfo { + updateProtoInfoCount = true + } + c.mu.Unlock() + + s.mu.Lock() + switch typ { + case CLIENT: + delete(s.clients, cid) + if updateProtoInfoCount { + s.cproto-- + } + case ROUTER: + delete(s.routes, cid) + if r != nil { + rc, ok := s.remotes[rID] + // Only delete it if it is us.. + if ok && c == rc { + delete(s.remotes, rID) + } + } + // Remove from temporary map in case it is there. + s.grMu.Lock() + delete(s.grTmpClients, cid) + s.grMu.Unlock() + } + s.mu.Unlock() +} + +///////////////////////////////////////////////////////////////// +// These are some helpers for accounting in functional tests. +///////////////////////////////////////////////////////////////// + +// NumRoutes will report the number of registered routes. +func (s *Server) NumRoutes() int { + s.mu.Lock() + nr := len(s.routes) + s.mu.Unlock() + return nr +} + +// NumRemotes will report number of registered remotes. +func (s *Server) NumRemotes() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.remotes) +} + +// NumClients will report the number of registered clients. +func (s *Server) NumClients() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.clients) +} + +// getClient will return the client associated with cid. +func (s *Server) getClient(cid uint64) *client { + s.mu.Lock() + defer s.mu.Unlock() + return s.clients[cid] +} + +// NumSubscriptions will report how many subscriptions are active. +func (s *Server) NumSubscriptions() uint32 { + s.mu.Lock() + subs := s.sl.Count() + s.mu.Unlock() + return subs +} + +// NumSlowConsumers will report the number of slow consumers. +func (s *Server) NumSlowConsumers() int64 { + return atomic.LoadInt64(&s.slowConsumers) +} + +// ConfigTime will report the last time the server configuration was loaded. +func (s *Server) ConfigTime() time.Time { + s.mu.Lock() + defer s.mu.Unlock() + return s.configTime +} + +// Addr will return the net.Addr object for the current listener. +func (s *Server) Addr() net.Addr { + s.mu.Lock() + defer s.mu.Unlock() + if s.listener == nil { + return nil + } + return s.listener.Addr() +} + +// MonitorAddr will return the net.Addr object for the monitoring listener. +func (s *Server) MonitorAddr() *net.TCPAddr { + s.mu.Lock() + defer s.mu.Unlock() + if s.http == nil { + return nil + } + return s.http.Addr().(*net.TCPAddr) +} + +// ClusterAddr returns the net.Addr object for the route listener. +func (s *Server) ClusterAddr() *net.TCPAddr { + s.mu.Lock() + defer s.mu.Unlock() + if s.routeListener == nil { + return nil + } + return s.routeListener.Addr().(*net.TCPAddr) +} + +// ProfilerAddr returns the net.Addr object for the route listener. +func (s *Server) ProfilerAddr() *net.TCPAddr { + s.mu.Lock() + defer s.mu.Unlock() + if s.profiler == nil { + return nil + } + return s.profiler.Addr().(*net.TCPAddr) +} + +// ReadyForConnections returns `true` if the server is ready to accept client +// and, if routing is enabled, route connections. If after the duration +// `dur` the server is still not ready, returns `false`. +func (s *Server) ReadyForConnections(dur time.Duration) bool { + // Snapshot server options. + opts := s.getOpts() + + end := time.Now().Add(dur) + for time.Now().Before(end) { + s.mu.Lock() + ok := s.listener != nil && (opts.Cluster.Port == 0 || s.routeListener != nil) + s.mu.Unlock() + if ok { + return true + } + time.Sleep(25 * time.Millisecond) + } + return false +} + +// ID returns the server's ID +func (s *Server) ID() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.info.ID +} + +func (s *Server) startGoRoutine(f func()) { + s.grMu.Lock() + if s.grRunning { + s.grWG.Add(1) + go f() + } + s.grMu.Unlock() +} + +func (s *Server) numClosedConns() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.closed.len() +} + +func (s *Server) totalClosedConns() uint64 { + s.mu.Lock() + defer s.mu.Unlock() + return s.closed.totalConns() +} + +func (s *Server) closedClients() []*closedClient { + s.mu.Lock() + defer s.mu.Unlock() + return s.closed.closedClients() +} + +// getClientConnectURLs returns suitable URLs for clients to connect to the listen +// port based on the server options' Host and Port. If the Host corresponds to +// "any" interfaces, this call returns the list of resolved IP addresses. +// If ClientAdvertise is set, returns the client advertise host and port. +// The server lock is assumed held on entry. +func (s *Server) getClientConnectURLs() []string { + // Snapshot server options. + opts := s.getOpts() + + urls := make([]string, 0, 1) + + // short circuit if client advertise is set + if opts.ClientAdvertise != "" { + // just use the info host/port. This is updated in s.New() + urls = append(urls, net.JoinHostPort(s.info.Host, strconv.Itoa(s.info.Port))) + } else { + sPort := strconv.Itoa(opts.Port) + ipAddr, err := net.ResolveIPAddr("ip", opts.Host) + // If the host is "any" (0.0.0.0 or ::), get specific IPs from available + // interfaces. + if err == nil && ipAddr.IP.IsUnspecified() { + var ip net.IP + ifaces, _ := net.Interfaces() + for _, i := range ifaces { + addrs, _ := i.Addrs() + for _, addr := range addrs { + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + // Skip non global unicast addresses + if !ip.IsGlobalUnicast() || ip.IsUnspecified() { + ip = nil + continue + } + urls = append(urls, net.JoinHostPort(ip.String(), sPort)) + } + } + } + if err != nil || len(urls) == 0 { + // We are here if s.opts.Host is not "0.0.0.0" nor "::", or if for some + // reason we could not add any URL in the loop above. + // We had a case where a Windows VM was hosed and would have err == nil + // and not add any address in the array in the loop above, and we + // ended-up returning 0.0.0.0, which is problematic for Windows clients. + // Check for 0.0.0.0 or :: specifically, and ignore if that's the case. + if opts.Host == "0.0.0.0" || opts.Host == "::" { + s.Errorf("Address %q can not be resolved properly", opts.Host) + } else { + urls = append(urls, net.JoinHostPort(opts.Host, sPort)) + } + } + } + + return urls +} + +// if the ip is not specified, attempt to resolve it +func resolveHostPorts(addr net.Listener) []string { + hostPorts := make([]string, 0) + hp := addr.Addr().(*net.TCPAddr) + port := strconv.Itoa(hp.Port) + if hp.IP.IsUnspecified() { + var ip net.IP + ifaces, _ := net.Interfaces() + for _, i := range ifaces { + addrs, _ := i.Addrs() + for _, addr := range addrs { + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + hostPorts = append(hostPorts, net.JoinHostPort(ip.String(), port)) + case *net.IPAddr: + ip = v.IP + hostPorts = append(hostPorts, net.JoinHostPort(ip.String(), port)) + default: + continue + } + } + } + } else { + hostPorts = append(hostPorts, net.JoinHostPort(hp.IP.String(), port)) + } + return hostPorts +} + +// format the address of a net.Listener with a protocol +func formatURL(protocol string, addr net.Listener) []string { + hostports := resolveHostPorts(addr) + for i, hp := range hostports { + hostports[i] = fmt.Sprintf("%s://%s", protocol, hp) + } + return hostports +} + +// Ports describes URLs that the server can be contacted in +type Ports struct { + Nats []string `json:"nats,omitempty"` + Monitoring []string `json:"monitoring,omitempty"` + Cluster []string `json:"cluster,omitempty"` + Profile []string `json:"profile,omitempty"` +} + +// Attempts to resolve all the ports. If after maxWait the ports are not +// resolved, it returns nil. Otherwise it returns a Ports struct +// describing ports where the server can be contacted +func (s *Server) PortsInfo(maxWait time.Duration) *Ports { + if s.readyForListeners(maxWait) { + opts := s.getOpts() + + s.mu.Lock() + info := s.copyInfo() + listener := s.listener + httpListener := s.http + clusterListener := s.routeListener + profileListener := s.profiler + s.mu.Unlock() + + ports := Ports{} + + if listener != nil { + natsProto := "nats" + if info.TLSRequired { + natsProto = "tls" + } + ports.Nats = formatURL(natsProto, listener) + } + + if httpListener != nil { + monProto := "http" + if opts.HTTPSPort != 0 { + monProto = "https" + } + ports.Monitoring = formatURL(monProto, httpListener) + } + + if clusterListener != nil { + clusterProto := "nats" + if opts.Cluster.TLSConfig != nil { + clusterProto = "tls" + } + ports.Cluster = formatURL(clusterProto, clusterListener) + } + + if profileListener != nil { + ports.Profile = formatURL("http", profileListener) + } + + return &ports + } + + return nil +} + +// Returns the portsFile. If a non-empty dirHint is provided, the dirHint +// path is used instead of the server option value +func (s *Server) portFile(dirHint string) string { + dirname := s.getOpts().PortsFileDir + if dirHint != "" { + dirname = dirHint + } + if dirname == _EMPTY_ { + return _EMPTY_ + } + return path.Join(dirname, fmt.Sprintf("%s_%d.ports", path.Base(os.Args[0]), os.Getpid())) +} + +// Delete the ports file. If a non-empty dirHint is provided, the dirHint +// path is used instead of the server option value +func (s *Server) deletePortsFile(hintDir string) { + portsFile := s.portFile(hintDir) + if portsFile != "" { + if err := os.Remove(portsFile); err != nil { + s.Errorf("Error cleaning up ports file %s: %v", portsFile, err) + } + } +} + +// Writes a file with a serialized Ports to the specified ports_file_dir. +// The name of the file is `exename_pid.ports`, typically gnatsd_pid.ports. +// if ports file is not set, this function has no effect +func (s *Server) logPorts() { + opts := s.getOpts() + portsFile := s.portFile(opts.PortsFileDir) + if portsFile != _EMPTY_ { + go func() { + info := s.PortsInfo(5 * time.Second) + if info == nil { + s.Errorf("Unable to resolve the ports in the specified time") + return + } + data, err := json.Marshal(info) + if err != nil { + s.Errorf("Error marshaling ports file: %v", err) + return + } + if err := ioutil.WriteFile(portsFile, data, 0666); err != nil { + s.Errorf("Error writing ports file (%s): %v", portsFile, err) + return + } + + }() + } +} + +// waits until a calculated list of listeners is resolved or a timeout +func (s *Server) readyForListeners(dur time.Duration) bool { + end := time.Now().Add(dur) + for time.Now().Before(end) { + s.mu.Lock() + listeners := s.serviceListeners() + s.mu.Unlock() + if len(listeners) == 0 { + return false + } + + ok := true + for _, l := range listeners { + if l == nil { + ok = false + break + } + } + if ok { + return true + } + select { + case <-s.quitCh: + return false + case <-time.After(25 * time.Millisecond): + // continue - unable to select from quit - we are still running + } + } + return false +} + +// returns a list of listeners that are intended for the process +// if the entry is nil, the interface is yet to be resolved +func (s *Server) serviceListeners() []net.Listener { + listeners := make([]net.Listener, 0) + opts := s.getOpts() + listeners = append(listeners, s.listener) + if opts.Cluster.Port != 0 { + listeners = append(listeners, s.routeListener) + } + if opts.HTTPPort != 0 || opts.HTTPSPort != 0 { + listeners = append(listeners, s.http) + } + if opts.ProfPort != 0 { + listeners = append(listeners, s.profiler) + } + return listeners +} diff --git a/vendor/github.com/nats-io/gnatsd/server/service.go b/vendor/github.com/nats-io/gnatsd/server/service.go new file mode 100644 index 00000000000..a44cbac3348 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/service.go @@ -0,0 +1,28 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !windows + +package server + +// Run starts the NATS server. This wrapper function allows Windows to add a +// hook for running NATS as a service. +func Run(server *Server) error { + server.Start() + return nil +} + +// isWindowsService indicates if NATS is running as a Windows service. +func isWindowsService() bool { + return false +} diff --git a/vendor/github.com/nats-io/gnatsd/server/service_windows.go b/vendor/github.com/nats-io/gnatsd/server/service_windows.go new file mode 100644 index 00000000000..0b9fa949684 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/service_windows.go @@ -0,0 +1,127 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "os" + "time" + + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/debug" +) + +const ( + reopenLogCode = 128 + reopenLogCmd = svc.Cmd(reopenLogCode) + acceptReopenLog = svc.Accepted(reopenLogCode) +) + +var serviceName = "gnatsd" + +// SetServiceName allows setting a different service name +func SetServiceName(name string) { + serviceName = name +} + +// winServiceWrapper implements the svc.Handler interface for implementing +// gnatsd as a Windows service. +type winServiceWrapper struct { + server *Server +} + +var dockerized = false + +func init() { + if v, exists := os.LookupEnv("NATS_DOCKERIZED"); exists && v == "1" { + dockerized = true + } +} + +// Execute will be called by the package code at the start of +// the service, and the service will exit once Execute completes. +// Inside Execute you must read service change requests from r and +// act accordingly. You must keep service control manager up to date +// about state of your service by writing into s as required. +// args contains service name followed by argument strings passed +// to the service. +// You can provide service exit code in exitCode return parameter, +// with 0 being "no error". You can also indicate if exit code, +// if any, is service specific or not by using svcSpecificEC +// parameter. +func (w *winServiceWrapper) Execute(args []string, changes <-chan svc.ChangeRequest, + status chan<- svc.Status) (bool, uint32) { + + status <- svc.Status{State: svc.StartPending} + go w.server.Start() + + // Wait for accept loop(s) to be started + if !w.server.ReadyForConnections(10 * time.Second) { + // Failed to start. + return false, 1 + } + + status <- svc.Status{ + State: svc.Running, + Accepts: svc.AcceptStop | svc.AcceptShutdown | svc.AcceptParamChange | acceptReopenLog, + } + +loop: + for change := range changes { + switch change.Cmd { + case svc.Interrogate: + status <- change.CurrentStatus + case svc.Stop, svc.Shutdown: + w.server.Shutdown() + break loop + case reopenLogCmd: + // File log re-open for rotating file logs. + w.server.ReOpenLogFile() + case svc.ParamChange: + if err := w.server.Reload(); err != nil { + w.server.Errorf("Failed to reload server configuration: %s", err) + } + default: + w.server.Debugf("Unexpected control request: %v", change.Cmd) + } + } + + status <- svc.Status{State: svc.StopPending} + return false, 0 +} + +// Run starts the NATS server as a Windows service. +func Run(server *Server) error { + if dockerized { + server.Start() + return nil + } + run := svc.Run + isInteractive, err := svc.IsAnInteractiveSession() + if err != nil { + return err + } + if isInteractive { + run = debug.Run + } + return run(serviceName, &winServiceWrapper{server}) +} + +// isWindowsService indicates if NATS is running as a Windows service. +func isWindowsService() bool { + if dockerized { + return false + } + isInteractive, _ := svc.IsAnInteractiveSession() + return !isInteractive +} diff --git a/vendor/github.com/nats-io/gnatsd/server/signal.go b/vendor/github.com/nats-io/gnatsd/server/signal.go new file mode 100644 index 00000000000..8402214905c --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/signal.go @@ -0,0 +1,163 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !windows + +package server + +import ( + "errors" + "fmt" + "os" + "os/exec" + "os/signal" + "strconv" + "strings" + "syscall" +) + +var processName = "gnatsd" + +// SetProcessName allows to change the expected name of the process. +func SetProcessName(name string) { + processName = name +} + +// Signal Handling +func (s *Server) handleSignals() { + if s.getOpts().NoSigs { + return + } + c := make(chan os.Signal, 1) + + signal.Notify(c, syscall.SIGINT, syscall.SIGUSR1, syscall.SIGHUP) + + s.grWG.Add(1) + go func() { + defer s.grWG.Done() + for { + select { + case sig := <-c: + s.Debugf("Trapped %q signal", sig) + switch sig { + case syscall.SIGINT: + s.Noticef("Server Exiting..") + os.Exit(0) + case syscall.SIGUSR1: + // File log re-open for rotating file logs. + s.ReOpenLogFile() + case syscall.SIGHUP: + // Config reload. + if err := s.Reload(); err != nil { + s.Errorf("Failed to reload server configuration: %s", err) + } + } + case <-s.quitCh: + return + } + } + }() +} + +// ProcessSignal sends the given signal command to the given process. If pidStr +// is empty, this will send the signal to the single running instance of +// gnatsd. If multiple instances are running, it returns an error. This returns +// an error if the given process is not running or the command is invalid. +func ProcessSignal(command Command, pidStr string) error { + var pid int + if pidStr == "" { + pids, err := resolvePids() + if err != nil { + return err + } + if len(pids) == 0 { + return fmt.Errorf("no %s processes running", processName) + } + if len(pids) > 1 { + errStr := fmt.Sprintf("multiple %s processes running:\n", processName) + prefix := "" + for _, p := range pids { + errStr += fmt.Sprintf("%s%d", prefix, p) + prefix = "\n" + } + return errors.New(errStr) + } + pid = pids[0] + } else { + p, err := strconv.Atoi(pidStr) + if err != nil { + return fmt.Errorf("invalid pid: %s", pidStr) + } + pid = p + } + + var err error + switch command { + case CommandStop: + err = kill(pid, syscall.SIGKILL) + case CommandQuit: + err = kill(pid, syscall.SIGINT) + case CommandReopen: + err = kill(pid, syscall.SIGUSR1) + case CommandReload: + err = kill(pid, syscall.SIGHUP) + default: + err = fmt.Errorf("unknown signal %q", command) + } + return err +} + +// resolvePids returns the pids for all running gnatsd processes. +func resolvePids() ([]int, error) { + // If pgrep isn't available, this will just bail out and the user will be + // required to specify a pid. + output, err := pgrep() + if err != nil { + switch err.(type) { + case *exec.ExitError: + // ExitError indicates non-zero exit code, meaning no processes + // found. + break + default: + return nil, errors.New("unable to resolve pid, try providing one") + } + } + var ( + myPid = os.Getpid() + pidStrs = strings.Split(string(output), "\n") + pids = make([]int, 0, len(pidStrs)) + ) + for _, pidStr := range pidStrs { + if pidStr == "" { + continue + } + pid, err := strconv.Atoi(pidStr) + if err != nil { + return nil, errors.New("unable to resolve pid, try providing one") + } + // Ignore the current process. + if pid == myPid { + continue + } + pids = append(pids, pid) + } + return pids, nil +} + +var kill = func(pid int, signal syscall.Signal) error { + return syscall.Kill(pid, signal) +} + +var pgrep = func() ([]byte, error) { + return exec.Command("pgrep", processName).Output() +} diff --git a/vendor/github.com/nats-io/gnatsd/server/signal_windows.go b/vendor/github.com/nats-io/gnatsd/server/signal_windows.go new file mode 100644 index 00000000000..368077dd59a --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/signal_windows.go @@ -0,0 +1,101 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "os" + "os/signal" + "time" + + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" +) + +// Signal Handling +func (s *Server) handleSignals() { + if s.getOpts().NoSigs { + return + } + c := make(chan os.Signal, 1) + + signal.Notify(c, os.Interrupt) + + go func() { + for sig := range c { + s.Debugf("Trapped %q signal", sig) + s.Noticef("Server Exiting..") + os.Exit(0) + } + }() +} + +// ProcessSignal sends the given signal command to the running gnatsd service. +// If service is empty, this signals the "gnatsd" service. This returns an +// error is the given service is not running or the command is invalid. +func ProcessSignal(command Command, service string) error { + if service == "" { + service = serviceName + } + + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + + s, err := m.OpenService(service) + if err != nil { + return fmt.Errorf("could not access service: %v", err) + } + defer s.Close() + + var ( + cmd svc.Cmd + to svc.State + ) + + switch command { + case CommandStop, CommandQuit: + cmd = svc.Stop + to = svc.Stopped + case CommandReopen: + cmd = reopenLogCmd + to = svc.Running + case CommandReload: + cmd = svc.ParamChange + to = svc.Running + default: + return fmt.Errorf("unknown signal %q", command) + } + + status, err := s.Control(cmd) + if err != nil { + return fmt.Errorf("could not send control=%d: %v", cmd, err) + } + + timeout := time.Now().Add(10 * time.Second) + for status.State != to { + if timeout.Before(time.Now()) { + return fmt.Errorf("timeout waiting for service to go to state=%d", to) + } + time.Sleep(300 * time.Millisecond) + status, err = s.Query() + if err != nil { + return fmt.Errorf("could not retrieve service status: %v", err) + } + } + + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/sublist.go b/vendor/github.com/nats-io/gnatsd/server/sublist.go new file mode 100644 index 00000000000..6b4ef47cf08 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/sublist.go @@ -0,0 +1,836 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package sublist is a routing mechanism to handle subject distribution +// and provides a facility to match subjects from published messages to +// interested subscribers. Subscribers can have wildcard subjects to match +// multiple published subjects. +package server + +import ( + "bytes" + "errors" + "strings" + "sync" + "sync/atomic" +) + +// Common byte variables for wildcards and token separator. +const ( + pwc = '*' + fwc = '>' + tsep = "." + btsep = '.' +) + +// Sublist related errors +var ( + ErrInvalidSubject = errors.New("sublist: Invalid Subject") + ErrNotFound = errors.New("sublist: No Matches Found") +) + +const ( + // cacheMax is used to bound limit the frontend cache + slCacheMax = 1024 + // If we run a sweeper we will drain to this count. + slCacheSweep = 512 + // plistMin is our lower bounds to create a fast plist for Match. + plistMin = 256 +) + +// A result structure better optimized for queue subs. +type SublistResult struct { + psubs []*subscription + qsubs [][]*subscription // don't make this a map, too expensive to iterate +} + +// A Sublist stores and efficiently retrieves subscriptions. +type Sublist struct { + sync.RWMutex + genid uint64 + matches uint64 + cacheHits uint64 + inserts uint64 + removes uint64 + root *level + cache sync.Map + cacheNum int32 + ccSweep int32 + count uint32 +} + +// A node contains subscriptions and a pointer to the next level. +type node struct { + next *level + psubs map[*subscription]*subscription + qsubs map[string](map[*subscription]*subscription) + plist []*subscription +} + +// A level represents a group of nodes and special pointers to +// wildcard nodes. +type level struct { + nodes map[string]*node + pwc, fwc *node +} + +// Create a new default node. +func newNode() *node { + return &node{psubs: make(map[*subscription]*subscription)} +} + +// Create a new default level. +func newLevel() *level { + return &level{nodes: make(map[string]*node)} +} + +// New will create a default sublist +func NewSublist() *Sublist { + return &Sublist{root: newLevel()} +} + +// Insert adds a subscription into the sublist +func (s *Sublist) Insert(sub *subscription) error { + // copy the subject since we hold this and this might be part of a large byte slice. + subject := string(sub.subject) + tsa := [32]string{} + tokens := tsa[:0] + start := 0 + for i := 0; i < len(subject); i++ { + if subject[i] == btsep { + tokens = append(tokens, subject[start:i]) + start = i + 1 + } + } + tokens = append(tokens, subject[start:]) + + s.Lock() + + sfwc := false + l := s.root + var n *node + + for _, t := range tokens { + lt := len(t) + if lt == 0 || sfwc { + s.Unlock() + return ErrInvalidSubject + } + + if lt > 1 { + n = l.nodes[t] + } else { + switch t[0] { + case pwc: + n = l.pwc + case fwc: + n = l.fwc + sfwc = true + default: + n = l.nodes[t] + } + } + if n == nil { + n = newNode() + if lt > 1 { + l.nodes[t] = n + } else { + switch t[0] { + case pwc: + l.pwc = n + case fwc: + l.fwc = n + default: + l.nodes[t] = n + } + } + } + if n.next == nil { + n.next = newLevel() + } + l = n.next + } + if sub.queue == nil { + n.psubs[sub] = sub + if n.plist != nil { + n.plist = append(n.plist, sub) + } else if len(n.psubs) > plistMin { + n.plist = make([]*subscription, 0, len(n.psubs)) + // Populate + for _, psub := range n.psubs { + n.plist = append(n.plist, psub) + } + } + } else { + if n.qsubs == nil { + n.qsubs = make(map[string]map[*subscription]*subscription) + } + qname := string(sub.queue) + // This is a queue subscription + subs, ok := n.qsubs[qname] + if !ok { + subs = make(map[*subscription]*subscription) + n.qsubs[qname] = subs + } + subs[sub] = sub + } + + s.count++ + s.inserts++ + + s.addToCache(subject, sub) + atomic.AddUint64(&s.genid, 1) + + s.Unlock() + return nil +} + +// Deep copy +func copyResult(r *SublistResult) *SublistResult { + nr := &SublistResult{} + nr.psubs = append([]*subscription(nil), r.psubs...) + for _, qr := range r.qsubs { + nqr := append([]*subscription(nil), qr...) + nr.qsubs = append(nr.qsubs, nqr) + } + return nr +} + +// Adds a new sub to an existing result. +func (r *SublistResult) addSubToResult(sub *subscription) *SublistResult { + // Copy since others may have a reference. + nr := copyResult(r) + if sub.queue == nil { + nr.psubs = append(nr.psubs, sub) + } else { + if i := findQSliceForSub(sub, nr.qsubs); i >= 0 { + nr.qsubs[i] = append(nr.qsubs[i], sub) + } else { + nr.qsubs = append(nr.qsubs, []*subscription{sub}) + } + } + return nr +} + +// addToCache will add the new entry to the existing cache +// entries if needed. Assumes write lock is held. +func (s *Sublist) addToCache(subject string, sub *subscription) { + // If literal we can direct match. + if subjectIsLiteral(subject) { + if v, ok := s.cache.Load(subject); ok { + r := v.(*SublistResult) + s.cache.Store(subject, r.addSubToResult(sub)) + } + return + } + s.cache.Range(func(k, v interface{}) bool { + key := k.(string) + r := v.(*SublistResult) + if matchLiteral(key, subject) { + s.cache.Store(key, r.addSubToResult(sub)) + } + return true + }) +} + +// removeFromCache will remove the sub from any active cache entries. +// Assumes write lock is held. +func (s *Sublist) removeFromCache(subject string, sub *subscription) { + // If literal we can direct match. + if subjectIsLiteral(subject) { + // Load for accounting + if _, ok := s.cache.Load(subject); ok { + s.cache.Delete(subject) + atomic.AddInt32(&s.cacheNum, -1) + } + return + } + s.cache.Range(func(k, v interface{}) bool { + key := k.(string) + if matchLiteral(key, subject) { + // Since someone else may be referecing, can't modify the list + // safely, just let it re-populate. + s.cache.Delete(key) + atomic.AddInt32(&s.cacheNum, -1) + } + return true + }) +} + +// Match will match all entries to the literal subject. +// It will return a set of results for both normal and queue subscribers. +func (s *Sublist) Match(subject string) *SublistResult { + atomic.AddUint64(&s.matches, 1) + + // Check cache first. + if r, ok := s.cache.Load(subject); ok { + atomic.AddUint64(&s.cacheHits, 1) + return r.(*SublistResult) + } + + tsa := [32]string{} + tokens := tsa[:0] + start := 0 + for i := 0; i < len(subject); i++ { + if subject[i] == btsep { + tokens = append(tokens, subject[start:i]) + start = i + 1 + } + } + tokens = append(tokens, subject[start:]) + + // FIXME(dlc) - Make shared pool between sublist and client readLoop? + result := &SublistResult{} + + // Get result from the main structure and place into the shared cache. + // Hold the read lock to avoid race between match and store. + s.RLock() + matchLevel(s.root, tokens, result) + s.cache.Store(subject, result) + n := atomic.AddInt32(&s.cacheNum, 1) + s.RUnlock() + + // Reduce the cache count if we have exceeded our set maximum. + if n > slCacheMax && atomic.CompareAndSwapInt32(&s.ccSweep, 0, 1) { + go s.reduceCacheCount() + } + + return result +} + +// Remove entries in the cache until we are under the maximum. +// TODO(dlc) this could be smarter now that its not inline. +func (s *Sublist) reduceCacheCount() { + defer atomic.StoreInt32(&s.ccSweep, 0) + // If we are over the cache limit randomly drop until under the limit. + s.cache.Range(func(k, v interface{}) bool { + s.cache.Delete(k.(string)) + n := atomic.AddInt32(&s.cacheNum, -1) + if n < slCacheSweep { + return false + } + return true + }) +} + +// This will add in a node's results to the total results. +func addNodeToResults(n *node, results *SublistResult) { + // Normal subscriptions + if n.plist != nil { + results.psubs = append(results.psubs, n.plist...) + } else { + for _, psub := range n.psubs { + results.psubs = append(results.psubs, psub) + } + } + // Queue subscriptions + for qname, qr := range n.qsubs { + if len(qr) == 0 { + continue + } + tsub := &subscription{subject: nil, queue: []byte(qname)} + // Need to find matching list in results + if i := findQSliceForSub(tsub, results.qsubs); i >= 0 { + for _, sub := range qr { + results.qsubs[i] = append(results.qsubs[i], sub) + } + } else { + var nqsub []*subscription + for _, sub := range qr { + nqsub = append(nqsub, sub) + } + results.qsubs = append(results.qsubs, nqsub) + } + } +} + +// We do not use a map here since we want iteration to be past when +// processing publishes in L1 on client. So we need to walk sequentially +// for now. Keep an eye on this in case we start getting large number of +// different queue subscribers for the same subject. +func findQSliceForSub(sub *subscription, qsl [][]*subscription) int { + if sub.queue == nil { + return -1 + } + for i, qr := range qsl { + if len(qr) > 0 && bytes.Equal(sub.queue, qr[0].queue) { + return i + } + } + return -1 +} + +// matchLevel is used to recursively descend into the trie. +func matchLevel(l *level, toks []string, results *SublistResult) { + var pwc, n *node + for i, t := range toks { + if l == nil { + return + } + if l.fwc != nil { + addNodeToResults(l.fwc, results) + } + if pwc = l.pwc; pwc != nil { + matchLevel(pwc.next, toks[i+1:], results) + } + n = l.nodes[t] + if n != nil { + l = n.next + } else { + l = nil + } + } + if n != nil { + addNodeToResults(n, results) + } + if pwc != nil { + addNodeToResults(pwc, results) + } +} + +// lnt is used to track descent into levels for a removal for pruning. +type lnt struct { + l *level + n *node + t string +} + +// Raw low level remove, can do batches with lock held outside. +func (s *Sublist) remove(sub *subscription, shouldLock bool) error { + subject := string(sub.subject) + tsa := [32]string{} + tokens := tsa[:0] + start := 0 + for i := 0; i < len(subject); i++ { + if subject[i] == btsep { + tokens = append(tokens, subject[start:i]) + start = i + 1 + } + } + tokens = append(tokens, subject[start:]) + + if shouldLock { + s.Lock() + defer s.Unlock() + } + + sfwc := false + l := s.root + var n *node + + // Track levels for pruning + var lnts [32]lnt + levels := lnts[:0] + + for _, t := range tokens { + lt := len(t) + if lt == 0 || sfwc { + return ErrInvalidSubject + } + if l == nil { + return ErrNotFound + } + if lt > 1 { + n = l.nodes[t] + } else { + switch t[0] { + case pwc: + n = l.pwc + case fwc: + n = l.fwc + sfwc = true + default: + n = l.nodes[t] + } + } + if n != nil { + levels = append(levels, lnt{l, n, t}) + l = n.next + } else { + l = nil + } + } + if !s.removeFromNode(n, sub) { + return ErrNotFound + } + + s.count-- + s.removes++ + + for i := len(levels) - 1; i >= 0; i-- { + l, n, t := levels[i].l, levels[i].n, levels[i].t + if n.isEmpty() { + l.pruneNode(n, t) + } + } + s.removeFromCache(subject, sub) + atomic.AddUint64(&s.genid, 1) + + return nil +} + +// Remove will remove a subscription. +func (s *Sublist) Remove(sub *subscription) error { + return s.remove(sub, true) +} + +// RemoveBatch will remove a list of subscriptions. +func (s *Sublist) RemoveBatch(subs []*subscription) error { + s.Lock() + defer s.Unlock() + + for _, sub := range subs { + if err := s.remove(sub, false); err != nil { + return err + } + } + return nil +} + +// pruneNode is used to prune an empty node from the tree. +func (l *level) pruneNode(n *node, t string) { + if n == nil { + return + } + if n == l.fwc { + l.fwc = nil + } else if n == l.pwc { + l.pwc = nil + } else { + delete(l.nodes, t) + } +} + +// isEmpty will test if the node has any entries. Used +// in pruning. +func (n *node) isEmpty() bool { + if len(n.psubs) == 0 && len(n.qsubs) == 0 { + if n.next == nil || n.next.numNodes() == 0 { + return true + } + } + return false +} + +// Return the number of nodes for the given level. +func (l *level) numNodes() int { + num := len(l.nodes) + if l.pwc != nil { + num++ + } + if l.fwc != nil { + num++ + } + return num +} + +// Remove the sub for the given node. +func (s *Sublist) removeFromNode(n *node, sub *subscription) (found bool) { + if n == nil { + return false + } + if sub.queue == nil { + _, found = n.psubs[sub] + delete(n.psubs, sub) + if found && n.plist != nil { + // This will brute force remove the plist to perform + // correct behavior. Will get repopulated on a call + // to Match as needed. + n.plist = nil + } + return found + } + + // We have a queue group subscription here + qname := string(sub.queue) + qsub := n.qsubs[qname] + _, found = qsub[sub] + delete(qsub, sub) + if len(qsub) == 0 { + delete(n.qsubs, qname) + } + return found +} + +// Count returns the number of subscriptions. +func (s *Sublist) Count() uint32 { + s.RLock() + defer s.RUnlock() + return s.count +} + +// CacheCount returns the number of result sets in the cache. +func (s *Sublist) CacheCount() int { + return int(atomic.LoadInt32(&s.cacheNum)) +} + +// Public stats for the sublist +type SublistStats struct { + NumSubs uint32 `json:"num_subscriptions"` + NumCache uint32 `json:"num_cache"` + NumInserts uint64 `json:"num_inserts"` + NumRemoves uint64 `json:"num_removes"` + NumMatches uint64 `json:"num_matches"` + CacheHitRate float64 `json:"cache_hit_rate"` + MaxFanout uint32 `json:"max_fanout"` + AvgFanout float64 `json:"avg_fanout"` +} + +// Stats will return a stats structure for the current state. +func (s *Sublist) Stats() *SublistStats { + s.Lock() + defer s.Unlock() + + st := &SublistStats{} + st.NumSubs = s.count + st.NumCache = uint32(atomic.LoadInt32(&s.cacheNum)) + st.NumInserts = s.inserts + st.NumRemoves = s.removes + st.NumMatches = atomic.LoadUint64(&s.matches) + if st.NumMatches > 0 { + st.CacheHitRate = float64(atomic.LoadUint64(&s.cacheHits)) / float64(st.NumMatches) + } + + // whip through cache for fanout stats, this can be off if cache is full and doing evictions. + tot, max := 0, 0 + clen := 0 + s.cache.Range(func(k, v interface{}) bool { + clen += 1 + r := v.(*SublistResult) + l := len(r.psubs) + len(r.qsubs) + tot += l + if l > max { + max = l + } + return true + }) + st.MaxFanout = uint32(max) + if tot > 0 { + st.AvgFanout = float64(tot) / float64(clen) + } + return st +} + +// numLevels will return the maximum number of levels +// contained in the Sublist tree. +func (s *Sublist) numLevels() int { + return visitLevel(s.root, 0) +} + +// visitLevel is used to descend the Sublist tree structure +// recursively. +func visitLevel(l *level, depth int) int { + if l == nil || l.numNodes() == 0 { + return depth + } + + depth++ + maxDepth := depth + + for _, n := range l.nodes { + if n == nil { + continue + } + newDepth := visitLevel(n.next, depth) + if newDepth > maxDepth { + maxDepth = newDepth + } + } + if l.pwc != nil { + pwcDepth := visitLevel(l.pwc.next, depth) + if pwcDepth > maxDepth { + maxDepth = pwcDepth + } + } + if l.fwc != nil { + fwcDepth := visitLevel(l.fwc.next, depth) + if fwcDepth > maxDepth { + maxDepth = fwcDepth + } + } + return maxDepth +} + +// Determine if the subject has any wildcards. Fast version, does not check for +// valid subject. Used in caching layer. +func subjectIsLiteral(subject string) bool { + for i, c := range subject { + if c == pwc || c == fwc { + if (i == 0 || subject[i-1] == btsep) && + (i+1 == len(subject) || subject[i+1] == btsep) { + return false + } + } + } + return true +} + +// IsValidSubject returns true if a subject is valid, false otherwise +func IsValidSubject(subject string) bool { + if subject == "" { + return false + } + sfwc := false + tokens := strings.Split(subject, tsep) + for _, t := range tokens { + if len(t) == 0 || sfwc { + return false + } + if len(t) > 1 { + continue + } + switch t[0] { + case fwc: + sfwc = true + } + } + return true +} + +// IsValidLiteralSubject returns true if a subject is valid and literal (no wildcards), false otherwise +func IsValidLiteralSubject(subject string) bool { + tokens := strings.Split(subject, tsep) + for _, t := range tokens { + if len(t) == 0 { + return false + } + if len(t) > 1 { + continue + } + switch t[0] { + case pwc, fwc: + return false + } + } + return true +} + +// matchLiteral is used to test literal subjects, those that do not have any +// wildcards, with a target subject. This is used in the cache layer. +func matchLiteral(literal, subject string) bool { + li := 0 + ll := len(literal) + ls := len(subject) + for i := 0; i < ls; i++ { + if li >= ll { + return false + } + // This function has been optimized for speed. + // For instance, do not set b:=subject[i] here since + // we may bump `i` in this loop to avoid `continue` or + // skiping common test in a particular test. + // Run Benchmark_SublistMatchLiteral before making any change. + switch subject[i] { + case pwc: + // NOTE: This is not testing validity of a subject, instead ensures + // that wildcards are treated as such if they follow some basic rules, + // namely that they are a token on their own. + if i == 0 || subject[i-1] == btsep { + if i == ls-1 { + // There is no more token in the subject after this wildcard. + // Skip token in literal and expect to not find a separator. + for { + // End of literal, this is a match. + if li >= ll { + return true + } + // Presence of separator, this can't be a match. + if literal[li] == btsep { + return false + } + li++ + } + } else if subject[i+1] == btsep { + // There is another token in the subject after this wildcard. + // Skip token in literal and expect to get a separator. + for { + // We found the end of the literal before finding a separator, + // this can't be a match. + if li >= ll { + return false + } + if literal[li] == btsep { + break + } + li++ + } + // Bump `i` since we know there is a `.` following, we are + // safe. The common test below is going to check `.` with `.` + // which is good. A `continue` here is too costly. + i++ + } + } + case fwc: + // For `>` to be a wildcard, it means being the only or last character + // in the string preceded by a `.` + if (i == 0 || subject[i-1] == btsep) && i == ls-1 { + return true + } + } + if subject[i] != literal[li] { + return false + } + li++ + } + // Make sure we have processed all of the literal's chars.. + return li >= ll +} + +func addLocalSub(sub *subscription, subs *[]*subscription) { + if sub != nil && sub.client != nil && sub.client.typ == CLIENT { + *subs = append(*subs, sub) + } +} + +func (s *Sublist) addNodeToSubs(n *node, subs *[]*subscription) { + // Normal subscriptions + if n.plist != nil { + for _, sub := range n.plist { + addLocalSub(sub, subs) + } + } else { + for _, sub := range n.psubs { + addLocalSub(sub, subs) + } + } + // Queue subscriptions + for _, qr := range n.qsubs { + for _, sub := range qr { + addLocalSub(sub, subs) + } + } +} + +func (s *Sublist) collectLocalSubs(l *level, subs *[]*subscription) { + if len(l.nodes) > 0 { + for _, n := range l.nodes { + s.addNodeToSubs(n, subs) + s.collectLocalSubs(n.next, subs) + } + } + if l.pwc != nil { + s.addNodeToSubs(l.pwc, subs) + s.collectLocalSubs(l.pwc.next, subs) + } + if l.fwc != nil { + s.addNodeToSubs(l.fwc, subs) + s.collectLocalSubs(l.fwc.next, subs) + } +} + +// Return all local client subscriptions. Use the supplied slice. +func (s *Sublist) localSubs(subs *[]*subscription) { + s.RLock() + s.collectLocalSubs(s.root, subs) + s.RUnlock() +} diff --git a/vendor/github.com/nats-io/gnatsd/server/util.go b/vendor/github.com/nats-io/gnatsd/server/util.go new file mode 100644 index 00000000000..9fcfac6d6e6 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/util.go @@ -0,0 +1,119 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "errors" + "fmt" + "net" + "net/url" + "reflect" + "strconv" + "strings" + "time" + + "github.com/nats-io/nuid" +) + +// Use nuid. +func genID() string { + return nuid.Next() +} + +// Ascii numbers 0-9 +const ( + asciiZero = 48 + asciiNine = 57 +) + +// parseSize expects decimal positive numbers. We +// return -1 to signal error. +func parseSize(d []byte) (n int) { + l := len(d) + if l == 0 { + return -1 + } + var ( + i int + dec byte + ) + + // Note: Use `goto` here to avoid for loop in order + // to have the function be inlined. + // See: https://github.com/golang/go/issues/14768 +loop: + dec = d[i] + if dec < asciiZero || dec > asciiNine { + return -1 + } + n = n*10 + (int(dec) - asciiZero) + + i++ + if i < l { + goto loop + } + return n +} + +// parseInt64 expects decimal positive numbers. We +// return -1 to signal error +func parseInt64(d []byte) (n int64) { + if len(d) == 0 { + return -1 + } + for _, dec := range d { + if dec < asciiZero || dec > asciiNine { + return -1 + } + n = n*10 + (int64(dec) - asciiZero) + } + return n +} + +// Helper to move from float seconds to time.Duration +func secondsToDuration(seconds float64) time.Duration { + ttl := seconds * float64(time.Second) + return time.Duration(ttl) +} + +// Parse a host/port string with a default port to use +// if none (or 0 or -1) is specified in `hostPort` string. +func parseHostPort(hostPort string, defaultPort int) (host string, port int, err error) { + if hostPort != "" { + host, sPort, err := net.SplitHostPort(hostPort) + switch err.(type) { + case *net.AddrError: + // try appending the current port + host, sPort, err = net.SplitHostPort(fmt.Sprintf("%s:%d", hostPort, defaultPort)) + } + if err != nil { + return "", -1, err + } + port, err = strconv.Atoi(strings.TrimSpace(sPort)) + if err != nil { + return "", -1, err + } + if port == 0 || port == -1 { + port = defaultPort + } + return strings.TrimSpace(host), port, nil + } + return "", -1, errors.New("No hostport specified") +} + +// Returns true if URL u1 represents the same URL than u2, +// false otherwise. +func urlsAreEqual(u1, u2 *url.URL) bool { + return reflect.DeepEqual(u1, u2) +} diff --git a/vendor/github.com/nats-io/gnatsd/util/mkpasswd.go b/vendor/github.com/nats-io/gnatsd/util/mkpasswd.go new file mode 100644 index 00000000000..02cdcb51875 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/util/mkpasswd.go @@ -0,0 +1,87 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build ignore + +package main + +import ( + "bytes" + "crypto/rand" + "flag" + "fmt" + "log" + "math/big" + + "golang.org/x/crypto/bcrypt" + "golang.org/x/crypto/ssh/terminal" +) + +func usage() { + log.Fatalf("Usage: mkpasswd [-p ] [-c COST] \n") +} + +const ( + // Make sure the password is reasonably long to generate enough entropy. + PasswordLength = 22 + // Common advice from the past couple of years suggests that 10 should be sufficient. + // Up that a little, to 11. Feel free to raise this higher if this value from 2015 is + // no longer appropriate. Min is 4, Max is 31. + DefaultCost = 11 +) + +func main() { + var pw = flag.Bool("p", false, "Input password via stdin") + var cost = flag.Int("c", DefaultCost, "The cost weight, range of 4-31 (11)") + + log.SetFlags(0) + flag.Usage = usage + flag.Parse() + + var password string + + if *pw { + fmt.Printf("Enter Password: ") + bytePassword, _ := terminal.ReadPassword(0) + fmt.Printf("\nReenter Password: ") + bytePassword2, _ := terminal.ReadPassword(0) + if !bytes.Equal(bytePassword, bytePassword2) { + log.Fatalf("Error, passwords do not match\n") + } + password = string(bytePassword) + fmt.Printf("\n") + } else { + password = genPassword() + fmt.Printf("pass: %s\n", password) + } + + cb, err := bcrypt.GenerateFromPassword([]byte(password), *cost) + if err != nil { + log.Fatalf("Error producing bcrypt hash: %v\n", err) + } + fmt.Printf("bcrypt hash: %s\n", cb) +} + +func genPassword() string { + var ch = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()") + b := make([]byte, PasswordLength) + max := big.NewInt(int64(len(ch))) + for i := range b { + ri, err := rand.Int(rand.Reader, max) + if err != nil { + log.Fatalf("Error producing random integer: %v\n", err) + } + b[i] = ch[int(ri.Int64())] + } + return string(b) +} diff --git a/vendor/github.com/nats-io/gnatsd/util/tls.go b/vendor/github.com/nats-io/gnatsd/util/tls.go new file mode 100644 index 00000000000..87907eeb606 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/util/tls.go @@ -0,0 +1,25 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +package util + +import ( + "crypto/tls" +) + +// CloneTLSConfig returns a copy of c. +func CloneTLSConfig(c *tls.Config) *tls.Config { + return c.Clone() +} diff --git a/vendor/github.com/nats-io/gnatsd/util/tls_pre17.go b/vendor/github.com/nats-io/gnatsd/util/tls_pre17.go new file mode 100644 index 00000000000..99ea32b410d --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/util/tls_pre17.go @@ -0,0 +1,47 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.5,!go1.7 + +package util + +import ( + "crypto/tls" +) + +// CloneTLSConfig returns a copy of c. Only the exported fields are copied. +// This is temporary, until this is provided by the language. +// https://go-review.googlesource.com/#/c/28075/ +func CloneTLSConfig(c *tls.Config) *tls.Config { + return &tls.Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + } +} diff --git a/vendor/github.com/nats-io/gnatsd/util/tls_pre18.go b/vendor/github.com/nats-io/gnatsd/util/tls_pre18.go new file mode 100644 index 00000000000..7df472617cf --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/util/tls_pre18.go @@ -0,0 +1,49 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.7,!go1.8 + +package util + +import ( + "crypto/tls" +) + +// CloneTLSConfig returns a copy of c. Only the exported fields are copied. +// This is temporary, until this is provided by the language. +// https://go-review.googlesource.com/#/c/28075/ +func CloneTLSConfig(c *tls.Config) *tls.Config { + return &tls.Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + } +} diff --git a/vendor/github.com/nats-io/go-nats-streaming/LICENSE b/vendor/github.com/nats-io/go-nats-streaming/LICENSE new file mode 100644 index 00000000000..f49a4e16e68 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats-streaming/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/vendor/github.com/nats-io/go-nats-streaming/pb/protocol.pb.go b/vendor/github.com/nats-io/go-nats-streaming/pb/protocol.pb.go new file mode 100644 index 00000000000..2182dc9333c --- /dev/null +++ b/vendor/github.com/nats-io/go-nats-streaming/pb/protocol.pb.go @@ -0,0 +1,3383 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: protocol.proto + +/* + Package pb is a generated protocol buffer package. + + It is generated from these files: + protocol.proto + + It has these top-level messages: + PubMsg + PubAck + MsgProto + Ack + ConnectRequest + ConnectResponse + Ping + PingResponse + SubscriptionRequest + SubscriptionResponse + UnsubscribeRequest + CloseRequest + CloseResponse +*/ +package pb + +import proto "github.com/gogo/protobuf/proto" +import fmt "fmt" +import math "math" +import _ "github.com/gogo/protobuf/gogoproto" + +import io "io" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package + +// Enum for start position type. +type StartPosition int32 + +const ( + StartPosition_NewOnly StartPosition = 0 + StartPosition_LastReceived StartPosition = 1 + StartPosition_TimeDeltaStart StartPosition = 2 + StartPosition_SequenceStart StartPosition = 3 + StartPosition_First StartPosition = 4 +) + +var StartPosition_name = map[int32]string{ + 0: "NewOnly", + 1: "LastReceived", + 2: "TimeDeltaStart", + 3: "SequenceStart", + 4: "First", +} +var StartPosition_value = map[string]int32{ + "NewOnly": 0, + "LastReceived": 1, + "TimeDeltaStart": 2, + "SequenceStart": 3, + "First": 4, +} + +func (x StartPosition) String() string { + return proto.EnumName(StartPosition_name, int32(x)) +} +func (StartPosition) EnumDescriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{0} } + +// How messages are delivered to the STAN cluster +type PubMsg struct { + ClientID string `protobuf:"bytes,1,opt,name=clientID,proto3" json:"clientID,omitempty"` + Guid string `protobuf:"bytes,2,opt,name=guid,proto3" json:"guid,omitempty"` + Subject string `protobuf:"bytes,3,opt,name=subject,proto3" json:"subject,omitempty"` + Reply string `protobuf:"bytes,4,opt,name=reply,proto3" json:"reply,omitempty"` + Data []byte `protobuf:"bytes,5,opt,name=data,proto3" json:"data,omitempty"` + ConnID []byte `protobuf:"bytes,6,opt,name=connID,proto3" json:"connID,omitempty"` + Sha256 []byte `protobuf:"bytes,10,opt,name=sha256,proto3" json:"sha256,omitempty"` +} + +func (m *PubMsg) Reset() { *m = PubMsg{} } +func (m *PubMsg) String() string { return proto.CompactTextString(m) } +func (*PubMsg) ProtoMessage() {} +func (*PubMsg) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{0} } + +// Used to ACK to publishers +type PubAck struct { + Guid string `protobuf:"bytes,1,opt,name=guid,proto3" json:"guid,omitempty"` + Error string `protobuf:"bytes,2,opt,name=error,proto3" json:"error,omitempty"` +} + +func (m *PubAck) Reset() { *m = PubAck{} } +func (m *PubAck) String() string { return proto.CompactTextString(m) } +func (*PubAck) ProtoMessage() {} +func (*PubAck) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{1} } + +// Msg struct. Sequence is assigned for global ordering by +// the cluster after the publisher has been acknowledged. +type MsgProto struct { + Sequence uint64 `protobuf:"varint,1,opt,name=sequence,proto3" json:"sequence,omitempty"` + Subject string `protobuf:"bytes,2,opt,name=subject,proto3" json:"subject,omitempty"` + Reply string `protobuf:"bytes,3,opt,name=reply,proto3" json:"reply,omitempty"` + Data []byte `protobuf:"bytes,4,opt,name=data,proto3" json:"data,omitempty"` + Timestamp int64 `protobuf:"varint,5,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + Redelivered bool `protobuf:"varint,6,opt,name=redelivered,proto3" json:"redelivered,omitempty"` + CRC32 uint32 `protobuf:"varint,10,opt,name=CRC32,proto3" json:"CRC32,omitempty"` +} + +func (m *MsgProto) Reset() { *m = MsgProto{} } +func (m *MsgProto) String() string { return proto.CompactTextString(m) } +func (*MsgProto) ProtoMessage() {} +func (*MsgProto) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{2} } + +// Ack will deliver an ack for a delivered msg. +type Ack struct { + Subject string `protobuf:"bytes,1,opt,name=subject,proto3" json:"subject,omitempty"` + Sequence uint64 `protobuf:"varint,2,opt,name=sequence,proto3" json:"sequence,omitempty"` +} + +func (m *Ack) Reset() { *m = Ack{} } +func (m *Ack) String() string { return proto.CompactTextString(m) } +func (*Ack) ProtoMessage() {} +func (*Ack) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{3} } + +// Connection Request +type ConnectRequest struct { + ClientID string `protobuf:"bytes,1,opt,name=clientID,proto3" json:"clientID,omitempty"` + HeartbeatInbox string `protobuf:"bytes,2,opt,name=heartbeatInbox,proto3" json:"heartbeatInbox,omitempty"` + Protocol int32 `protobuf:"varint,3,opt,name=protocol,proto3" json:"protocol,omitempty"` + ConnID []byte `protobuf:"bytes,4,opt,name=connID,proto3" json:"connID,omitempty"` + PingInterval int32 `protobuf:"varint,5,opt,name=pingInterval,proto3" json:"pingInterval,omitempty"` + PingMaxOut int32 `protobuf:"varint,6,opt,name=pingMaxOut,proto3" json:"pingMaxOut,omitempty"` +} + +func (m *ConnectRequest) Reset() { *m = ConnectRequest{} } +func (m *ConnectRequest) String() string { return proto.CompactTextString(m) } +func (*ConnectRequest) ProtoMessage() {} +func (*ConnectRequest) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{4} } + +// Response to a client connect +type ConnectResponse struct { + PubPrefix string `protobuf:"bytes,1,opt,name=pubPrefix,proto3" json:"pubPrefix,omitempty"` + SubRequests string `protobuf:"bytes,2,opt,name=subRequests,proto3" json:"subRequests,omitempty"` + UnsubRequests string `protobuf:"bytes,3,opt,name=unsubRequests,proto3" json:"unsubRequests,omitempty"` + CloseRequests string `protobuf:"bytes,4,opt,name=closeRequests,proto3" json:"closeRequests,omitempty"` + Error string `protobuf:"bytes,5,opt,name=error,proto3" json:"error,omitempty"` + SubCloseRequests string `protobuf:"bytes,6,opt,name=subCloseRequests,proto3" json:"subCloseRequests,omitempty"` + PingRequests string `protobuf:"bytes,7,opt,name=pingRequests,proto3" json:"pingRequests,omitempty"` + PingInterval int32 `protobuf:"varint,8,opt,name=pingInterval,proto3" json:"pingInterval,omitempty"` + PingMaxOut int32 `protobuf:"varint,9,opt,name=pingMaxOut,proto3" json:"pingMaxOut,omitempty"` + Protocol int32 `protobuf:"varint,10,opt,name=protocol,proto3" json:"protocol,omitempty"` + PublicKey string `protobuf:"bytes,100,opt,name=publicKey,proto3" json:"publicKey,omitempty"` +} + +func (m *ConnectResponse) Reset() { *m = ConnectResponse{} } +func (m *ConnectResponse) String() string { return proto.CompactTextString(m) } +func (*ConnectResponse) ProtoMessage() {} +func (*ConnectResponse) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{5} } + +// PING from client to server +type Ping struct { + ConnID []byte `protobuf:"bytes,1,opt,name=connID,proto3" json:"connID,omitempty"` +} + +func (m *Ping) Reset() { *m = Ping{} } +func (m *Ping) String() string { return proto.CompactTextString(m) } +func (*Ping) ProtoMessage() {} +func (*Ping) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{6} } + +// PING response from the server +type PingResponse struct { + Error string `protobuf:"bytes,1,opt,name=error,proto3" json:"error,omitempty"` +} + +func (m *PingResponse) Reset() { *m = PingResponse{} } +func (m *PingResponse) String() string { return proto.CompactTextString(m) } +func (*PingResponse) ProtoMessage() {} +func (*PingResponse) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{7} } + +// Protocol for a client to subscribe +type SubscriptionRequest struct { + ClientID string `protobuf:"bytes,1,opt,name=clientID,proto3" json:"clientID,omitempty"` + Subject string `protobuf:"bytes,2,opt,name=subject,proto3" json:"subject,omitempty"` + QGroup string `protobuf:"bytes,3,opt,name=qGroup,proto3" json:"qGroup,omitempty"` + Inbox string `protobuf:"bytes,4,opt,name=inbox,proto3" json:"inbox,omitempty"` + MaxInFlight int32 `protobuf:"varint,5,opt,name=maxInFlight,proto3" json:"maxInFlight,omitempty"` + AckWaitInSecs int32 `protobuf:"varint,6,opt,name=ackWaitInSecs,proto3" json:"ackWaitInSecs,omitempty"` + DurableName string `protobuf:"bytes,7,opt,name=durableName,proto3" json:"durableName,omitempty"` + StartPosition StartPosition `protobuf:"varint,10,opt,name=startPosition,proto3,enum=pb.StartPosition" json:"startPosition,omitempty"` + StartSequence uint64 `protobuf:"varint,11,opt,name=startSequence,proto3" json:"startSequence,omitempty"` + StartTimeDelta int64 `protobuf:"varint,12,opt,name=startTimeDelta,proto3" json:"startTimeDelta,omitempty"` +} + +func (m *SubscriptionRequest) Reset() { *m = SubscriptionRequest{} } +func (m *SubscriptionRequest) String() string { return proto.CompactTextString(m) } +func (*SubscriptionRequest) ProtoMessage() {} +func (*SubscriptionRequest) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{8} } + +// Response for SubscriptionRequest and UnsubscribeRequests +type SubscriptionResponse struct { + AckInbox string `protobuf:"bytes,2,opt,name=ackInbox,proto3" json:"ackInbox,omitempty"` + Error string `protobuf:"bytes,3,opt,name=error,proto3" json:"error,omitempty"` +} + +func (m *SubscriptionResponse) Reset() { *m = SubscriptionResponse{} } +func (m *SubscriptionResponse) String() string { return proto.CompactTextString(m) } +func (*SubscriptionResponse) ProtoMessage() {} +func (*SubscriptionResponse) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{9} } + +// Protocol for a clients to unsubscribe. Will return a SubscriptionResponse +type UnsubscribeRequest struct { + ClientID string `protobuf:"bytes,1,opt,name=clientID,proto3" json:"clientID,omitempty"` + Subject string `protobuf:"bytes,2,opt,name=subject,proto3" json:"subject,omitempty"` + Inbox string `protobuf:"bytes,3,opt,name=inbox,proto3" json:"inbox,omitempty"` + DurableName string `protobuf:"bytes,4,opt,name=durableName,proto3" json:"durableName,omitempty"` +} + +func (m *UnsubscribeRequest) Reset() { *m = UnsubscribeRequest{} } +func (m *UnsubscribeRequest) String() string { return proto.CompactTextString(m) } +func (*UnsubscribeRequest) ProtoMessage() {} +func (*UnsubscribeRequest) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{10} } + +// Protocol for a client to close a connection +type CloseRequest struct { + ClientID string `protobuf:"bytes,1,opt,name=clientID,proto3" json:"clientID,omitempty"` +} + +func (m *CloseRequest) Reset() { *m = CloseRequest{} } +func (m *CloseRequest) String() string { return proto.CompactTextString(m) } +func (*CloseRequest) ProtoMessage() {} +func (*CloseRequest) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{11} } + +// Response for CloseRequest +type CloseResponse struct { + Error string `protobuf:"bytes,1,opt,name=error,proto3" json:"error,omitempty"` +} + +func (m *CloseResponse) Reset() { *m = CloseResponse{} } +func (m *CloseResponse) String() string { return proto.CompactTextString(m) } +func (*CloseResponse) ProtoMessage() {} +func (*CloseResponse) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{12} } + +func init() { + proto.RegisterType((*PubMsg)(nil), "pb.PubMsg") + proto.RegisterType((*PubAck)(nil), "pb.PubAck") + proto.RegisterType((*MsgProto)(nil), "pb.MsgProto") + proto.RegisterType((*Ack)(nil), "pb.Ack") + proto.RegisterType((*ConnectRequest)(nil), "pb.ConnectRequest") + proto.RegisterType((*ConnectResponse)(nil), "pb.ConnectResponse") + proto.RegisterType((*Ping)(nil), "pb.Ping") + proto.RegisterType((*PingResponse)(nil), "pb.PingResponse") + proto.RegisterType((*SubscriptionRequest)(nil), "pb.SubscriptionRequest") + proto.RegisterType((*SubscriptionResponse)(nil), "pb.SubscriptionResponse") + proto.RegisterType((*UnsubscribeRequest)(nil), "pb.UnsubscribeRequest") + proto.RegisterType((*CloseRequest)(nil), "pb.CloseRequest") + proto.RegisterType((*CloseResponse)(nil), "pb.CloseResponse") + proto.RegisterEnum("pb.StartPosition", StartPosition_name, StartPosition_value) +} +func (m *PubMsg) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *PubMsg) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.ClientID) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ClientID))) + i += copy(dAtA[i:], m.ClientID) + } + if len(m.Guid) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Guid))) + i += copy(dAtA[i:], m.Guid) + } + if len(m.Subject) > 0 { + dAtA[i] = 0x1a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Subject))) + i += copy(dAtA[i:], m.Subject) + } + if len(m.Reply) > 0 { + dAtA[i] = 0x22 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Reply))) + i += copy(dAtA[i:], m.Reply) + } + if len(m.Data) > 0 { + dAtA[i] = 0x2a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Data))) + i += copy(dAtA[i:], m.Data) + } + if len(m.ConnID) > 0 { + dAtA[i] = 0x32 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ConnID))) + i += copy(dAtA[i:], m.ConnID) + } + if len(m.Sha256) > 0 { + dAtA[i] = 0x52 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Sha256))) + i += copy(dAtA[i:], m.Sha256) + } + return i, nil +} + +func (m *PubAck) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *PubAck) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Guid) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Guid))) + i += copy(dAtA[i:], m.Guid) + } + if len(m.Error) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Error))) + i += copy(dAtA[i:], m.Error) + } + return i, nil +} + +func (m *MsgProto) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *MsgProto) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.Sequence != 0 { + dAtA[i] = 0x8 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.Sequence)) + } + if len(m.Subject) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Subject))) + i += copy(dAtA[i:], m.Subject) + } + if len(m.Reply) > 0 { + dAtA[i] = 0x1a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Reply))) + i += copy(dAtA[i:], m.Reply) + } + if len(m.Data) > 0 { + dAtA[i] = 0x22 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Data))) + i += copy(dAtA[i:], m.Data) + } + if m.Timestamp != 0 { + dAtA[i] = 0x28 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.Timestamp)) + } + if m.Redelivered { + dAtA[i] = 0x30 + i++ + if m.Redelivered { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i++ + } + if m.CRC32 != 0 { + dAtA[i] = 0x50 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.CRC32)) + } + return i, nil +} + +func (m *Ack) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Ack) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Subject) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Subject))) + i += copy(dAtA[i:], m.Subject) + } + if m.Sequence != 0 { + dAtA[i] = 0x10 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.Sequence)) + } + return i, nil +} + +func (m *ConnectRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ConnectRequest) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.ClientID) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ClientID))) + i += copy(dAtA[i:], m.ClientID) + } + if len(m.HeartbeatInbox) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.HeartbeatInbox))) + i += copy(dAtA[i:], m.HeartbeatInbox) + } + if m.Protocol != 0 { + dAtA[i] = 0x18 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.Protocol)) + } + if len(m.ConnID) > 0 { + dAtA[i] = 0x22 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ConnID))) + i += copy(dAtA[i:], m.ConnID) + } + if m.PingInterval != 0 { + dAtA[i] = 0x28 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.PingInterval)) + } + if m.PingMaxOut != 0 { + dAtA[i] = 0x30 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.PingMaxOut)) + } + return i, nil +} + +func (m *ConnectResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ConnectResponse) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.PubPrefix) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.PubPrefix))) + i += copy(dAtA[i:], m.PubPrefix) + } + if len(m.SubRequests) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.SubRequests))) + i += copy(dAtA[i:], m.SubRequests) + } + if len(m.UnsubRequests) > 0 { + dAtA[i] = 0x1a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.UnsubRequests))) + i += copy(dAtA[i:], m.UnsubRequests) + } + if len(m.CloseRequests) > 0 { + dAtA[i] = 0x22 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.CloseRequests))) + i += copy(dAtA[i:], m.CloseRequests) + } + if len(m.Error) > 0 { + dAtA[i] = 0x2a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Error))) + i += copy(dAtA[i:], m.Error) + } + if len(m.SubCloseRequests) > 0 { + dAtA[i] = 0x32 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.SubCloseRequests))) + i += copy(dAtA[i:], m.SubCloseRequests) + } + if len(m.PingRequests) > 0 { + dAtA[i] = 0x3a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.PingRequests))) + i += copy(dAtA[i:], m.PingRequests) + } + if m.PingInterval != 0 { + dAtA[i] = 0x40 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.PingInterval)) + } + if m.PingMaxOut != 0 { + dAtA[i] = 0x48 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.PingMaxOut)) + } + if m.Protocol != 0 { + dAtA[i] = 0x50 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.Protocol)) + } + if len(m.PublicKey) > 0 { + dAtA[i] = 0xa2 + i++ + dAtA[i] = 0x6 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.PublicKey))) + i += copy(dAtA[i:], m.PublicKey) + } + return i, nil +} + +func (m *Ping) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Ping) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.ConnID) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ConnID))) + i += copy(dAtA[i:], m.ConnID) + } + return i, nil +} + +func (m *PingResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *PingResponse) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Error) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Error))) + i += copy(dAtA[i:], m.Error) + } + return i, nil +} + +func (m *SubscriptionRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *SubscriptionRequest) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.ClientID) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ClientID))) + i += copy(dAtA[i:], m.ClientID) + } + if len(m.Subject) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Subject))) + i += copy(dAtA[i:], m.Subject) + } + if len(m.QGroup) > 0 { + dAtA[i] = 0x1a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.QGroup))) + i += copy(dAtA[i:], m.QGroup) + } + if len(m.Inbox) > 0 { + dAtA[i] = 0x22 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Inbox))) + i += copy(dAtA[i:], m.Inbox) + } + if m.MaxInFlight != 0 { + dAtA[i] = 0x28 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.MaxInFlight)) + } + if m.AckWaitInSecs != 0 { + dAtA[i] = 0x30 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.AckWaitInSecs)) + } + if len(m.DurableName) > 0 { + dAtA[i] = 0x3a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.DurableName))) + i += copy(dAtA[i:], m.DurableName) + } + if m.StartPosition != 0 { + dAtA[i] = 0x50 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.StartPosition)) + } + if m.StartSequence != 0 { + dAtA[i] = 0x58 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.StartSequence)) + } + if m.StartTimeDelta != 0 { + dAtA[i] = 0x60 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.StartTimeDelta)) + } + return i, nil +} + +func (m *SubscriptionResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *SubscriptionResponse) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.AckInbox) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.AckInbox))) + i += copy(dAtA[i:], m.AckInbox) + } + if len(m.Error) > 0 { + dAtA[i] = 0x1a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Error))) + i += copy(dAtA[i:], m.Error) + } + return i, nil +} + +func (m *UnsubscribeRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *UnsubscribeRequest) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.ClientID) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ClientID))) + i += copy(dAtA[i:], m.ClientID) + } + if len(m.Subject) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Subject))) + i += copy(dAtA[i:], m.Subject) + } + if len(m.Inbox) > 0 { + dAtA[i] = 0x1a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Inbox))) + i += copy(dAtA[i:], m.Inbox) + } + if len(m.DurableName) > 0 { + dAtA[i] = 0x22 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.DurableName))) + i += copy(dAtA[i:], m.DurableName) + } + return i, nil +} + +func (m *CloseRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *CloseRequest) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.ClientID) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ClientID))) + i += copy(dAtA[i:], m.ClientID) + } + return i, nil +} + +func (m *CloseResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *CloseResponse) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Error) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Error))) + i += copy(dAtA[i:], m.Error) + } + return i, nil +} + +func encodeVarintProtocol(dAtA []byte, offset int, v uint64) int { + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return offset + 1 +} +func (m *PubMsg) Size() (n int) { + var l int + _ = l + l = len(m.ClientID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Guid) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Subject) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Reply) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Data) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.ConnID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Sha256) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *PubAck) Size() (n int) { + var l int + _ = l + l = len(m.Guid) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Error) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *MsgProto) Size() (n int) { + var l int + _ = l + if m.Sequence != 0 { + n += 1 + sovProtocol(uint64(m.Sequence)) + } + l = len(m.Subject) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Reply) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Data) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + if m.Timestamp != 0 { + n += 1 + sovProtocol(uint64(m.Timestamp)) + } + if m.Redelivered { + n += 2 + } + if m.CRC32 != 0 { + n += 1 + sovProtocol(uint64(m.CRC32)) + } + return n +} + +func (m *Ack) Size() (n int) { + var l int + _ = l + l = len(m.Subject) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + if m.Sequence != 0 { + n += 1 + sovProtocol(uint64(m.Sequence)) + } + return n +} + +func (m *ConnectRequest) Size() (n int) { + var l int + _ = l + l = len(m.ClientID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.HeartbeatInbox) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + if m.Protocol != 0 { + n += 1 + sovProtocol(uint64(m.Protocol)) + } + l = len(m.ConnID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + if m.PingInterval != 0 { + n += 1 + sovProtocol(uint64(m.PingInterval)) + } + if m.PingMaxOut != 0 { + n += 1 + sovProtocol(uint64(m.PingMaxOut)) + } + return n +} + +func (m *ConnectResponse) Size() (n int) { + var l int + _ = l + l = len(m.PubPrefix) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.SubRequests) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.UnsubRequests) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.CloseRequests) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Error) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.SubCloseRequests) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.PingRequests) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + if m.PingInterval != 0 { + n += 1 + sovProtocol(uint64(m.PingInterval)) + } + if m.PingMaxOut != 0 { + n += 1 + sovProtocol(uint64(m.PingMaxOut)) + } + if m.Protocol != 0 { + n += 1 + sovProtocol(uint64(m.Protocol)) + } + l = len(m.PublicKey) + if l > 0 { + n += 2 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *Ping) Size() (n int) { + var l int + _ = l + l = len(m.ConnID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *PingResponse) Size() (n int) { + var l int + _ = l + l = len(m.Error) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *SubscriptionRequest) Size() (n int) { + var l int + _ = l + l = len(m.ClientID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Subject) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.QGroup) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Inbox) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + if m.MaxInFlight != 0 { + n += 1 + sovProtocol(uint64(m.MaxInFlight)) + } + if m.AckWaitInSecs != 0 { + n += 1 + sovProtocol(uint64(m.AckWaitInSecs)) + } + l = len(m.DurableName) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + if m.StartPosition != 0 { + n += 1 + sovProtocol(uint64(m.StartPosition)) + } + if m.StartSequence != 0 { + n += 1 + sovProtocol(uint64(m.StartSequence)) + } + if m.StartTimeDelta != 0 { + n += 1 + sovProtocol(uint64(m.StartTimeDelta)) + } + return n +} + +func (m *SubscriptionResponse) Size() (n int) { + var l int + _ = l + l = len(m.AckInbox) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Error) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *UnsubscribeRequest) Size() (n int) { + var l int + _ = l + l = len(m.ClientID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Subject) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Inbox) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.DurableName) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *CloseRequest) Size() (n int) { + var l int + _ = l + l = len(m.ClientID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *CloseResponse) Size() (n int) { + var l int + _ = l + l = len(m.Error) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func sovProtocol(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozProtocol(x uint64) (n int) { + return sovProtocol(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *PubMsg) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: PubMsg: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: PubMsg: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ClientID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ClientID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Guid", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Guid = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Subject", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Subject = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Reply", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Reply = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...) + if m.Data == nil { + m.Data = []byte{} + } + iNdEx = postIndex + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ConnID", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ConnID = append(m.ConnID[:0], dAtA[iNdEx:postIndex]...) + if m.ConnID == nil { + m.ConnID = []byte{} + } + iNdEx = postIndex + case 10: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Sha256", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Sha256 = append(m.Sha256[:0], dAtA[iNdEx:postIndex]...) + if m.Sha256 == nil { + m.Sha256 = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *PubAck) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: PubAck: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: PubAck: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Guid", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Guid = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Error", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Error = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *MsgProto) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: MsgProto: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: MsgProto: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Sequence", wireType) + } + m.Sequence = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Sequence |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Subject", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Subject = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Reply", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Reply = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...) + if m.Data == nil { + m.Data = []byte{} + } + iNdEx = postIndex + case 5: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Timestamp", wireType) + } + m.Timestamp = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Timestamp |= (int64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 6: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Redelivered", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Redelivered = bool(v != 0) + case 10: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field CRC32", wireType) + } + m.CRC32 = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.CRC32 |= (uint32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Ack) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Ack: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Ack: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Subject", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Subject = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Sequence", wireType) + } + m.Sequence = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Sequence |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *ConnectRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ConnectRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ConnectRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ClientID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ClientID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field HeartbeatInbox", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.HeartbeatInbox = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Protocol", wireType) + } + m.Protocol = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Protocol |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ConnID", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ConnID = append(m.ConnID[:0], dAtA[iNdEx:postIndex]...) + if m.ConnID == nil { + m.ConnID = []byte{} + } + iNdEx = postIndex + case 5: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field PingInterval", wireType) + } + m.PingInterval = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.PingInterval |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 6: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field PingMaxOut", wireType) + } + m.PingMaxOut = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.PingMaxOut |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *ConnectResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ConnectResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ConnectResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field PubPrefix", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.PubPrefix = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SubRequests", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.SubRequests = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field UnsubRequests", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.UnsubRequests = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field CloseRequests", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.CloseRequests = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Error", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Error = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SubCloseRequests", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.SubCloseRequests = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field PingRequests", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.PingRequests = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 8: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field PingInterval", wireType) + } + m.PingInterval = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.PingInterval |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 9: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field PingMaxOut", wireType) + } + m.PingMaxOut = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.PingMaxOut |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 10: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Protocol", wireType) + } + m.Protocol = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Protocol |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 100: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field PublicKey", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.PublicKey = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Ping) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Ping: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Ping: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ConnID", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ConnID = append(m.ConnID[:0], dAtA[iNdEx:postIndex]...) + if m.ConnID == nil { + m.ConnID = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *PingResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: PingResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: PingResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Error", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Error = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *SubscriptionRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: SubscriptionRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: SubscriptionRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ClientID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ClientID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Subject", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Subject = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field QGroup", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.QGroup = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Inbox", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Inbox = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 5: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field MaxInFlight", wireType) + } + m.MaxInFlight = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.MaxInFlight |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 6: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field AckWaitInSecs", wireType) + } + m.AckWaitInSecs = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.AckWaitInSecs |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field DurableName", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.DurableName = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 10: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field StartPosition", wireType) + } + m.StartPosition = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.StartPosition |= (StartPosition(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 11: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field StartSequence", wireType) + } + m.StartSequence = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.StartSequence |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 12: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field StartTimeDelta", wireType) + } + m.StartTimeDelta = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.StartTimeDelta |= (int64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *SubscriptionResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: SubscriptionResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: SubscriptionResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field AckInbox", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.AckInbox = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Error", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Error = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *UnsubscribeRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: UnsubscribeRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: UnsubscribeRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ClientID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ClientID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Subject", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Subject = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Inbox", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Inbox = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field DurableName", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.DurableName = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *CloseRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: CloseRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: CloseRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ClientID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ClientID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *CloseResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: CloseResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: CloseResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Error", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Error = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipProtocol(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowProtocol + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowProtocol + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowProtocol + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthProtocol + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowProtocol + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipProtocol(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthProtocol = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowProtocol = fmt.Errorf("proto: integer overflow") +) + +func init() { proto.RegisterFile("protocol.proto", fileDescriptorProtocol) } + +var fileDescriptorProtocol = []byte{ + // 823 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x54, 0xc1, 0x8e, 0xe3, 0x44, + 0x10, 0x1d, 0xc7, 0x76, 0x26, 0x53, 0x93, 0x84, 0x6c, 0x13, 0xad, 0xac, 0x08, 0x45, 0x23, 0x6b, + 0x41, 0xab, 0x91, 0xc8, 0x4a, 0x59, 0x01, 0x07, 0x4e, 0x90, 0xd1, 0x42, 0x04, 0xb3, 0x1b, 0x39, + 0x20, 0xae, 0xb4, 0x9d, 0x5e, 0xa7, 0x19, 0xa7, 0xed, 0x75, 0xb7, 0x87, 0xcc, 0x11, 0xbe, 0x80, + 0x0f, 0xe1, 0x23, 0x38, 0xee, 0x81, 0x03, 0x9f, 0x00, 0xc3, 0x8d, 0xaf, 0x40, 0x5d, 0x76, 0x9c, + 0x76, 0x86, 0x1d, 0x90, 0xb8, 0xf5, 0x7b, 0xa9, 0x6e, 0xd7, 0x7b, 0xaf, 0x2a, 0xd0, 0xcf, 0xf2, + 0x54, 0xa5, 0x51, 0x9a, 0x4c, 0xf0, 0x40, 0x5a, 0x59, 0x38, 0x7a, 0x3f, 0xe6, 0x6a, 0x5d, 0x84, + 0x93, 0x28, 0xdd, 0x3c, 0x89, 0xd3, 0x38, 0x7d, 0x82, 0x3f, 0x85, 0xc5, 0x4b, 0x44, 0x08, 0xf0, + 0x54, 0x5e, 0xf1, 0x7f, 0xb6, 0xa0, 0xbd, 0x28, 0xc2, 0x4b, 0x19, 0x93, 0x11, 0x74, 0xa2, 0x84, + 0x33, 0xa1, 0xe6, 0x17, 0x9e, 0x75, 0x66, 0x3d, 0x3e, 0x09, 0x6a, 0x4c, 0x08, 0x38, 0x71, 0xc1, + 0x57, 0x5e, 0x0b, 0x79, 0x3c, 0x13, 0x0f, 0x8e, 0x65, 0x11, 0x7e, 0xc7, 0x22, 0xe5, 0xd9, 0x48, + 0xef, 0x20, 0x19, 0x82, 0x9b, 0xb3, 0x2c, 0xb9, 0xf1, 0x1c, 0xe4, 0x4b, 0xa0, 0xdf, 0x58, 0x51, + 0x45, 0x3d, 0xf7, 0xcc, 0x7a, 0xdc, 0x0d, 0xf0, 0x4c, 0x1e, 0x42, 0x3b, 0x4a, 0x85, 0x98, 0x5f, + 0x78, 0x6d, 0x64, 0x2b, 0xa4, 0x79, 0xb9, 0xa6, 0xd3, 0x0f, 0x3e, 0xf4, 0xa0, 0xe4, 0x4b, 0xe4, + 0x4f, 0xb1, 0xdb, 0x4f, 0xa2, 0xab, 0xba, 0x23, 0xcb, 0xe8, 0x68, 0x08, 0x2e, 0xcb, 0xf3, 0x34, + 0xaf, 0xda, 0x2c, 0x81, 0xff, 0x8b, 0x05, 0x9d, 0x4b, 0x19, 0x2f, 0xd0, 0xa2, 0x11, 0x74, 0x24, + 0x7b, 0x55, 0x30, 0x11, 0x31, 0xbc, 0xea, 0x04, 0x35, 0x36, 0x05, 0xb5, 0xde, 0x20, 0xc8, 0xfe, + 0x27, 0x41, 0x8e, 0x21, 0xe8, 0x1d, 0x38, 0x51, 0x7c, 0xc3, 0xa4, 0xa2, 0x9b, 0x0c, 0x95, 0xda, + 0xc1, 0x9e, 0x20, 0x67, 0x70, 0x9a, 0xb3, 0x15, 0x4b, 0xf8, 0x35, 0xcb, 0xd9, 0x0a, 0x35, 0x77, + 0x02, 0x93, 0xd2, 0x5f, 0x9a, 0x05, 0xb3, 0xa7, 0x53, 0xd4, 0xdd, 0x0b, 0x4a, 0xe0, 0x7f, 0x0c, + 0xb6, 0xd6, 0x6c, 0x34, 0x68, 0x35, 0x1b, 0x34, 0x65, 0xb5, 0x9a, 0xb2, 0xfc, 0x5f, 0x2d, 0xe8, + 0xcf, 0x52, 0x21, 0x58, 0xa4, 0x02, 0xcd, 0x49, 0x75, 0x6f, 0xd4, 0xef, 0x41, 0x7f, 0xcd, 0x68, + 0xae, 0x42, 0x46, 0xd5, 0x5c, 0x84, 0xe9, 0xb6, 0x32, 0xe3, 0x80, 0xd5, 0x6f, 0xec, 0xc6, 0x0f, + 0x6d, 0x71, 0x83, 0x1a, 0x1b, 0xb1, 0x3a, 0x8d, 0x58, 0x7d, 0xe8, 0x66, 0x5c, 0xc4, 0x73, 0xa1, + 0x58, 0x7e, 0x4d, 0x13, 0x34, 0xc8, 0x0d, 0x1a, 0x1c, 0x19, 0x03, 0x68, 0x7c, 0x49, 0xb7, 0x2f, + 0x0a, 0x85, 0x16, 0xb9, 0x81, 0xc1, 0xf8, 0x3f, 0xd8, 0xf0, 0x56, 0x2d, 0x47, 0x66, 0xa9, 0x90, + 0x4c, 0xbb, 0x9e, 0x15, 0xe1, 0x22, 0x67, 0x2f, 0xf9, 0xb6, 0x12, 0xb4, 0x27, 0xb4, 0xeb, 0xb2, + 0x08, 0x2b, 0xed, 0xb2, 0x92, 0x63, 0x52, 0xe4, 0x11, 0xf4, 0x0a, 0x61, 0xd6, 0x94, 0x39, 0x37, + 0x49, 0x5d, 0x15, 0x25, 0xa9, 0x64, 0x75, 0x55, 0x39, 0xde, 0x4d, 0x72, 0x3f, 0x84, 0xae, 0x31, + 0x84, 0xe4, 0x1c, 0x06, 0xb2, 0x08, 0x67, 0x8d, 0xeb, 0x6d, 0x2c, 0xb8, 0xc3, 0xef, 0x5c, 0xaa, + 0xeb, 0x8e, 0xb1, 0xae, 0xc1, 0xdd, 0x71, 0xb2, 0xf3, 0xaf, 0x4e, 0x9e, 0x1c, 0x3a, 0xd9, 0x48, + 0x10, 0x0e, 0x12, 0x2c, 0x1d, 0x4d, 0x78, 0xf4, 0x05, 0xbb, 0xf1, 0x56, 0xb5, 0xa3, 0x25, 0xe1, + 0x8f, 0xc1, 0x59, 0x70, 0x11, 0x1b, 0x39, 0x5b, 0x66, 0xce, 0xfe, 0x23, 0xe8, 0x2e, 0xb0, 0xdb, + 0x2a, 0x9f, 0xda, 0x13, 0xcb, 0x5c, 0xcc, 0xbf, 0x5a, 0xf0, 0xf6, 0xb2, 0x08, 0x65, 0x94, 0xf3, + 0x4c, 0xf1, 0x54, 0xfc, 0x97, 0xe9, 0x7c, 0xf3, 0x8e, 0x3e, 0x84, 0xf6, 0xab, 0xcf, 0xf2, 0xb4, + 0xc8, 0xaa, 0xf0, 0x2a, 0xa4, 0xbf, 0xcd, 0x71, 0x8c, 0xab, 0x3f, 0x23, 0x04, 0x7a, 0x26, 0x36, + 0x74, 0x3b, 0x17, 0xcf, 0x12, 0x1e, 0xaf, 0x55, 0x35, 0x88, 0x26, 0xa5, 0xd3, 0xa6, 0xd1, 0xd5, + 0x37, 0x94, 0xab, 0xb9, 0x58, 0xb2, 0x48, 0x56, 0xa3, 0xd8, 0x24, 0xf5, 0x3b, 0xab, 0x22, 0xa7, + 0x61, 0xc2, 0x9e, 0xd3, 0x0d, 0xab, 0xa2, 0x32, 0x29, 0xf2, 0x11, 0xf4, 0xa4, 0xa2, 0xb9, 0x5a, + 0xa4, 0x92, 0x6b, 0x95, 0x68, 0x75, 0x7f, 0xfa, 0x60, 0x92, 0x85, 0x93, 0xa5, 0xf9, 0x43, 0xd0, + 0xac, 0xd3, 0x0d, 0x20, 0xb1, 0xdc, 0x2d, 0xf6, 0x29, 0x2e, 0x76, 0x93, 0xd4, 0xeb, 0x8a, 0xc4, + 0x57, 0x7c, 0xc3, 0x2e, 0x58, 0xa2, 0xa8, 0xd7, 0xc5, 0x7f, 0x9d, 0x03, 0xd6, 0xff, 0x1c, 0x86, + 0x4d, 0xaf, 0xab, 0x68, 0x46, 0xd0, 0xa1, 0xd1, 0x95, 0xb9, 0xe8, 0x35, 0xde, 0xc7, 0x66, 0x9b, + 0xb1, 0xfd, 0x68, 0x01, 0xf9, 0x5a, 0x2f, 0x86, 0x7e, 0x2c, 0x64, 0xff, 0x2f, 0xb5, 0x3a, 0x1d, + 0xfb, 0x20, 0x1d, 0xd3, 0x55, 0xe7, 0x8e, 0xab, 0xfe, 0x39, 0x74, 0xcd, 0xa5, 0xb9, 0xef, 0xeb, + 0xfe, 0xbb, 0xd0, 0xab, 0x6a, 0xef, 0x1b, 0xc7, 0xf3, 0x6f, 0xa1, 0xd7, 0xc8, 0x83, 0x9c, 0xc2, + 0xf1, 0x73, 0xf6, 0xfd, 0x0b, 0x91, 0xdc, 0x0c, 0x8e, 0xc8, 0x00, 0xba, 0x5f, 0x52, 0xa9, 0x02, + 0x16, 0x31, 0x7e, 0xcd, 0x56, 0x03, 0x8b, 0x10, 0xe8, 0xd7, 0xf6, 0xe2, 0xc5, 0x41, 0x8b, 0x3c, + 0x80, 0xde, 0x2e, 0x99, 0x92, 0xb2, 0xc9, 0x09, 0xb8, 0xcf, 0x78, 0x2e, 0xd5, 0xc0, 0xf9, 0x74, + 0xf8, 0xfa, 0x8f, 0xf1, 0xd1, 0xeb, 0xdb, 0xb1, 0xf5, 0xdb, 0xed, 0xd8, 0xfa, 0xfd, 0x76, 0x6c, + 0xfd, 0xf4, 0xe7, 0xf8, 0x28, 0x6c, 0xe3, 0xd2, 0x3d, 0xfd, 0x3b, 0x00, 0x00, 0xff, 0xff, 0x66, + 0xcd, 0xb9, 0x27, 0xce, 0x07, 0x00, 0x00, +} diff --git a/vendor/github.com/nats-io/go-nats-streaming/stan.go b/vendor/github.com/nats-io/go-nats-streaming/stan.go new file mode 100644 index 00000000000..167ab6dd341 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats-streaming/stan.go @@ -0,0 +1,761 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package stan is a Go client for the NATS Streaming messaging system (https://nats.io). +package stan + +import ( + "errors" + "fmt" + "runtime" + "sync" + "time" + + "github.com/nats-io/go-nats" + "github.com/nats-io/go-nats-streaming/pb" + "github.com/nats-io/nuid" +) + +// Version is the NATS Streaming Go Client version +const Version = "0.4.0" + +const ( + // DefaultNatsURL is the default URL the client connects to + DefaultNatsURL = "nats://localhost:4222" + // DefaultConnectWait is the default timeout used for the connect operation + DefaultConnectWait = 2 * time.Second + // DefaultDiscoverPrefix is the prefix subject used to connect to the NATS Streaming server + DefaultDiscoverPrefix = "_STAN.discover" + // DefaultACKPrefix is the prefix subject used to send ACKs to the NATS Streaming server + DefaultACKPrefix = "_STAN.acks" + // DefaultMaxPubAcksInflight is the default maximum number of published messages + // without outstanding ACKs from the server + DefaultMaxPubAcksInflight = 16384 + // DefaultPingInterval is the default interval (in seconds) at which a connection sends a PING to the server + DefaultPingInterval = 5 + // DefaultPingMaxOut is the number of PINGs without a response before the connection is considered lost. + DefaultPingMaxOut = 3 +) + +// Conn represents a connection to the NATS Streaming subsystem. It can Publish and +// Subscribe to messages within the NATS Streaming cluster. +type Conn interface { + // Publish + Publish(subject string, data []byte) error + PublishAsync(subject string, data []byte, ah AckHandler) (string, error) + + // Subscribe + Subscribe(subject string, cb MsgHandler, opts ...SubscriptionOption) (Subscription, error) + + // QueueSubscribe + QueueSubscribe(subject, qgroup string, cb MsgHandler, opts ...SubscriptionOption) (Subscription, error) + + // Close + Close() error + + // NatsConn returns the underlying NATS conn. Use this with care. For + // example, closing the wrapped NATS conn will put the NATS Streaming Conn + // in an invalid state. + NatsConn() *nats.Conn +} + +const ( + // Client send connID in ConnectRequest and PubMsg, and server + // listens and responds to client PINGs. The validity of the + // connection (based on connID) is checked on incoming PINGs. + protocolOne = int32(1) +) + +// Errors +var ( + ErrConnectReqTimeout = errors.New("stan: connect request timeout") + ErrCloseReqTimeout = errors.New("stan: close request timeout") + ErrSubReqTimeout = errors.New("stan: subscribe request timeout") + ErrUnsubReqTimeout = errors.New("stan: unsubscribe request timeout") + ErrConnectionClosed = errors.New("stan: connection closed") + ErrTimeout = errors.New("stan: publish ack timeout") + ErrBadAck = errors.New("stan: malformed ack") + ErrBadSubscription = errors.New("stan: invalid subscription") + ErrBadConnection = errors.New("stan: invalid connection") + ErrManualAck = errors.New("stan: cannot manually ack in auto-ack mode") + ErrNilMsg = errors.New("stan: nil message") + ErrNoServerSupport = errors.New("stan: not supported by server") + ErrMaxPings = errors.New("stan: connection lost due to PING failure") +) + +var testAllowMillisecInPings = false + +// AckHandler is used for Async Publishing to provide status of the ack. +// The func will be passed the GUID and any error state. No error means the +// message was successfully received by NATS Streaming. +type AckHandler func(string, error) + +// ConnectionLostHandler is used to be notified if the Streaming connection +// is closed due to unexpected errors. +type ConnectionLostHandler func(Conn, error) + +// Options can be used to a create a customized connection. +type Options struct { + NatsURL string + NatsConn *nats.Conn + ConnectTimeout time.Duration + AckTimeout time.Duration + DiscoverPrefix string + MaxPubAcksInflight int + PingIterval int // In seconds + PingMaxOut int + ConnectionLostCB ConnectionLostHandler +} + +// DefaultOptions are the NATS Streaming client's default options +var DefaultOptions = Options{ + NatsURL: DefaultNatsURL, + ConnectTimeout: DefaultConnectWait, + AckTimeout: DefaultAckWait, + DiscoverPrefix: DefaultDiscoverPrefix, + MaxPubAcksInflight: DefaultMaxPubAcksInflight, + PingIterval: DefaultPingInterval, + PingMaxOut: DefaultPingMaxOut, +} + +// Option is a function on the options for a connection. +type Option func(*Options) error + +// NatsURL is an Option to set the URL the client should connect to. +func NatsURL(u string) Option { + return func(o *Options) error { + o.NatsURL = u + return nil + } +} + +// ConnectWait is an Option to set the timeout for establishing a connection. +func ConnectWait(t time.Duration) Option { + return func(o *Options) error { + o.ConnectTimeout = t + return nil + } +} + +// PubAckWait is an Option to set the timeout for waiting for an ACK for a +// published message. +func PubAckWait(t time.Duration) Option { + return func(o *Options) error { + o.AckTimeout = t + return nil + } +} + +// MaxPubAcksInflight is an Option to set the maximum number of published +// messages without outstanding ACKs from the server. +func MaxPubAcksInflight(max int) Option { + return func(o *Options) error { + o.MaxPubAcksInflight = max + return nil + } +} + +// NatsConn is an Option to set the underlying NATS connection to be used +// by a NATS Streaming Conn object. +func NatsConn(nc *nats.Conn) Option { + return func(o *Options) error { + o.NatsConn = nc + return nil + } +} + +// Pings is an Option to set the ping interval and max out values. +// The interval needs to be at least 1 and represents the number +// of seconds. +// The maxOut needs to be at least 2, since the count of sent PINGs +// increase whenever a PING is sent and reset to 0 when a response +// is received. Setting to 1 would cause the library to close the +// connection right away. +func Pings(interval, maxOut int) Option { + return func(o *Options) error { + // For tests, we may pass negative value that will be interpreted + // by the library as milliseconds. If this test boolean is set, + // do not check values. + if !testAllowMillisecInPings { + if interval < 1 || maxOut <= 2 { + return fmt.Errorf("Invalid ping values: interval=%v (min>0) maxOut=%v (min=2)", interval, maxOut) + } + } + o.PingIterval = interval + o.PingMaxOut = maxOut + return nil + } +} + +// SetConnectionLostHandler is an Option to set the connection lost handler. +// This callback will be invoked should the client permanently lose +// contact with the server (or another client replaces it while being +// disconnected). The callback will not be invoked on normal Conn.Close(). +func SetConnectionLostHandler(handler ConnectionLostHandler) Option { + return func(o *Options) error { + o.ConnectionLostCB = handler + return nil + } +} + +// A conn represents a bare connection to a stan cluster. +type conn struct { + sync.RWMutex + clientID string + connID []byte // This is a NUID that uniquely identify connections. + pubPrefix string // Publish prefix set by stan, append our subject. + subRequests string // Subject to send subscription requests. + unsubRequests string // Subject to send unsubscribe requests. + subCloseRequests string // Subject to send subscription close requests. + closeRequests string // Subject to send close requests. + ackSubject string // publish acks + ackSubscription *nats.Subscription + hbSubscription *nats.Subscription + subMap map[string]*subscription + pubAckMap map[string]*ack + pubAckChan chan (struct{}) + opts Options + nc *nats.Conn + ncOwned bool // NATS Streaming created the connection, so needs to close it. + pubNUID *nuid.NUID // NUID generator for published messages. + connLostCB ConnectionLostHandler + + pingMu sync.Mutex + pingSub *nats.Subscription + pingTimer *time.Timer + pingBytes []byte + pingRequests string + pingInbox string + pingInterval time.Duration + pingMaxOut int + pingOut int +} + +// Closure for ack contexts. +type ack struct { + t *time.Timer + ah AckHandler + ch chan error +} + +// Connect will form a connection to the NATS Streaming subsystem. +// Note that clientID can contain only alphanumeric and `-` or `_` characters. +func Connect(stanClusterID, clientID string, options ...Option) (Conn, error) { + // Process Options + c := conn{clientID: clientID, opts: DefaultOptions, connID: []byte(nuid.Next()), pubNUID: nuid.New()} + for _, opt := range options { + if err := opt(&c.opts); err != nil { + return nil, err + } + } + // Check if the user has provided a connection as an option + c.nc = c.opts.NatsConn + // Create a NATS connection if it doesn't exist. + if c.nc == nil { + // We will set the max reconnect attempts to -1 (infinite) + // and the reconnect buffer to -1 to prevent any buffering + // (which may cause a published message to be flushed on + // reconnect while the API may have returned an error due + // to PubAck timeout. + nc, err := nats.Connect(c.opts.NatsURL, + nats.Name(clientID), + nats.MaxReconnects(-1), + nats.ReconnectBufSize(-1)) + if err != nil { + return nil, err + } + c.nc = nc + c.ncOwned = true + } else if !c.nc.IsConnected() { + // Bail if the custom NATS connection is disconnected + return nil, ErrBadConnection + } + + // Create a heartbeat inbox + hbInbox := nats.NewInbox() + var err error + if c.hbSubscription, err = c.nc.Subscribe(hbInbox, c.processHeartBeat); err != nil { + c.Close() + return nil, err + } + + // Prepare a subscription on ping responses, even if we are not + // going to need it, so that if that fails, it fails before initiating + // a connection. + pingSub, err := c.nc.Subscribe(nats.NewInbox(), c.processPingResponse) + if err != nil { + c.Close() + return nil, err + } + + // Send Request to discover the cluster + discoverSubject := c.opts.DiscoverPrefix + "." + stanClusterID + req := &pb.ConnectRequest{ + ClientID: clientID, + HeartbeatInbox: hbInbox, + ConnID: c.connID, + Protocol: protocolOne, + PingInterval: int32(c.opts.PingIterval), + PingMaxOut: int32(c.opts.PingMaxOut), + } + b, _ := req.Marshal() + reply, err := c.nc.Request(discoverSubject, b, c.opts.ConnectTimeout) + if err != nil { + c.Close() + if err == nats.ErrTimeout { + return nil, ErrConnectReqTimeout + } + return nil, err + } + // Process the response, grab server pubPrefix + cr := &pb.ConnectResponse{} + err = cr.Unmarshal(reply.Data) + if err != nil { + c.Close() + return nil, err + } + if cr.Error != "" { + c.Close() + return nil, errors.New(cr.Error) + } + + // Capture cluster configuration endpoints to publish and subscribe/unsubscribe. + c.pubPrefix = cr.PubPrefix + c.subRequests = cr.SubRequests + c.unsubRequests = cr.UnsubRequests + c.subCloseRequests = cr.SubCloseRequests + c.closeRequests = cr.CloseRequests + + // Setup the ACK subscription + c.ackSubject = DefaultACKPrefix + "." + nuid.Next() + if c.ackSubscription, err = c.nc.Subscribe(c.ackSubject, c.processAck); err != nil { + c.Close() + return nil, err + } + c.ackSubscription.SetPendingLimits(1024*1024, 32*1024*1024) + c.pubAckMap = make(map[string]*ack) + + // Create Subscription map + c.subMap = make(map[string]*subscription) + + c.pubAckChan = make(chan struct{}, c.opts.MaxPubAcksInflight) + + // Capture the connection error cb + c.connLostCB = c.opts.ConnectionLostCB + + unsubPingSub := true + // Do this with servers which are at least at protcolOne. + if cr.Protocol >= protocolOne { + // Note that in the future server may override client ping + // interval value sent in ConnectRequest, so use the + // value in ConnectResponse to decide if we send PINGs + // and at what interval. + // In tests, the interval could be negative to indicate + // milliseconds. + if cr.PingInterval != 0 { + unsubPingSub = false + + // These will be immutable. + c.pingRequests = cr.PingRequests + c.pingInbox = pingSub.Subject + // In test, it is possible that we get a negative value + // to represent milliseconds. + if testAllowMillisecInPings && cr.PingInterval < 0 { + c.pingInterval = time.Duration(cr.PingInterval*-1) * time.Millisecond + } else { + // PingInterval is otherwise assumed to be in seconds. + c.pingInterval = time.Duration(cr.PingInterval) * time.Second + } + c.pingMaxOut = int(cr.PingMaxOut) + c.pingBytes, _ = (&pb.Ping{ConnID: c.connID}).Marshal() + c.pingSub = pingSub + // Set the timer now that we are set. Use lock to create + // synchronization point. + c.pingMu.Lock() + c.pingTimer = time.AfterFunc(c.pingInterval, c.pingServer) + c.pingMu.Unlock() + } + } + if unsubPingSub { + pingSub.Unsubscribe() + } + + // Attach a finalizer + runtime.SetFinalizer(&c, func(sc *conn) { sc.Close() }) + + return &c, nil +} + +// Sends a PING (containing the connection's ID) to the server at intervals +// specified by PingInterval option when connection is created. +// Everytime a PING is sent, the number of outstanding PINGs is increased. +// If the total number is > than the PingMaxOut option, then the connection +// is closed, and connection error callback invoked if one was specified. +func (sc *conn) pingServer() { + sc.pingMu.Lock() + // In case the timer fired while we were stopping it. + if sc.pingTimer == nil { + sc.pingMu.Unlock() + return + } + sc.pingOut++ + if sc.pingOut > sc.pingMaxOut { + sc.pingMu.Unlock() + sc.closeDueToPing(ErrMaxPings) + return + } + sc.pingTimer.Reset(sc.pingInterval) + nc := sc.nc + sc.pingMu.Unlock() + // Send the PING now. If the NATS connection is reported closed, + // we are done. + if err := nc.PublishRequest(sc.pingRequests, sc.pingInbox, sc.pingBytes); err == nats.ErrConnectionClosed { + sc.closeDueToPing(err) + } +} + +// Receives PING responses from the server. +// If the response contains an error message, the connection is closed +// and the connection error callback is invoked (if one is specified). +// If no error, the number of ping out is reset to 0. There is no +// decrement by one since for a given PING, the client may received +// many responses when servers are running in channel partitioning mode. +// Regardless, any positive response from the server ought to signal +// that the connection is ok. +func (sc *conn) processPingResponse(m *nats.Msg) { + // No data means OK (we don't have to call Unmarshal) + if len(m.Data) > 0 { + pingResp := &pb.PingResponse{} + if err := pingResp.Unmarshal(m.Data); err != nil { + return + } + if pingResp.Error != "" { + sc.closeDueToPing(errors.New(pingResp.Error)) + return + } + } + // Do not attempt to decrement, simply reset to 0. + sc.pingMu.Lock() + sc.pingOut = 0 + sc.pingMu.Unlock() +} + +// Closes a connection and invoke the connection error callback if one +// was registered when the connection was created. +func (sc *conn) closeDueToPing(err error) { + sc.Lock() + if sc.nc == nil { + sc.Unlock() + return + } + // Stop timer, unsubscribe, fail the pubs, etc.. + sc.cleanupOnClose(err) + // No need to send Close prototol, so simply close the underlying + // NATS connection (if we own it, and if not already closed) + if sc.ncOwned && !sc.nc.IsClosed() { + sc.nc.Close() + } + // Mark this streaming connection as closed. Do this under pingMu lock. + sc.pingMu.Lock() + sc.nc = nil + sc.pingMu.Unlock() + // Capture callback (even though this is immutable). + cb := sc.connLostCB + sc.Unlock() + if cb != nil { + // Execute in separate go routine. + go cb(sc, err) + } +} + +// Do some cleanup when connection is lost or closed. +// Connection lock is held on entry, and sc.nc is guaranteed not to be nil. +func (sc *conn) cleanupOnClose(err error) { + sc.pingMu.Lock() + if sc.pingTimer != nil { + sc.pingTimer.Stop() + sc.pingTimer = nil + } + sc.pingMu.Unlock() + + // Unsubscribe only if the NATS connection is not already closed... + if !sc.nc.IsClosed() { + if sc.ackSubscription != nil { + sc.ackSubscription.Unsubscribe() + } + if sc.pingSub != nil { + sc.pingSub.Unsubscribe() + } + } + // Fail all pending pubs + for guid, pubAck := range sc.pubAckMap { + if pubAck.t != nil { + pubAck.t.Stop() + } + if pubAck.ah != nil { + pubAck.ah(guid, err) + } else if pubAck.ch != nil { + pubAck.ch <- err + } + delete(sc.pubAckMap, guid) + if len(sc.pubAckChan) > 0 { + <-sc.pubAckChan + } + } +} + +// Close a connection to the stan system. +func (sc *conn) Close() error { + sc.Lock() + defer sc.Unlock() + + if sc.nc == nil { + // We are already closed. + return nil + } + + // Capture for NATS calls below. + nc := sc.nc + if sc.ncOwned { + defer nc.Close() + } + + // Now close ourselves. + sc.cleanupOnClose(ErrConnectionClosed) + + // Signals we are closed. + // Do this also under pingMu lock so that we don't need + // to grab sc's lock in pingServer. + sc.pingMu.Lock() + sc.nc = nil + sc.pingMu.Unlock() + + req := &pb.CloseRequest{ClientID: sc.clientID} + b, _ := req.Marshal() + reply, err := nc.Request(sc.closeRequests, b, sc.opts.ConnectTimeout) + if err != nil { + if err == nats.ErrTimeout { + return ErrCloseReqTimeout + } + return err + } + cr := &pb.CloseResponse{} + err = cr.Unmarshal(reply.Data) + if err != nil { + return err + } + if cr.Error != "" { + return errors.New(cr.Error) + } + return nil +} + +// NatsConn returns the underlying NATS conn. Use this with care. For example, +// closing the wrapped NATS conn will put the NATS Streaming Conn in an invalid +// state. +func (sc *conn) NatsConn() *nats.Conn { + sc.RLock() + nc := sc.nc + sc.RUnlock() + return nc +} + +// Process a heartbeat from the NATS Streaming cluster +func (sc *conn) processHeartBeat(m *nats.Msg) { + // No payload assumed, just reply. + sc.RLock() + nc := sc.nc + sc.RUnlock() + if nc != nil { + nc.Publish(m.Reply, nil) + } +} + +// Process an ack from the NATS Streaming cluster +func (sc *conn) processAck(m *nats.Msg) { + pa := &pb.PubAck{} + err := pa.Unmarshal(m.Data) + if err != nil { + panic(fmt.Errorf("Error during ack unmarshal: %v", err)) + } + + // Remove + a := sc.removeAck(pa.Guid) + if a != nil { + // Capture error if it exists. + if pa.Error != "" { + err = errors.New(pa.Error) + } + if a.ah != nil { + // Perform the ackHandler callback + a.ah(pa.Guid, err) + } else if a.ch != nil { + // Send to channel directly + a.ch <- err + } + } +} + +// Publish will publish to the cluster and wait for an ACK. +func (sc *conn) Publish(subject string, data []byte) error { + // Need to make this a buffered channel of 1 in case + // a publish call is blocked in pubAckChan but cleanupOnClose() + // is trying to push the error to this channel. + ch := make(chan error, 1) + _, err := sc.publishAsync(subject, data, nil, ch) + if err == nil { + err = <-ch + } + return err +} + +// PublishAsync will publish to the cluster on pubPrefix+subject and asynchronously +// process the ACK or error state. It will return the GUID for the message being sent. +func (sc *conn) PublishAsync(subject string, data []byte, ah AckHandler) (string, error) { + return sc.publishAsync(subject, data, ah, nil) +} + +func (sc *conn) publishAsync(subject string, data []byte, ah AckHandler, ch chan error) (string, error) { + a := &ack{ah: ah, ch: ch} + sc.Lock() + if sc.nc == nil { + sc.Unlock() + return "", ErrConnectionClosed + } + + subj := sc.pubPrefix + "." + subject + // This is only what we need from PubMsg in the timer below, + // so do this so that pe doesn't escape. + peGUID := sc.pubNUID.Next() + // We send connID regardless of server we connect to. Older server + // will simply not decode it. + pe := &pb.PubMsg{ClientID: sc.clientID, Guid: peGUID, Subject: subject, Data: data, ConnID: sc.connID} + b, _ := pe.Marshal() + + // Map ack to guid. + sc.pubAckMap[peGUID] = a + // snapshot + ackSubject := sc.ackSubject + ackTimeout := sc.opts.AckTimeout + pac := sc.pubAckChan + nc := sc.nc + sc.Unlock() + + // Use the buffered channel to control the number of outstanding acks. + pac <- struct{}{} + + err := nc.PublishRequest(subj, ackSubject, b) + + // Setup the timer for expiration. + sc.Lock() + if err != nil || sc.nc == nil { + sc.Unlock() + // If we got and error on publish or the connection has been closed, + // we need to return an error only if: + // - we can remove the pubAck from the map + // - we can't, but this is an async pub with no provided AckHandler + removed := sc.removeAck(peGUID) != nil + if removed || (ch == nil && ah == nil) { + if err == nil { + err = ErrConnectionClosed + } + return "", err + } + // pubAck was removed from cleanupOnClose() and error will be sent + // to appropriate go channel (ah or ch). + return peGUID, nil + } + a.t = time.AfterFunc(ackTimeout, func() { + pubAck := sc.removeAck(peGUID) + // processAck could get here before and handle the ack. + // If that's the case, we would get nil here and simply return. + if pubAck == nil { + return + } + if pubAck.ah != nil { + pubAck.ah(peGUID, ErrTimeout) + } else if a.ch != nil { + pubAck.ch <- ErrTimeout + } + }) + sc.Unlock() + + return peGUID, nil +} + +// removeAck removes the ack from the pubAckMap and cancels any state, e.g. timers +func (sc *conn) removeAck(guid string) *ack { + var t *time.Timer + sc.Lock() + a := sc.pubAckMap[guid] + if a != nil { + t = a.t + delete(sc.pubAckMap, guid) + } + pac := sc.pubAckChan + sc.Unlock() + + // Cancel timer if needed. + if t != nil { + t.Stop() + } + + // Remove from channel to unblock PublishAsync + if a != nil && len(pac) > 0 { + <-pac + } + return a +} + +// Process an msg from the NATS Streaming cluster +func (sc *conn) processMsg(raw *nats.Msg) { + msg := &Msg{} + err := msg.Unmarshal(raw.Data) + if err != nil { + panic(fmt.Errorf("Error processing unmarshal for msg: %v", err)) + } + // Lookup the subscription + sc.RLock() + nc := sc.nc + isClosed := nc == nil + sub := sc.subMap[raw.Subject] + sc.RUnlock() + + // Check if sub is no longer valid or connection has been closed. + if sub == nil || isClosed { + return + } + + // Store in msg for backlink + msg.Sub = sub + + sub.RLock() + cb := sub.cb + ackSubject := sub.ackInbox + isManualAck := sub.opts.ManualAcks + subsc := sub.sc // Can be nil if sub has been unsubscribed. + sub.RUnlock() + + // Perform the callback + if cb != nil && subsc != nil { + cb(msg) + } + + // Process auto-ack + if !isManualAck && nc != nil { + ack := &pb.Ack{Subject: msg.Subject, Sequence: msg.Sequence} + b, _ := ack.Marshal() + // FIXME(dlc) - Async error handler? Retry? + nc.Publish(ackSubject, b) + } +} diff --git a/vendor/github.com/nats-io/go-nats-streaming/sub.go b/vendor/github.com/nats-io/go-nats-streaming/sub.go new file mode 100644 index 00000000000..f35763a55a1 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats-streaming/sub.go @@ -0,0 +1,472 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package stan is a Go client for the NATS Streaming messaging system (https://nats.io). +package stan + +import ( + "errors" + "sync" + "time" + + "github.com/nats-io/go-nats" + "github.com/nats-io/go-nats-streaming/pb" +) + +const ( + // DefaultAckWait indicates how long the server should wait for an ACK before resending a message + DefaultAckWait = 30 * time.Second + // DefaultMaxInflight indicates how many messages with outstanding ACKs the server can send + DefaultMaxInflight = 1024 +) + +// Msg is the client defined message, which includes proto, then back link to subscription. +type Msg struct { + pb.MsgProto // MsgProto: Seq, Subject, Reply[opt], Data, Timestamp, CRC32[opt] + Sub Subscription +} + +// Subscriptions and Options + +// Subscription represents a subscription within the NATS Streaming cluster. Subscriptions +// will be rate matched and follow at-least delivery semantics. +type Subscription interface { + ClearMaxPending() error + Delivered() (int64, error) + Dropped() (int, error) + IsValid() bool + MaxPending() (int, int, error) + Pending() (int, int, error) + PendingLimits() (int, int, error) + SetPendingLimits(msgLimit, bytesLimit int) error + // Unsubscribe removes interest in the subscription. + // For durables, it means that the durable interest is also removed from + // the server. Restarting a durable with the same name will not resume + // the subscription, it will be considered a new one. + Unsubscribe() error + + // Close removes this subscriber from the server, but unlike Unsubscribe(), + // the durable interest is not removed. If the client has connected to a server + // for which this feature is not available, Close() will return a ErrNoServerSupport + // error. + Close() error +} + +// A subscription represents a subscription to a stan cluster. +type subscription struct { + sync.RWMutex + sc *conn + subject string + qgroup string + inbox string + ackInbox string + inboxSub *nats.Subscription + opts SubscriptionOptions + cb MsgHandler +} + +// SubscriptionOption is a function on the options for a subscription. +type SubscriptionOption func(*SubscriptionOptions) error + +// MsgHandler is a callback function that processes messages delivered to +// asynchronous subscribers. +type MsgHandler func(msg *Msg) + +// SubscriptionOptions are used to control the Subscription's behavior. +type SubscriptionOptions struct { + // DurableName, if set will survive client restarts. + DurableName string + // Controls the number of messages the cluster will have inflight without an ACK. + MaxInflight int + // Controls the time the cluster will wait for an ACK for a given message. + AckWait time.Duration + // StartPosition enum from proto. + StartAt pb.StartPosition + // Optional start sequence number. + StartSequence uint64 + // Optional start time. + StartTime time.Time + // Option to do Manual Acks + ManualAcks bool +} + +// DefaultSubscriptionOptions are the default subscriptions' options +var DefaultSubscriptionOptions = SubscriptionOptions{ + MaxInflight: DefaultMaxInflight, + AckWait: DefaultAckWait, +} + +// MaxInflight is an Option to set the maximum number of messages the cluster will send +// without an ACK. +func MaxInflight(m int) SubscriptionOption { + return func(o *SubscriptionOptions) error { + o.MaxInflight = m + return nil + } +} + +// AckWait is an Option to set the timeout for waiting for an ACK from the cluster's +// point of view for delivered messages. +func AckWait(t time.Duration) SubscriptionOption { + return func(o *SubscriptionOptions) error { + o.AckWait = t + return nil + } +} + +// StartAt sets the desired start position for the message stream. +func StartAt(sp pb.StartPosition) SubscriptionOption { + return func(o *SubscriptionOptions) error { + o.StartAt = sp + return nil + } +} + +// StartAtSequence sets the desired start sequence position and state. +func StartAtSequence(seq uint64) SubscriptionOption { + return func(o *SubscriptionOptions) error { + o.StartAt = pb.StartPosition_SequenceStart + o.StartSequence = seq + return nil + } +} + +// StartAtTime sets the desired start time position and state. +func StartAtTime(start time.Time) SubscriptionOption { + return func(o *SubscriptionOptions) error { + o.StartAt = pb.StartPosition_TimeDeltaStart + o.StartTime = start + return nil + } +} + +// StartAtTimeDelta sets the desired start time position and state using the delta. +func StartAtTimeDelta(ago time.Duration) SubscriptionOption { + return func(o *SubscriptionOptions) error { + o.StartAt = pb.StartPosition_TimeDeltaStart + o.StartTime = time.Now().Add(-ago) + return nil + } +} + +// StartWithLastReceived is a helper function to set start position to last received. +func StartWithLastReceived() SubscriptionOption { + return func(o *SubscriptionOptions) error { + o.StartAt = pb.StartPosition_LastReceived + return nil + } +} + +// DeliverAllAvailable will deliver all messages available. +func DeliverAllAvailable() SubscriptionOption { + return func(o *SubscriptionOptions) error { + o.StartAt = pb.StartPosition_First + return nil + } +} + +// SetManualAckMode will allow clients to control their own acks to delivered messages. +func SetManualAckMode() SubscriptionOption { + return func(o *SubscriptionOptions) error { + o.ManualAcks = true + return nil + } +} + +// DurableName sets the DurableName for the subcriber. +func DurableName(name string) SubscriptionOption { + return func(o *SubscriptionOptions) error { + o.DurableName = name + return nil + } +} + +// Subscribe will perform a subscription with the given options to the NATS Streaming cluster. +func (sc *conn) Subscribe(subject string, cb MsgHandler, options ...SubscriptionOption) (Subscription, error) { + return sc.subscribe(subject, "", cb, options...) +} + +// QueueSubscribe will perform a queue subscription with the given options to the NATS Streaming cluster. +func (sc *conn) QueueSubscribe(subject, qgroup string, cb MsgHandler, options ...SubscriptionOption) (Subscription, error) { + return sc.subscribe(subject, qgroup, cb, options...) +} + +// subscribe will perform a subscription with the given options to the NATS Streaming cluster. +func (sc *conn) subscribe(subject, qgroup string, cb MsgHandler, options ...SubscriptionOption) (Subscription, error) { + sub := &subscription{subject: subject, qgroup: qgroup, inbox: nats.NewInbox(), cb: cb, sc: sc, opts: DefaultSubscriptionOptions} + for _, opt := range options { + if err := opt(&sub.opts); err != nil { + return nil, err + } + } + sc.Lock() + if sc.nc == nil { + sc.Unlock() + return nil, ErrConnectionClosed + } + + // Register subscription. + sc.subMap[sub.inbox] = sub + nc := sc.nc + sc.Unlock() + + // Hold lock throughout. + sub.Lock() + defer sub.Unlock() + + // Listen for actual messages. + nsub, err := nc.Subscribe(sub.inbox, sc.processMsg) + if err != nil { + return nil, err + } + sub.inboxSub = nsub + + // Create a subscription request + // FIXME(dlc) add others. + sr := &pb.SubscriptionRequest{ + ClientID: sc.clientID, + Subject: subject, + QGroup: qgroup, + Inbox: sub.inbox, + MaxInFlight: int32(sub.opts.MaxInflight), + AckWaitInSecs: int32(sub.opts.AckWait / time.Second), + StartPosition: sub.opts.StartAt, + DurableName: sub.opts.DurableName, + } + + // Conditionals + switch sr.StartPosition { + case pb.StartPosition_TimeDeltaStart: + sr.StartTimeDelta = time.Now().UnixNano() - sub.opts.StartTime.UnixNano() + case pb.StartPosition_SequenceStart: + sr.StartSequence = sub.opts.StartSequence + } + + b, _ := sr.Marshal() + reply, err := sc.nc.Request(sc.subRequests, b, sc.opts.ConnectTimeout) + if err != nil { + sub.inboxSub.Unsubscribe() + if err == nats.ErrTimeout { + err = ErrSubReqTimeout + } + return nil, err + } + r := &pb.SubscriptionResponse{} + if err := r.Unmarshal(reply.Data); err != nil { + sub.inboxSub.Unsubscribe() + return nil, err + } + if r.Error != "" { + sub.inboxSub.Unsubscribe() + return nil, errors.New(r.Error) + } + sub.ackInbox = r.AckInbox + + return sub, nil +} + +// ClearMaxPending resets the maximums seen so far. +func (sub *subscription) ClearMaxPending() error { + sub.Lock() + defer sub.Unlock() + if sub.inboxSub == nil { + return ErrBadSubscription + } + return sub.inboxSub.ClearMaxPending() +} + +// Delivered returns the number of delivered messages for this subscription. +func (sub *subscription) Delivered() (int64, error) { + sub.Lock() + defer sub.Unlock() + if sub.inboxSub == nil { + return -1, ErrBadSubscription + } + return sub.inboxSub.Delivered() +} + +// Dropped returns the number of known dropped messages for this subscription. +// This will correspond to messages dropped by violations of PendingLimits. If +// the server declares the connection a SlowConsumer, this number may not be +// valid. +func (sub *subscription) Dropped() (int, error) { + sub.Lock() + defer sub.Unlock() + if sub.inboxSub == nil { + return -1, ErrBadSubscription + } + return sub.inboxSub.Dropped() +} + +// IsValid returns a boolean indicating whether the subscription +// is still active. This will return false if the subscription has +// already been closed. +func (sub *subscription) IsValid() bool { + sub.Lock() + defer sub.Unlock() + if sub.inboxSub == nil { + return false + } + return sub.inboxSub.IsValid() +} + +// MaxPending returns the maximum number of queued messages and queued bytes seen so far. +func (sub *subscription) MaxPending() (int, int, error) { + sub.Lock() + defer sub.Unlock() + if sub.inboxSub == nil { + return -1, -1, ErrBadSubscription + } + return sub.inboxSub.MaxPending() +} + +// Pending returns the number of queued messages and queued bytes in the client for this subscription. +func (sub *subscription) Pending() (int, int, error) { + sub.Lock() + defer sub.Unlock() + if sub.inboxSub == nil { + return -1, -1, ErrBadSubscription + } + return sub.inboxSub.Pending() +} + +// PendingLimits returns the current limits for this subscription. +// If no error is returned, a negative value indicates that the +// given metric is not limited. +func (sub *subscription) PendingLimits() (int, int, error) { + sub.Lock() + defer sub.Unlock() + if sub.inboxSub == nil { + return -1, -1, ErrBadSubscription + } + return sub.inboxSub.PendingLimits() +} + +// SetPendingLimits sets the limits for pending msgs and bytes for this subscription. +// Zero is not allowed. Any negative value means that the given metric is not limited. +func (sub *subscription) SetPendingLimits(msgLimit, bytesLimit int) error { + sub.Lock() + defer sub.Unlock() + if sub.inboxSub == nil { + return ErrBadSubscription + } + return sub.inboxSub.SetPendingLimits(msgLimit, bytesLimit) +} + +// closeOrUnsubscribe performs either close or unsubsribe based on +// given boolean. +func (sub *subscription) closeOrUnsubscribe(doClose bool) error { + sub.Lock() + sc := sub.sc + if sc == nil { + // Already closed. + sub.Unlock() + return ErrBadSubscription + } + sub.sc = nil + sub.inboxSub.Unsubscribe() + sub.inboxSub = nil + sub.Unlock() + + sc.Lock() + if sc.nc == nil { + sc.Unlock() + return ErrConnectionClosed + } + + delete(sc.subMap, sub.inbox) + reqSubject := sc.unsubRequests + if doClose { + reqSubject = sc.subCloseRequests + if reqSubject == "" { + sc.Unlock() + return ErrNoServerSupport + } + } + + // Snapshot connection to avoid data race, since the connection may be + // closing while we try to send the request + nc := sc.nc + sc.Unlock() + + usr := &pb.UnsubscribeRequest{ + ClientID: sc.clientID, + Subject: sub.subject, + Inbox: sub.ackInbox, + } + b, _ := usr.Marshal() + reply, err := nc.Request(reqSubject, b, sc.opts.ConnectTimeout) + if err != nil { + if err == nats.ErrTimeout { + if doClose { + return ErrCloseReqTimeout + } + return ErrUnsubReqTimeout + } + return err + } + r := &pb.SubscriptionResponse{} + if err := r.Unmarshal(reply.Data); err != nil { + return err + } + if r.Error != "" { + return errors.New(r.Error) + } + + return nil +} + +// Unsubscribe implements the Subscription interface +func (sub *subscription) Unsubscribe() error { + return sub.closeOrUnsubscribe(false) +} + +// Close implements the Subscription interface +func (sub *subscription) Close() error { + return sub.closeOrUnsubscribe(true) +} + +// Ack manually acknowledges a message. +// The subscriber had to be created with SetManualAckMode() option. +func (msg *Msg) Ack() error { + if msg == nil { + return ErrNilMsg + } + // Look up subscription (cannot be nil) + sub := msg.Sub.(*subscription) + sub.RLock() + ackSubject := sub.ackInbox + isManualAck := sub.opts.ManualAcks + sc := sub.sc + sub.RUnlock() + + // Check for error conditions. + if !isManualAck { + return ErrManualAck + } + if sc == nil { + return ErrBadSubscription + } + // Get nc from the connection (needs locking to avoid race) + sc.RLock() + nc := sc.nc + sc.RUnlock() + if nc == nil { + return ErrBadConnection + } + + // Ack here. + ack := &pb.Ack{Subject: msg.Subject, Sequence: msg.Sequence} + b, _ := ack.Marshal() + return nc.Publish(ackSubject, b) +} diff --git a/vendor/github.com/nats-io/go-nats/LICENSE b/vendor/github.com/nats-io/go-nats/LICENSE new file mode 100644 index 00000000000..261eeb9e9f8 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/nats-io/go-nats/context.go b/vendor/github.com/nats-io/go-nats/context.go new file mode 100644 index 00000000000..3d753564883 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/context.go @@ -0,0 +1,184 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.7 + +// A Go client for the NATS messaging system (https://nats.io). +package nats + +import ( + "context" + "fmt" + "reflect" +) + +// RequestWithContext takes a context, a subject and payload +// in bytes and request expecting a single response. +func (nc *Conn) RequestWithContext(ctx context.Context, subj string, data []byte) (*Msg, error) { + if ctx == nil { + return nil, ErrInvalidContext + } + if nc == nil { + return nil, ErrInvalidConnection + } + // Check whether the context is done already before making + // the request. + if ctx.Err() != nil { + return nil, ctx.Err() + } + + nc.mu.Lock() + // If user wants the old style. + if nc.Opts.UseOldRequestStyle { + nc.mu.Unlock() + return nc.oldRequestWithContext(ctx, subj, data) + } + + // Do setup for the new style. + if nc.respMap == nil { + // _INBOX wildcard + nc.respSub = fmt.Sprintf("%s.*", NewInbox()) + nc.respMap = make(map[string]chan *Msg) + } + // Create literal Inbox and map to a chan msg. + mch := make(chan *Msg, RequestChanLen) + respInbox := nc.newRespInbox() + token := respToken(respInbox) + nc.respMap[token] = mch + createSub := nc.respMux == nil + ginbox := nc.respSub + nc.mu.Unlock() + + if createSub { + // Make sure scoped subscription is setup only once. + var err error + nc.respSetup.Do(func() { err = nc.createRespMux(ginbox) }) + if err != nil { + return nil, err + } + } + + err := nc.PublishRequest(subj, respInbox, data) + if err != nil { + return nil, err + } + + var ok bool + var msg *Msg + + select { + case msg, ok = <-mch: + if !ok { + return nil, ErrConnectionClosed + } + case <-ctx.Done(): + nc.mu.Lock() + delete(nc.respMap, token) + nc.mu.Unlock() + return nil, ctx.Err() + } + + return msg, nil +} + +// oldRequestWithContext utilizes inbox and subscription per request. +func (nc *Conn) oldRequestWithContext(ctx context.Context, subj string, data []byte) (*Msg, error) { + inbox := NewInbox() + ch := make(chan *Msg, RequestChanLen) + + s, err := nc.subscribe(inbox, _EMPTY_, nil, ch) + if err != nil { + return nil, err + } + s.AutoUnsubscribe(1) + defer s.Unsubscribe() + + err = nc.PublishRequest(subj, inbox, data) + if err != nil { + return nil, err + } + + return s.NextMsgWithContext(ctx) +} + +// NextMsgWithContext takes a context and returns the next message +// available to a synchronous subscriber, blocking until it is delivered +// or context gets canceled. +func (s *Subscription) NextMsgWithContext(ctx context.Context) (*Msg, error) { + if ctx == nil { + return nil, ErrInvalidContext + } + if s == nil { + return nil, ErrBadSubscription + } + if ctx.Err() != nil { + return nil, ctx.Err() + } + + s.mu.Lock() + err := s.validateNextMsgState() + if err != nil { + s.mu.Unlock() + return nil, err + } + + mch := s.mch + s.mu.Unlock() + + var ok bool + var msg *Msg + + select { + case msg, ok = <-mch: + if !ok { + return nil, ErrConnectionClosed + } + err := s.processNextMsgDelivered(msg) + if err != nil { + return nil, err + } + case <-ctx.Done(): + return nil, ctx.Err() + } + + return msg, nil +} + +// RequestWithContext will create an Inbox and perform a Request +// using the provided cancellation context with the Inbox reply +// for the data v. A response will be decoded into the vPtrResponse. +func (c *EncodedConn) RequestWithContext(ctx context.Context, subject string, v interface{}, vPtr interface{}) error { + if ctx == nil { + return ErrInvalidContext + } + + b, err := c.Enc.Encode(subject, v) + if err != nil { + return err + } + m, err := c.Conn.RequestWithContext(ctx, subject, b) + if err != nil { + return err + } + if reflect.TypeOf(vPtr) == emptyMsgType { + mPtr := vPtr.(*Msg) + *mPtr = *m + } else { + err := c.Enc.Decode(m.Subject, m.Data, vPtr) + if err != nil { + return err + } + } + + return nil +} diff --git a/vendor/github.com/nats-io/go-nats/enc.go b/vendor/github.com/nats-io/go-nats/enc.go new file mode 100644 index 00000000000..0f06acc1d50 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/enc.go @@ -0,0 +1,269 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nats + +import ( + "errors" + "fmt" + "reflect" + "sync" + "time" + + // Default Encoders + . "github.com/nats-io/go-nats/encoders/builtin" +) + +// Encoder interface is for all register encoders +type Encoder interface { + Encode(subject string, v interface{}) ([]byte, error) + Decode(subject string, data []byte, vPtr interface{}) error +} + +var encMap map[string]Encoder +var encLock sync.Mutex + +// Indexe names into the Registered Encoders. +const ( + JSON_ENCODER = "json" + GOB_ENCODER = "gob" + DEFAULT_ENCODER = "default" +) + +func init() { + encMap = make(map[string]Encoder) + // Register json, gob and default encoder + RegisterEncoder(JSON_ENCODER, &JsonEncoder{}) + RegisterEncoder(GOB_ENCODER, &GobEncoder{}) + RegisterEncoder(DEFAULT_ENCODER, &DefaultEncoder{}) +} + +// EncodedConn are the preferred way to interface with NATS. They wrap a bare connection to +// a nats server and have an extendable encoder system that will encode and decode messages +// from raw Go types. +type EncodedConn struct { + Conn *Conn + Enc Encoder +} + +// NewEncodedConn will wrap an existing Connection and utilize the appropriate registered +// encoder. +func NewEncodedConn(c *Conn, encType string) (*EncodedConn, error) { + if c == nil { + return nil, errors.New("nats: Nil Connection") + } + if c.IsClosed() { + return nil, ErrConnectionClosed + } + ec := &EncodedConn{Conn: c, Enc: EncoderForType(encType)} + if ec.Enc == nil { + return nil, fmt.Errorf("No encoder registered for '%s'", encType) + } + return ec, nil +} + +// RegisterEncoder will register the encType with the given Encoder. Useful for customization. +func RegisterEncoder(encType string, enc Encoder) { + encLock.Lock() + defer encLock.Unlock() + encMap[encType] = enc +} + +// EncoderForType will return the registered Encoder for the encType. +func EncoderForType(encType string) Encoder { + encLock.Lock() + defer encLock.Unlock() + return encMap[encType] +} + +// Publish publishes the data argument to the given subject. The data argument +// will be encoded using the associated encoder. +func (c *EncodedConn) Publish(subject string, v interface{}) error { + b, err := c.Enc.Encode(subject, v) + if err != nil { + return err + } + return c.Conn.publish(subject, _EMPTY_, b) +} + +// PublishRequest will perform a Publish() expecting a response on the +// reply subject. Use Request() for automatically waiting for a response +// inline. +func (c *EncodedConn) PublishRequest(subject, reply string, v interface{}) error { + b, err := c.Enc.Encode(subject, v) + if err != nil { + return err + } + return c.Conn.publish(subject, reply, b) +} + +// Request will create an Inbox and perform a Request() call +// with the Inbox reply for the data v. A response will be +// decoded into the vPtrResponse. +func (c *EncodedConn) Request(subject string, v interface{}, vPtr interface{}, timeout time.Duration) error { + b, err := c.Enc.Encode(subject, v) + if err != nil { + return err + } + m, err := c.Conn.Request(subject, b, timeout) + if err != nil { + return err + } + if reflect.TypeOf(vPtr) == emptyMsgType { + mPtr := vPtr.(*Msg) + *mPtr = *m + } else { + err = c.Enc.Decode(m.Subject, m.Data, vPtr) + } + return err +} + +// Handler is a specific callback used for Subscribe. It is generalized to +// an interface{}, but we will discover its format and arguments at runtime +// and perform the correct callback, including de-marshaling JSON strings +// back into the appropriate struct based on the signature of the Handler. +// +// Handlers are expected to have one of four signatures. +// +// type person struct { +// Name string `json:"name,omitempty"` +// Age uint `json:"age,omitempty"` +// } +// +// handler := func(m *Msg) +// handler := func(p *person) +// handler := func(subject string, o *obj) +// handler := func(subject, reply string, o *obj) +// +// These forms allow a callback to request a raw Msg ptr, where the processing +// of the message from the wire is untouched. Process a JSON representation +// and demarshal it into the given struct, e.g. person. +// There are also variants where the callback wants either the subject, or the +// subject and the reply subject. +type Handler interface{} + +// Dissect the cb Handler's signature +func argInfo(cb Handler) (reflect.Type, int) { + cbType := reflect.TypeOf(cb) + if cbType.Kind() != reflect.Func { + panic("nats: Handler needs to be a func") + } + numArgs := cbType.NumIn() + if numArgs == 0 { + return nil, numArgs + } + return cbType.In(numArgs - 1), numArgs +} + +var emptyMsgType = reflect.TypeOf(&Msg{}) + +// Subscribe will create a subscription on the given subject and process incoming +// messages using the specified Handler. The Handler should be a func that matches +// a signature from the description of Handler from above. +func (c *EncodedConn) Subscribe(subject string, cb Handler) (*Subscription, error) { + return c.subscribe(subject, _EMPTY_, cb) +} + +// QueueSubscribe will create a queue subscription on the given subject and process +// incoming messages using the specified Handler. The Handler should be a func that +// matches a signature from the description of Handler from above. +func (c *EncodedConn) QueueSubscribe(subject, queue string, cb Handler) (*Subscription, error) { + return c.subscribe(subject, queue, cb) +} + +// Internal implementation that all public functions will use. +func (c *EncodedConn) subscribe(subject, queue string, cb Handler) (*Subscription, error) { + if cb == nil { + return nil, errors.New("nats: Handler required for EncodedConn Subscription") + } + argType, numArgs := argInfo(cb) + if argType == nil { + return nil, errors.New("nats: Handler requires at least one argument") + } + + cbValue := reflect.ValueOf(cb) + wantsRaw := (argType == emptyMsgType) + + natsCB := func(m *Msg) { + var oV []reflect.Value + if wantsRaw { + oV = []reflect.Value{reflect.ValueOf(m)} + } else { + var oPtr reflect.Value + if argType.Kind() != reflect.Ptr { + oPtr = reflect.New(argType) + } else { + oPtr = reflect.New(argType.Elem()) + } + if err := c.Enc.Decode(m.Subject, m.Data, oPtr.Interface()); err != nil { + if c.Conn.Opts.AsyncErrorCB != nil { + c.Conn.ach.push(func() { + c.Conn.Opts.AsyncErrorCB(c.Conn, m.Sub, errors.New("nats: Got an error trying to unmarshal: "+err.Error())) + }) + } + return + } + if argType.Kind() != reflect.Ptr { + oPtr = reflect.Indirect(oPtr) + } + + // Callback Arity + switch numArgs { + case 1: + oV = []reflect.Value{oPtr} + case 2: + subV := reflect.ValueOf(m.Subject) + oV = []reflect.Value{subV, oPtr} + case 3: + subV := reflect.ValueOf(m.Subject) + replyV := reflect.ValueOf(m.Reply) + oV = []reflect.Value{subV, replyV, oPtr} + } + + } + cbValue.Call(oV) + } + + return c.Conn.subscribe(subject, queue, natsCB, nil) +} + +// FlushTimeout allows a Flush operation to have an associated timeout. +func (c *EncodedConn) FlushTimeout(timeout time.Duration) (err error) { + return c.Conn.FlushTimeout(timeout) +} + +// Flush will perform a round trip to the server and return when it +// receives the internal reply. +func (c *EncodedConn) Flush() error { + return c.Conn.Flush() +} + +// Close will close the connection to the server. This call will release +// all blocking calls, such as Flush(), etc. +func (c *EncodedConn) Close() { + c.Conn.Close() +} + +// Drain will put a connection into a drain state. All subscriptions will +// immediately be put into a drain state. Upon completion, the publishers +// will be drained and can not publish any additional messages. Upon draining +// of the publishers, the connection will be closed. Use the ClosedCB() +// option to know when the connection has moved from draining to closed. +func (c *EncodedConn) Drain() error { + return c.Conn.Drain() +} + +// LastError reports the last error encountered via the Connection. +func (c *EncodedConn) LastError() error { + return c.Conn.err +} diff --git a/vendor/github.com/nats-io/go-nats/encoders/builtin/default_enc.go b/vendor/github.com/nats-io/go-nats/encoders/builtin/default_enc.go new file mode 100644 index 00000000000..46d918eea64 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/encoders/builtin/default_enc.go @@ -0,0 +1,117 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package builtin + +import ( + "bytes" + "fmt" + "reflect" + "strconv" + "unsafe" +) + +// DefaultEncoder implementation for EncodedConn. +// This encoder will leave []byte and string untouched, but will attempt to +// turn numbers into appropriate strings that can be decoded. It will also +// propely encoded and decode bools. If will encode a struct, but if you want +// to properly handle structures you should use JsonEncoder. +type DefaultEncoder struct { + // Empty +} + +var trueB = []byte("true") +var falseB = []byte("false") +var nilB = []byte("") + +// Encode +func (je *DefaultEncoder) Encode(subject string, v interface{}) ([]byte, error) { + switch arg := v.(type) { + case string: + bytes := *(*[]byte)(unsafe.Pointer(&arg)) + return bytes, nil + case []byte: + return arg, nil + case bool: + if arg { + return trueB, nil + } else { + return falseB, nil + } + case nil: + return nilB, nil + default: + var buf bytes.Buffer + fmt.Fprintf(&buf, "%+v", arg) + return buf.Bytes(), nil + } +} + +// Decode +func (je *DefaultEncoder) Decode(subject string, data []byte, vPtr interface{}) error { + // Figure out what it's pointing to... + sData := *(*string)(unsafe.Pointer(&data)) + switch arg := vPtr.(type) { + case *string: + *arg = sData + return nil + case *[]byte: + *arg = data + return nil + case *int: + n, err := strconv.ParseInt(sData, 10, 64) + if err != nil { + return err + } + *arg = int(n) + return nil + case *int32: + n, err := strconv.ParseInt(sData, 10, 64) + if err != nil { + return err + } + *arg = int32(n) + return nil + case *int64: + n, err := strconv.ParseInt(sData, 10, 64) + if err != nil { + return err + } + *arg = int64(n) + return nil + case *float32: + n, err := strconv.ParseFloat(sData, 32) + if err != nil { + return err + } + *arg = float32(n) + return nil + case *float64: + n, err := strconv.ParseFloat(sData, 64) + if err != nil { + return err + } + *arg = float64(n) + return nil + case *bool: + b, err := strconv.ParseBool(sData) + if err != nil { + return err + } + *arg = b + return nil + default: + vt := reflect.TypeOf(arg).Elem() + return fmt.Errorf("nats: Default Encoder can't decode to type %s", vt) + } +} diff --git a/vendor/github.com/nats-io/go-nats/encoders/builtin/gob_enc.go b/vendor/github.com/nats-io/go-nats/encoders/builtin/gob_enc.go new file mode 100644 index 00000000000..632bcbd395d --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/encoders/builtin/gob_enc.go @@ -0,0 +1,45 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package builtin + +import ( + "bytes" + "encoding/gob" +) + +// GobEncoder is a Go specific GOB Encoder implementation for EncodedConn. +// This encoder will use the builtin encoding/gob to Marshal +// and Unmarshal most types, including structs. +type GobEncoder struct { + // Empty +} + +// FIXME(dlc) - This could probably be more efficient. + +// Encode +func (ge *GobEncoder) Encode(subject string, v interface{}) ([]byte, error) { + b := new(bytes.Buffer) + enc := gob.NewEncoder(b) + if err := enc.Encode(v); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// Decode +func (ge *GobEncoder) Decode(subject string, data []byte, vPtr interface{}) (err error) { + dec := gob.NewDecoder(bytes.NewBuffer(data)) + err = dec.Decode(vPtr) + return +} diff --git a/vendor/github.com/nats-io/go-nats/encoders/builtin/json_enc.go b/vendor/github.com/nats-io/go-nats/encoders/builtin/json_enc.go new file mode 100644 index 00000000000..c9670f3131d --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/encoders/builtin/json_enc.go @@ -0,0 +1,56 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package builtin + +import ( + "encoding/json" + "strings" +) + +// JsonEncoder is a JSON Encoder implementation for EncodedConn. +// This encoder will use the builtin encoding/json to Marshal +// and Unmarshal most types, including structs. +type JsonEncoder struct { + // Empty +} + +// Encode +func (je *JsonEncoder) Encode(subject string, v interface{}) ([]byte, error) { + b, err := json.Marshal(v) + if err != nil { + return nil, err + } + return b, nil +} + +// Decode +func (je *JsonEncoder) Decode(subject string, data []byte, vPtr interface{}) (err error) { + switch arg := vPtr.(type) { + case *string: + // If they want a string and it is a JSON string, strip quotes + // This allows someone to send a struct but receive as a plain string + // This cast should be efficient for Go 1.3 and beyond. + str := string(data) + if strings.HasPrefix(str, `"`) && strings.HasSuffix(str, `"`) { + *arg = str[1 : len(str)-1] + } else { + *arg = str + } + case *[]byte: + *arg = data + default: + err = json.Unmarshal(data, arg) + } + return +} diff --git a/vendor/github.com/nats-io/go-nats/nats.go b/vendor/github.com/nats-io/go-nats/nats.go new file mode 100644 index 00000000000..0b33a7aa77a --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/nats.go @@ -0,0 +1,3477 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// A Go client for the NATS messaging system (https://nats.io). +package nats + +import ( + "bufio" + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "math/rand" + "net" + "net/url" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/nats-io/go-nats/util" + "github.com/nats-io/nuid" +) + +// Default Constants +const ( + Version = "1.6.0" + DefaultURL = "nats://localhost:4222" + DefaultPort = 4222 + DefaultMaxReconnect = 60 + DefaultReconnectWait = 2 * time.Second + DefaultTimeout = 2 * time.Second + DefaultPingInterval = 2 * time.Minute + DefaultMaxPingOut = 2 + DefaultMaxChanLen = 8192 // 8k + DefaultReconnectBufSize = 8 * 1024 * 1024 // 8MB + RequestChanLen = 8 + DefaultDrainTimeout = 30 * time.Second + LangString = "go" +) + +// STALE_CONNECTION is for detection and proper handling of stale connections. +const STALE_CONNECTION = "stale connection" + +// PERMISSIONS_ERR is for when nats server subject authorization has failed. +const PERMISSIONS_ERR = "permissions violation" + +// AUTHORIZATION_ERR is for when nats server user authorization has failed. +const AUTHORIZATION_ERR = "authorization violation" + +// Errors +var ( + ErrConnectionClosed = errors.New("nats: connection closed") + ErrConnectionDraining = errors.New("nats: connection draining") + ErrDrainTimeout = errors.New("nats: draining connection timed out") + ErrConnectionReconnecting = errors.New("nats: connection reconnecting") + ErrSecureConnRequired = errors.New("nats: secure connection required") + ErrSecureConnWanted = errors.New("nats: secure connection not available") + ErrBadSubscription = errors.New("nats: invalid subscription") + ErrTypeSubscription = errors.New("nats: invalid subscription type") + ErrBadSubject = errors.New("nats: invalid subject") + ErrSlowConsumer = errors.New("nats: slow consumer, messages dropped") + ErrTimeout = errors.New("nats: timeout") + ErrBadTimeout = errors.New("nats: timeout invalid") + ErrAuthorization = errors.New("nats: authorization violation") + ErrNoServers = errors.New("nats: no servers available for connection") + ErrJsonParse = errors.New("nats: connect message, json parse error") + ErrChanArg = errors.New("nats: argument needs to be a channel type") + ErrMaxPayload = errors.New("nats: maximum payload exceeded") + ErrMaxMessages = errors.New("nats: maximum messages delivered") + ErrSyncSubRequired = errors.New("nats: illegal call on an async subscription") + ErrMultipleTLSConfigs = errors.New("nats: multiple tls.Configs not allowed") + ErrNoInfoReceived = errors.New("nats: protocol exception, INFO not received") + ErrReconnectBufExceeded = errors.New("nats: outbound buffer limit exceeded") + ErrInvalidConnection = errors.New("nats: invalid connection") + ErrInvalidMsg = errors.New("nats: invalid message or message nil") + ErrInvalidArg = errors.New("nats: invalid argument") + ErrInvalidContext = errors.New("nats: invalid context") + ErrNoEchoNotSupported = errors.New("nats: no echo option not supported by this server") + ErrStaleConnection = errors.New("nats: " + STALE_CONNECTION) +) + +// GetDefaultOptions returns default configuration options for the client. +func GetDefaultOptions() Options { + return Options{ + AllowReconnect: true, + MaxReconnect: DefaultMaxReconnect, + ReconnectWait: DefaultReconnectWait, + Timeout: DefaultTimeout, + PingInterval: DefaultPingInterval, + MaxPingsOut: DefaultMaxPingOut, + SubChanLen: DefaultMaxChanLen, + ReconnectBufSize: DefaultReconnectBufSize, + DrainTimeout: DefaultDrainTimeout, + } +} + +// DEPRECATED: Use GetDefaultOptions() instead. +// DefaultOptions is not safe for use by multiple clients. +// For details see #308. +var DefaultOptions = GetDefaultOptions() + +// Status represents the state of the connection. +type Status int + +const ( + DISCONNECTED = Status(iota) + CONNECTED + CLOSED + RECONNECTING + CONNECTING + DRAINING_SUBS + DRAINING_PUBS +) + +// ConnHandler is used for asynchronous events such as +// disconnected and closed connections. +type ConnHandler func(*Conn) + +// ErrHandler is used to process asynchronous errors encountered +// while processing inbound messages. +type ErrHandler func(*Conn, *Subscription, error) + +// asyncCB is used to preserve order for async callbacks. +type asyncCB struct { + f func() + next *asyncCB +} + +type asyncCallbacksHandler struct { + mu sync.Mutex + cond *sync.Cond + head *asyncCB + tail *asyncCB +} + +// Option is a function on the options for a connection. +type Option func(*Options) error + +// CustomDialer can be used to specify any dialer, not necessarily +// a *net.Dialer. +type CustomDialer interface { + Dial(network, address string) (net.Conn, error) +} + +// Options can be used to create a customized connection. +type Options struct { + + // Url represents a single NATS server url to which the client + // will be connecting. If the Servers option is also set, it + // then becomes the first server in the Servers array. + Url string + + // Servers is a configured set of servers which this client + // will use when attempting to connect. + Servers []string + + // NoRandomize configures whether we will randomize the + // server pool. + NoRandomize bool + + // NoEcho configures whether the server will echo back messages + // that are sent on this connection if we also have matching subscriptions. + // Note this is supported on servers >= version 1.2. Proto 1 or greater. + NoEcho bool + + // Name is an optional name label which will be sent to the server + // on CONNECT to identify the client. + Name string + + // Verbose signals the server to send an OK ack for commands + // successfully processed by the server. + Verbose bool + + // Pedantic signals the server whether it should be doing further + // validation of subjects. + Pedantic bool + + // Secure enables TLS secure connections that skip server + // verification by default. NOT RECOMMENDED. + Secure bool + + // TLSConfig is a custom TLS configuration to use for secure + // transports. + TLSConfig *tls.Config + + // AllowReconnect enables reconnection logic to be used when we + // encounter a disconnect from the current server. + AllowReconnect bool + + // MaxReconnect sets the number of reconnect attempts that will be + // tried before giving up. If negative, then it will never give up + // trying to reconnect. + MaxReconnect int + + // ReconnectWait sets the time to backoff after attempting a reconnect + // to a server that we were already connected to previously. + ReconnectWait time.Duration + + // Timeout sets the timeout for a Dial operation on a connection. + Timeout time.Duration + + // DrainTimeout sets the timeout for a Drain Operation to complete. + DrainTimeout time.Duration + + // FlusherTimeout is the maximum time to wait for the flusher loop + // to be able to finish writing to the underlying connection. + FlusherTimeout time.Duration + + // PingInterval is the period at which the client will be sending ping + // commands to the server, disabled if 0 or negative. + PingInterval time.Duration + + // MaxPingsOut is the maximum number of pending ping commands that can + // be awaiting a response before raising an ErrStaleConnection error. + MaxPingsOut int + + // ClosedCB sets the closed handler that is called when a client will + // no longer be connected. + ClosedCB ConnHandler + + // DisconnectedCB sets the disconnected handler that is called + // whenever the connection is disconnected. + DisconnectedCB ConnHandler + + // ReconnectedCB sets the reconnected handler called whenever + // the connection is successfully reconnected. + ReconnectedCB ConnHandler + + // DiscoveredServersCB sets the callback that is invoked whenever a new + // server has joined the cluster. + DiscoveredServersCB ConnHandler + + // AsyncErrorCB sets the async error handler (e.g. slow consumer errors) + AsyncErrorCB ErrHandler + + // ReconnectBufSize is the size of the backing bufio during reconnect. + // Once this has been exhausted publish operations will return an error. + ReconnectBufSize int + + // SubChanLen is the size of the buffered channel used between the socket + // Go routine and the message delivery for SyncSubscriptions. + // NOTE: This does not affect AsyncSubscriptions which are + // dictated by PendingLimits() + SubChanLen int + + // User sets the username to be used when connecting to the server. + User string + + // Password sets the password to be used when connecting to a server. + Password string + + // Token sets the token to be used when connecting to a server. + Token string + + // Dialer allows a custom net.Dialer when forming connections. + // DEPRECATED: should use CustomDialer instead. + Dialer *net.Dialer + + // CustomDialer allows to specify a custom dialer (not necessarily + // a *net.Dialer). + CustomDialer CustomDialer + + // UseOldRequestStyle forces the old method of Requests that utilize + // a new Inbox and a new Subscription for each request. + UseOldRequestStyle bool +} + +const ( + // Scratch storage for assembling protocol headers + scratchSize = 512 + + // The size of the bufio reader/writer on top of the socket. + defaultBufSize = 32768 + + // The buffered size of the flush "kick" channel + flushChanSize = 1024 + + // Default server pool size + srvPoolSize = 4 + + // NUID size + nuidSize = 22 + + // Default port used if none is specified in given URL(s) + defaultPortString = "4222" +) + +// A Conn represents a bare connection to a nats-server. +// It can send and receive []byte payloads. +type Conn struct { + // Keep all members for which we use atomic at the beginning of the + // struct and make sure they are all 64bits (or use padding if necessary). + // atomic.* functions crash on 32bit machines if operand is not aligned + // at 64bit. See https://github.com/golang/go/issues/599 + Statistics + mu sync.Mutex + // Opts holds the configuration of the Conn. + // Modifying the configuration of a running Conn is a race. + Opts Options + wg sync.WaitGroup + url *url.URL + conn net.Conn + srvPool []*srv + urls map[string]struct{} // Keep track of all known URLs (used by processInfo) + bw *bufio.Writer + pending *bytes.Buffer + fch chan struct{} + info serverInfo + ssid int64 + subsMu sync.RWMutex + subs map[int64]*Subscription + ach *asyncCallbacksHandler + pongs []chan struct{} + scratch [scratchSize]byte + status Status + initc bool // true if the connection is performing the initial connect + err error + ps *parseState + ptmr *time.Timer + pout int + + // New style response handler + respSub string // The wildcard subject + respMux *Subscription // A single response subscription + respMap map[string]chan *Msg // Request map for the response msg channels + respSetup sync.Once // Ensures response subscription occurs once +} + +// A Subscription represents interest in a given subject. +type Subscription struct { + mu sync.Mutex + sid int64 + + // Subject that represents this subscription. This can be different + // than the received subject inside a Msg if this is a wildcard. + Subject string + + // Optional queue group name. If present, all subscriptions with the + // same name will form a distributed queue, and each message will + // only be processed by one member of the group. + Queue string + + delivered uint64 + max uint64 + conn *Conn + mcb MsgHandler + mch chan *Msg + closed bool + sc bool + connClosed bool + + // Type of Subscription + typ SubscriptionType + + // Async linked list + pHead *Msg + pTail *Msg + pCond *sync.Cond + + // Pending stats, async subscriptions, high-speed etc. + pMsgs int + pBytes int + pMsgsMax int + pBytesMax int + pMsgsLimit int + pBytesLimit int + dropped int +} + +// Msg is a structure used by Subscribers and PublishMsg(). +type Msg struct { + Subject string + Reply string + Data []byte + Sub *Subscription + next *Msg + barrier *barrierInfo +} + +type barrierInfo struct { + refs int64 + f func() +} + +// Tracks various stats received and sent on this connection, +// including counts for messages and bytes. +type Statistics struct { + InMsgs uint64 + OutMsgs uint64 + InBytes uint64 + OutBytes uint64 + Reconnects uint64 +} + +// Tracks individual backend servers. +type srv struct { + url *url.URL + didConnect bool + reconnects int + lastAttempt time.Time + isImplicit bool +} + +type serverInfo struct { + Id string `json:"server_id"` + Host string `json:"host"` + Port uint `json:"port"` + Version string `json:"version"` + AuthRequired bool `json:"auth_required"` + TLSRequired bool `json:"tls_required"` + MaxPayload int64 `json:"max_payload"` + ConnectURLs []string `json:"connect_urls,omitempty"` + Proto int `json:"proto,omitempty"` +} + +const ( + // clientProtoZero is the original client protocol from 2009. + // http://nats.io/documentation/internals/nats-protocol/ + /* clientProtoZero */ _ = iota + // clientProtoInfo signals a client can receive more then the original INFO block. + // This can be used to update clients on other cluster members, etc. + clientProtoInfo +) + +type connectInfo struct { + Verbose bool `json:"verbose"` + Pedantic bool `json:"pedantic"` + User string `json:"user,omitempty"` + Pass string `json:"pass,omitempty"` + Token string `json:"auth_token,omitempty"` + TLS bool `json:"tls_required"` + Name string `json:"name"` + Lang string `json:"lang"` + Version string `json:"version"` + Protocol int `json:"protocol"` + Echo bool `json:"echo"` +} + +// MsgHandler is a callback function that processes messages delivered to +// asynchronous subscribers. +type MsgHandler func(msg *Msg) + +// Connect will attempt to connect to the NATS system. +// The url can contain username/password semantics. e.g. nats://derek:pass@localhost:4222 +// Comma separated arrays are also supported, e.g. urlA, urlB. +// Options start with the defaults but can be overridden. +func Connect(url string, options ...Option) (*Conn, error) { + opts := GetDefaultOptions() + opts.Servers = processUrlString(url) + for _, opt := range options { + if opt != nil { + if err := opt(&opts); err != nil { + return nil, err + } + } + } + return opts.Connect() +} + +// Options that can be passed to Connect. + +// Name is an Option to set the client name. +func Name(name string) Option { + return func(o *Options) error { + o.Name = name + return nil + } +} + +// Secure is an Option to enable TLS secure connections that skip server verification by default. +// Pass a TLS Configuration for proper TLS. +func Secure(tls ...*tls.Config) Option { + return func(o *Options) error { + o.Secure = true + // Use of variadic just simplifies testing scenarios. We only take the first one. + // fixme(DLC) - Could panic if more than one. Could also do TLS option. + if len(tls) > 1 { + return ErrMultipleTLSConfigs + } + if len(tls) == 1 { + o.TLSConfig = tls[0] + } + return nil + } +} + +// RootCAs is a helper option to provide the RootCAs pool from a list of filenames. If Secure is +// not already set this will set it as well. +func RootCAs(file ...string) Option { + return func(o *Options) error { + pool := x509.NewCertPool() + for _, f := range file { + rootPEM, err := ioutil.ReadFile(f) + if err != nil || rootPEM == nil { + return fmt.Errorf("nats: error loading or parsing rootCA file: %v", err) + } + ok := pool.AppendCertsFromPEM(rootPEM) + if !ok { + return fmt.Errorf("nats: failed to parse root certificate from %q", f) + } + } + if o.TLSConfig == nil { + o.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + o.TLSConfig.RootCAs = pool + o.Secure = true + return nil + } +} + +// ClientCert is a helper option to provide the client certificate from a file. If Secure is +// not already set this will set it as well +func ClientCert(certFile, keyFile string) Option { + return func(o *Options) error { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return fmt.Errorf("nats: error loading client certificate: %v", err) + } + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return fmt.Errorf("nats: error parsing client certificate: %v", err) + } + if o.TLSConfig == nil { + o.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + o.TLSConfig.Certificates = []tls.Certificate{cert} + o.Secure = true + return nil + } +} + +// NoReconnect is an Option to turn off reconnect behavior. +func NoReconnect() Option { + return func(o *Options) error { + o.AllowReconnect = false + return nil + } +} + +// DontRandomize is an Option to turn off randomizing the server pool. +func DontRandomize() Option { + return func(o *Options) error { + o.NoRandomize = true + return nil + } +} + +// NoEcho is an Option to turn off messages echoing back from a server. +// Note this is supported on servers >= version 1.2. Proto 1 or greater. +func NoEcho() Option { + return func(o *Options) error { + o.NoEcho = true + return nil + } +} + +// ReconnectWait is an Option to set the wait time between reconnect attempts. +func ReconnectWait(t time.Duration) Option { + return func(o *Options) error { + o.ReconnectWait = t + return nil + } +} + +// MaxReconnects is an Option to set the maximum number of reconnect attempts. +func MaxReconnects(max int) Option { + return func(o *Options) error { + o.MaxReconnect = max + return nil + } +} + +// PingInterval is an Option to set the period for client ping commands +func PingInterval(t time.Duration) Option { + return func(o *Options) error { + o.PingInterval = t + return nil + } +} + +// ReconnectBufSize sets the buffer size of messages kept while busy reconnecting +func ReconnectBufSize(size int) Option { + return func(o *Options) error { + o.ReconnectBufSize = size + return nil + } +} + +// Timeout is an Option to set the timeout for Dial on a connection. +func Timeout(t time.Duration) Option { + return func(o *Options) error { + o.Timeout = t + return nil + } +} + +// DrainTimeout is an Option to set the timeout for draining a connection. +func DrainTimeout(t time.Duration) Option { + return func(o *Options) error { + o.DrainTimeout = t + return nil + } +} + +// DisconnectHandler is an Option to set the disconnected handler. +func DisconnectHandler(cb ConnHandler) Option { + return func(o *Options) error { + o.DisconnectedCB = cb + return nil + } +} + +// ReconnectHandler is an Option to set the reconnected handler. +func ReconnectHandler(cb ConnHandler) Option { + return func(o *Options) error { + o.ReconnectedCB = cb + return nil + } +} + +// ClosedHandler is an Option to set the closed handler. +func ClosedHandler(cb ConnHandler) Option { + return func(o *Options) error { + o.ClosedCB = cb + return nil + } +} + +// DiscoveredServersHandler is an Option to set the new servers handler. +func DiscoveredServersHandler(cb ConnHandler) Option { + return func(o *Options) error { + o.DiscoveredServersCB = cb + return nil + } +} + +// ErrorHandler is an Option to set the async error handler. +func ErrorHandler(cb ErrHandler) Option { + return func(o *Options) error { + o.AsyncErrorCB = cb + return nil + } +} + +// UserInfo is an Option to set the username and password to +// use when not included directly in the URLs. +func UserInfo(user, password string) Option { + return func(o *Options) error { + o.User = user + o.Password = password + return nil + } +} + +// Token is an Option to set the token to use when not included +// directly in the URLs. +func Token(token string) Option { + return func(o *Options) error { + o.Token = token + return nil + } +} + +// Dialer is an Option to set the dialer which will be used when +// attempting to establish a connection. +// DEPRECATED: Should use CustomDialer instead. +func Dialer(dialer *net.Dialer) Option { + return func(o *Options) error { + o.Dialer = dialer + return nil + } +} + +// SetCustomDialer is an Option to set a custom dialer which will be +// used when attempting to establish a connection. If both Dialer +// and CustomDialer are specified, CustomDialer takes precedence. +func SetCustomDialer(dialer CustomDialer) Option { + return func(o *Options) error { + o.CustomDialer = dialer + return nil + } +} + +// UseOldRequestStyle is an Option to force usage of the old Request style. +func UseOldRequestStyle() Option { + return func(o *Options) error { + o.UseOldRequestStyle = true + return nil + } +} + +// Handler processing + +// SetDisconnectHandler will set the disconnect event handler. +func (nc *Conn) SetDisconnectHandler(dcb ConnHandler) { + if nc == nil { + return + } + nc.mu.Lock() + defer nc.mu.Unlock() + nc.Opts.DisconnectedCB = dcb +} + +// SetReconnectHandler will set the reconnect event handler. +func (nc *Conn) SetReconnectHandler(rcb ConnHandler) { + if nc == nil { + return + } + nc.mu.Lock() + defer nc.mu.Unlock() + nc.Opts.ReconnectedCB = rcb +} + +// SetDiscoveredServersHandler will set the discovered servers handler. +func (nc *Conn) SetDiscoveredServersHandler(dscb ConnHandler) { + if nc == nil { + return + } + nc.mu.Lock() + defer nc.mu.Unlock() + nc.Opts.DiscoveredServersCB = dscb +} + +// SetClosedHandler will set the reconnect event handler. +func (nc *Conn) SetClosedHandler(cb ConnHandler) { + if nc == nil { + return + } + nc.mu.Lock() + defer nc.mu.Unlock() + nc.Opts.ClosedCB = cb +} + +// SetErrorHandler will set the async error handler. +func (nc *Conn) SetErrorHandler(cb ErrHandler) { + if nc == nil { + return + } + nc.mu.Lock() + defer nc.mu.Unlock() + nc.Opts.AsyncErrorCB = cb +} + +// Process the url string argument to Connect. Return an array of +// urls, even if only one. +func processUrlString(url string) []string { + urls := strings.Split(url, ",") + for i, s := range urls { + urls[i] = strings.TrimSpace(s) + } + return urls +} + +// Connect will attempt to connect to a NATS server with multiple options. +func (o Options) Connect() (*Conn, error) { + nc := &Conn{Opts: o} + + // Some default options processing. + if nc.Opts.MaxPingsOut == 0 { + nc.Opts.MaxPingsOut = DefaultMaxPingOut + } + // Allow old default for channel length to work correctly. + if nc.Opts.SubChanLen == 0 { + nc.Opts.SubChanLen = DefaultMaxChanLen + } + // Default ReconnectBufSize + if nc.Opts.ReconnectBufSize == 0 { + nc.Opts.ReconnectBufSize = DefaultReconnectBufSize + } + // Ensure that Timeout is not 0 + if nc.Opts.Timeout == 0 { + nc.Opts.Timeout = DefaultTimeout + } + + // Allow custom Dialer for connecting using DialTimeout by default + if nc.Opts.Dialer == nil { + nc.Opts.Dialer = &net.Dialer{ + Timeout: nc.Opts.Timeout, + } + } + + if err := nc.setupServerPool(); err != nil { + return nil, err + } + + // Create the async callback handler. + nc.ach = &asyncCallbacksHandler{} + nc.ach.cond = sync.NewCond(&nc.ach.mu) + + if err := nc.connect(); err != nil { + return nil, err + } + + // Spin up the async cb dispatcher on success + go nc.ach.asyncCBDispatcher() + + return nc, nil +} + +const ( + _CRLF_ = "\r\n" + _EMPTY_ = "" + _SPC_ = " " + _PUB_P_ = "PUB " +) + +const ( + _OK_OP_ = "+OK" + _ERR_OP_ = "-ERR" + _PONG_OP_ = "PONG" + _INFO_OP_ = "INFO" +) + +const ( + conProto = "CONNECT %s" + _CRLF_ + pingProto = "PING" + _CRLF_ + pongProto = "PONG" + _CRLF_ + subProto = "SUB %s %s %d" + _CRLF_ + unsubProto = "UNSUB %d %s" + _CRLF_ + okProto = _OK_OP_ + _CRLF_ +) + +// Return the currently selected server +func (nc *Conn) currentServer() (int, *srv) { + for i, s := range nc.srvPool { + if s == nil { + continue + } + if s.url == nc.url { + return i, s + } + } + return -1, nil +} + +// Pop the current server and put onto the end of the list. Select head of list as long +// as number of reconnect attempts under MaxReconnect. +func (nc *Conn) selectNextServer() (*srv, error) { + i, s := nc.currentServer() + if i < 0 { + return nil, ErrNoServers + } + sp := nc.srvPool + num := len(sp) + copy(sp[i:num-1], sp[i+1:num]) + maxReconnect := nc.Opts.MaxReconnect + if maxReconnect < 0 || s.reconnects < maxReconnect { + nc.srvPool[num-1] = s + } else { + nc.srvPool = sp[0 : num-1] + } + if len(nc.srvPool) <= 0 { + nc.url = nil + return nil, ErrNoServers + } + nc.url = nc.srvPool[0].url + return nc.srvPool[0], nil +} + +// Will assign the correct server to the nc.Url +func (nc *Conn) pickServer() error { + nc.url = nil + if len(nc.srvPool) <= 0 { + return ErrNoServers + } + for _, s := range nc.srvPool { + if s != nil { + nc.url = s.url + return nil + } + } + return ErrNoServers +} + +const tlsScheme = "tls" + +// Create the server pool using the options given. +// We will place a Url option first, followed by any +// Server Options. We will randomize the server pool unless +// the NoRandomize flag is set. +func (nc *Conn) setupServerPool() error { + nc.srvPool = make([]*srv, 0, srvPoolSize) + nc.urls = make(map[string]struct{}, srvPoolSize) + + // Create srv objects from each url string in nc.Opts.Servers + // and add them to the pool + for _, urlString := range nc.Opts.Servers { + if err := nc.addURLToPool(urlString, false); err != nil { + return err + } + } + + // Randomize if allowed to + if !nc.Opts.NoRandomize { + nc.shufflePool() + } + + // Normally, if this one is set, Options.Servers should not be, + // but we always allowed that, so continue to do so. + if nc.Opts.Url != _EMPTY_ { + // Add to the end of the array + if err := nc.addURLToPool(nc.Opts.Url, false); err != nil { + return err + } + // Then swap it with first to guarantee that Options.Url is tried first. + last := len(nc.srvPool) - 1 + if last > 0 { + nc.srvPool[0], nc.srvPool[last] = nc.srvPool[last], nc.srvPool[0] + } + } else if len(nc.srvPool) <= 0 { + // Place default URL if pool is empty. + if err := nc.addURLToPool(DefaultURL, false); err != nil { + return err + } + } + + // Check for Scheme hint to move to TLS mode. + for _, srv := range nc.srvPool { + if srv.url.Scheme == tlsScheme { + // FIXME(dlc), this is for all in the pool, should be case by case. + nc.Opts.Secure = true + if nc.Opts.TLSConfig == nil { + nc.Opts.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + } + } + + return nc.pickServer() +} + +// addURLToPool adds an entry to the server pool +func (nc *Conn) addURLToPool(sURL string, implicit bool) error { + if !strings.Contains(sURL, "://") { + sURL = "nats://" + sURL + } + var ( + u *url.URL + err error + ) + for i := 0; i < 2; i++ { + u, err = url.Parse(sURL) + if err != nil { + return err + } + if u.Port() != "" { + break + } + // In case given URL is of the form "localhost:", just add + // the port number at the end, otherwise, add ":4222". + if sURL[len(sURL)-1] != ':' { + sURL += ":" + } + sURL += defaultPortString + } + s := &srv{url: u, isImplicit: implicit} + nc.srvPool = append(nc.srvPool, s) + nc.urls[u.Host] = struct{}{} + return nil +} + +// shufflePool swaps randomly elements in the server pool +func (nc *Conn) shufflePool() { + if len(nc.srvPool) <= 1 { + return + } + source := rand.NewSource(time.Now().UnixNano()) + r := rand.New(source) + for i := range nc.srvPool { + j := r.Intn(i + 1) + nc.srvPool[i], nc.srvPool[j] = nc.srvPool[j], nc.srvPool[i] + } +} + +// createConn will connect to the server and wrap the appropriate +// bufio structures. It will do the right thing when an existing +// connection is in place. +func (nc *Conn) createConn() (err error) { + if nc.Opts.Timeout < 0 { + return ErrBadTimeout + } + if _, cur := nc.currentServer(); cur == nil { + return ErrNoServers + } else { + cur.lastAttempt = time.Now() + } + + // CustomDialer takes precedence. If not set, use Opts.Dialer which + // is set to a default *net.Dialer (in Connect()) if not explicitly + // set by the user. + dialer := nc.Opts.CustomDialer + if dialer == nil { + dialer = nc.Opts.Dialer + } + nc.conn, err = dialer.Dial("tcp", nc.url.Host) + if err != nil { + return err + } + + // No clue why, but this stalls and kills performance on Mac (Mavericks). + // https://code.google.com/p/go/issues/detail?id=6930 + //if ip, ok := nc.conn.(*net.TCPConn); ok { + // ip.SetReadBuffer(defaultBufSize) + //} + + if nc.pending != nil && nc.bw != nil { + // Move to pending buffer. + nc.bw.Flush() + } + nc.bw = bufio.NewWriterSize(nc.conn, defaultBufSize) + return nil +} + +// makeTLSConn will wrap an existing Conn using TLS +func (nc *Conn) makeTLSConn() { + // Allow the user to configure their own tls.Config structure, otherwise + // default to InsecureSkipVerify. + // TODO(dlc) - We should make the more secure version the default. + if nc.Opts.TLSConfig != nil { + tlsCopy := util.CloneTLSConfig(nc.Opts.TLSConfig) + // If its blank we will override it with the current host + if tlsCopy.ServerName == _EMPTY_ { + h, _, _ := net.SplitHostPort(nc.url.Host) + tlsCopy.ServerName = h + } + nc.conn = tls.Client(nc.conn, tlsCopy) + } else { + nc.conn = tls.Client(nc.conn, &tls.Config{InsecureSkipVerify: true}) + } + conn := nc.conn.(*tls.Conn) + conn.Handshake() + nc.bw = bufio.NewWriterSize(nc.conn, defaultBufSize) +} + +// waitForExits will wait for all socket watcher Go routines to +// be shutdown before proceeding. +func (nc *Conn) waitForExits() { + // Kick old flusher forcefully. + select { + case nc.fch <- struct{}{}: + default: + } + + // Wait for any previous go routines. + nc.wg.Wait() +} + +// Report the connected server's Url +func (nc *Conn) ConnectedUrl() string { + if nc == nil { + return _EMPTY_ + } + nc.mu.Lock() + defer nc.mu.Unlock() + if nc.status != CONNECTED { + return _EMPTY_ + } + return nc.url.String() +} + +// Report the connected server's Id +func (nc *Conn) ConnectedServerId() string { + if nc == nil { + return _EMPTY_ + } + nc.mu.Lock() + defer nc.mu.Unlock() + if nc.status != CONNECTED { + return _EMPTY_ + } + return nc.info.Id +} + +// Low level setup for structs, etc +func (nc *Conn) setup() { + nc.subs = make(map[int64]*Subscription) + nc.pongs = make([]chan struct{}, 0, 8) + + nc.fch = make(chan struct{}, flushChanSize) + + // Setup scratch outbound buffer for PUB + pub := nc.scratch[:len(_PUB_P_)] + copy(pub, _PUB_P_) +} + +// Process a connected connection and initialize properly. +func (nc *Conn) processConnectInit() error { + + // Set our deadline for the whole connect process + nc.conn.SetDeadline(time.Now().Add(nc.Opts.Timeout)) + defer nc.conn.SetDeadline(time.Time{}) + + // Set our status to connecting. + nc.status = CONNECTING + + // Process the INFO protocol received from the server + err := nc.processExpectedInfo() + if err != nil { + return err + } + + // Send the CONNECT protocol along with the initial PING protocol. + // Wait for the PONG response (or any error that we get from the server). + err = nc.sendConnect() + if err != nil { + return err + } + + // Reset the number of PING sent out + nc.pout = 0 + + // Start or reset Timer + if nc.Opts.PingInterval > 0 { + if nc.ptmr == nil { + nc.ptmr = time.AfterFunc(nc.Opts.PingInterval, nc.processPingTimer) + } else { + nc.ptmr.Reset(nc.Opts.PingInterval) + } + } + + // Start the readLoop and flusher go routines, we will wait on both on a reconnect event. + nc.wg.Add(2) + go nc.readLoop() + go nc.flusher() + + return nil +} + +// Main connect function. Will connect to the nats-server +func (nc *Conn) connect() error { + var returnedErr error + + // Create actual socket connection + // For first connect we walk all servers in the pool and try + // to connect immediately. + nc.mu.Lock() + nc.initc = true + // The pool may change inside the loop iteration due to INFO protocol. + for i := 0; i < len(nc.srvPool); i++ { + nc.url = nc.srvPool[i].url + + if err := nc.createConn(); err == nil { + // This was moved out of processConnectInit() because + // that function is now invoked from doReconnect() too. + nc.setup() + + err = nc.processConnectInit() + + if err == nil { + nc.srvPool[i].didConnect = true + nc.srvPool[i].reconnects = 0 + returnedErr = nil + break + } else { + returnedErr = err + nc.mu.Unlock() + nc.close(DISCONNECTED, false) + nc.mu.Lock() + nc.url = nil + } + } else { + // Cancel out default connection refused, will trigger the + // No servers error conditional + if strings.Contains(err.Error(), "connection refused") { + returnedErr = nil + } + } + } + nc.initc = false + defer nc.mu.Unlock() + + if returnedErr == nil && nc.status != CONNECTED { + returnedErr = ErrNoServers + } + return returnedErr +} + +// This will check to see if the connection should be +// secure. This can be dictated from either end and should +// only be called after the INIT protocol has been received. +func (nc *Conn) checkForSecure() error { + // Check to see if we need to engage TLS + o := nc.Opts + + // Check for mismatch in setups + if o.Secure && !nc.info.TLSRequired { + return ErrSecureConnWanted + } else if nc.info.TLSRequired && !o.Secure { + // Switch to Secure since server needs TLS. + o.Secure = true + } + + // Need to rewrap with bufio + if o.Secure { + nc.makeTLSConn() + } + return nil +} + +// processExpectedInfo will look for the expected first INFO message +// sent when a connection is established. The lock should be held entering. +func (nc *Conn) processExpectedInfo() error { + + c := &control{} + + // Read the protocol + err := nc.readOp(c) + if err != nil { + return err + } + + // The nats protocol should send INFO first always. + if c.op != _INFO_OP_ { + return ErrNoInfoReceived + } + + // Parse the protocol + if err := nc.processInfo(c.args); err != nil { + return err + } + + return nc.checkForSecure() +} + +// Sends a protocol control message by queuing into the bufio writer +// and kicking the flush Go routine. These writes are protected. +func (nc *Conn) sendProto(proto string) { + nc.mu.Lock() + nc.bw.WriteString(proto) + nc.kickFlusher() + nc.mu.Unlock() +} + +// Generate a connect protocol message, issuing user/password if +// applicable. The lock is assumed to be held upon entering. +func (nc *Conn) connectProto() (string, error) { + o := nc.Opts + var user, pass, token string + u := nc.url.User + if u != nil { + // if no password, assume username is authToken + if _, ok := u.Password(); !ok { + token = u.Username() + } else { + user = u.Username() + pass, _ = u.Password() + } + } else { + // Take from options (possibly all empty strings) + user = nc.Opts.User + pass = nc.Opts.Password + token = nc.Opts.Token + } + + cinfo := connectInfo{o.Verbose, o.Pedantic, user, pass, token, + o.Secure, o.Name, LangString, Version, clientProtoInfo, !o.NoEcho} + + b, err := json.Marshal(cinfo) + if err != nil { + return _EMPTY_, ErrJsonParse + } + + // Check if NoEcho is set and we have a server that supports it. + if o.NoEcho && nc.info.Proto < 1 { + return _EMPTY_, ErrNoEchoNotSupported + } + + return fmt.Sprintf(conProto, b), nil +} + +// normalizeErr removes the prefix -ERR, trim spaces and remove the quotes. +func normalizeErr(line string) string { + s := strings.ToLower(strings.TrimSpace(strings.TrimPrefix(line, _ERR_OP_))) + s = strings.TrimLeft(strings.TrimRight(s, "'"), "'") + return s +} + +// Send a connect protocol message to the server, issue user/password if +// applicable. Will wait for a flush to return from the server for error +// processing. +func (nc *Conn) sendConnect() error { + + // Construct the CONNECT protocol string + cProto, err := nc.connectProto() + if err != nil { + return err + } + + // Write the protocol into the buffer + _, err = nc.bw.WriteString(cProto) + if err != nil { + return err + } + + // Add to the buffer the PING protocol + _, err = nc.bw.WriteString(pingProto) + if err != nil { + return err + } + + // Flush the buffer + err = nc.bw.Flush() + if err != nil { + return err + } + + // We don't want to read more than we need here, otherwise + // we would need to transfer the excess read data to the readLoop. + // Since in normal situations we just are looking for a PONG\r\n, + // reading byte-by-byte here is ok. + proto, err := nc.readProto() + if err != nil { + return err + } + + // If opts.Verbose is set, handle +OK + if nc.Opts.Verbose && proto == okProto { + // Read the rest now... + proto, err = nc.readProto() + if err != nil { + return err + } + } + + // We expect a PONG + if proto != pongProto { + // But it could be something else, like -ERR + + // Since we no longer use ReadLine(), trim the trailing "\r\n" + proto = strings.TrimRight(proto, "\r\n") + + // If it's a server error... + if strings.HasPrefix(proto, _ERR_OP_) { + // Remove -ERR, trim spaces and quotes, and convert to lower case. + proto = normalizeErr(proto) + return errors.New("nats: " + proto) + } + + // Notify that we got an unexpected protocol. + return fmt.Errorf("nats: expected '%s', got '%s'", _PONG_OP_, proto) + } + + // This is where we are truly connected. + nc.status = CONNECTED + + return nil +} + +// reads a protocol one byte at a time. +func (nc *Conn) readProto() (string, error) { + var ( + _buf = [10]byte{} + buf = _buf[:0] + b = [1]byte{} + protoEnd = byte('\n') + ) + for { + if _, err := nc.conn.Read(b[:1]); err != nil { + // Do not report EOF error + if err == io.EOF { + return string(buf), nil + } + return "", err + } + buf = append(buf, b[0]) + if b[0] == protoEnd { + return string(buf), nil + } + } +} + +// A control protocol line. +type control struct { + op, args string +} + +// Read a control line and process the intended op. +func (nc *Conn) readOp(c *control) error { + br := bufio.NewReaderSize(nc.conn, defaultBufSize) + line, err := br.ReadString('\n') + if err != nil { + return err + } + parseControl(line, c) + return nil +} + +// Parse a control line from the server. +func parseControl(line string, c *control) { + toks := strings.SplitN(line, _SPC_, 2) + if len(toks) == 1 { + c.op = strings.TrimSpace(toks[0]) + c.args = _EMPTY_ + } else if len(toks) == 2 { + c.op, c.args = strings.TrimSpace(toks[0]), strings.TrimSpace(toks[1]) + } else { + c.op = _EMPTY_ + } +} + +// flushReconnectPending will push the pending items that were +// gathered while we were in a RECONNECTING state to the socket. +func (nc *Conn) flushReconnectPendingItems() { + if nc.pending == nil { + return + } + if nc.pending.Len() > 0 { + nc.bw.Write(nc.pending.Bytes()) + } +} + +// Stops the ping timer if set. +// Connection lock is held on entry. +func (nc *Conn) stopPingTimer() { + if nc.ptmr != nil { + nc.ptmr.Stop() + } +} + +// Try to reconnect using the option parameters. +// This function assumes we are allowed to reconnect. +func (nc *Conn) doReconnect() { + // We want to make sure we have the other watchers shutdown properly + // here before we proceed past this point. + nc.waitForExits() + + // FIXME(dlc) - We have an issue here if we have + // outstanding flush points (pongs) and they were not + // sent out, but are still in the pipe. + + // Hold the lock manually and release where needed below, + // can't do defer here. + nc.mu.Lock() + + // Clear any queued pongs, e.g. pending flush calls. + nc.clearPendingFlushCalls() + + // Clear any errors. + nc.err = nil + // Perform appropriate callback if needed for a disconnect. + if nc.Opts.DisconnectedCB != nil { + nc.ach.push(func() { nc.Opts.DisconnectedCB(nc) }) + } + + // This is used to wait on go routines exit if we start them in the loop + // but an error occurs after that. + waitForGoRoutines := false + + for len(nc.srvPool) > 0 { + cur, err := nc.selectNextServer() + if err != nil { + nc.err = err + break + } + + sleepTime := int64(0) + + // Sleep appropriate amount of time before the + // connection attempt if connecting to same server + // we just got disconnected from.. + if time.Since(cur.lastAttempt) < nc.Opts.ReconnectWait { + sleepTime = int64(nc.Opts.ReconnectWait - time.Since(cur.lastAttempt)) + } + + // On Windows, createConn() will take more than a second when no + // server is running at that address. So it could be that the + // time elapsed between reconnect attempts is always > than + // the set option. Release the lock to give a chance to a parallel + // nc.Close() to break the loop. + nc.mu.Unlock() + if sleepTime <= 0 { + runtime.Gosched() + } else { + time.Sleep(time.Duration(sleepTime)) + } + // If the readLoop, etc.. go routines were started, wait for them to complete. + if waitForGoRoutines { + nc.waitForExits() + waitForGoRoutines = false + } + nc.mu.Lock() + + // Check if we have been closed first. + if nc.isClosed() { + break + } + + // Mark that we tried a reconnect + cur.reconnects++ + + // Try to create a new connection + err = nc.createConn() + + // Not yet connected, retry... + // Continue to hold the lock + if err != nil { + nc.err = nil + continue + } + + // We are reconnected + nc.Reconnects++ + + // Process connect logic + if nc.err = nc.processConnectInit(); nc.err != nil { + nc.status = RECONNECTING + // Reset the buffered writer to the pending buffer + // (was set to a buffered writer on nc.conn in createConn) + nc.bw.Reset(nc.pending) + continue + } + + // Clear out server stats for the server we connected to.. + cur.didConnect = true + cur.reconnects = 0 + + // Send existing subscription state + nc.resendSubscriptions() + + // Now send off and clear pending buffer + nc.flushReconnectPendingItems() + + // Flush the buffer + nc.err = nc.bw.Flush() + if nc.err != nil { + nc.status = RECONNECTING + // Reset the buffered writer to the pending buffer (bytes.Buffer). + nc.bw.Reset(nc.pending) + // Stop the ping timer (if set) + nc.stopPingTimer() + // Since processConnectInit() returned without error, the + // go routines were started, so wait for them to return + // on the next iteration (after releasing the lock). + waitForGoRoutines = true + continue + } + + // Done with the pending buffer + nc.pending = nil + + // This is where we are truly connected. + nc.status = CONNECTED + + // Queue up the reconnect callback. + if nc.Opts.ReconnectedCB != nil { + nc.ach.push(func() { nc.Opts.ReconnectedCB(nc) }) + } + // Release lock here, we will return below. + nc.mu.Unlock() + + // Make sure to flush everything + nc.Flush() + + return + } + + // Call into close.. We have no servers left.. + if nc.err == nil { + nc.err = ErrNoServers + } + nc.mu.Unlock() + nc.Close() +} + +// processOpErr handles errors from reading or parsing the protocol. +// The lock should not be held entering this function. +func (nc *Conn) processOpErr(err error) { + nc.mu.Lock() + if nc.isConnecting() || nc.isClosed() || nc.isReconnecting() { + nc.mu.Unlock() + return + } + + if nc.Opts.AllowReconnect && nc.status == CONNECTED { + // Set our new status + nc.status = RECONNECTING + // Stop ping timer if set + nc.stopPingTimer() + if nc.conn != nil { + nc.bw.Flush() + nc.conn.Close() + nc.conn = nil + } + + // Create pending buffer before reconnecting. + nc.pending = new(bytes.Buffer) + nc.bw.Reset(nc.pending) + + go nc.doReconnect() + nc.mu.Unlock() + return + } + + nc.status = DISCONNECTED + nc.err = err + nc.mu.Unlock() + nc.Close() +} + +// dispatch is responsible for calling any async callbacks +func (ac *asyncCallbacksHandler) asyncCBDispatcher() { + for { + ac.mu.Lock() + // Protect for spurious wakeups. We should get out of the + // wait only if there is an element to pop from the list. + for ac.head == nil { + ac.cond.Wait() + } + cur := ac.head + ac.head = cur.next + if cur == ac.tail { + ac.tail = nil + } + ac.mu.Unlock() + + // This signals that the dispatcher has been closed and all + // previous callbacks have been dispatched. + if cur.f == nil { + return + } + // Invoke callback outside of handler's lock + cur.f() + } +} + +// Add the given function to the tail of the list and +// signals the dispatcher. +func (ac *asyncCallbacksHandler) push(f func()) { + ac.pushOrClose(f, false) +} + +// Signals that we are closing... +func (ac *asyncCallbacksHandler) close() { + ac.pushOrClose(nil, true) +} + +// Add the given function to the tail of the list and +// signals the dispatcher. +func (ac *asyncCallbacksHandler) pushOrClose(f func(), close bool) { + ac.mu.Lock() + defer ac.mu.Unlock() + // Make sure that library is not calling push with nil function, + // since this is used to notify the dispatcher that it should stop. + if !close && f == nil { + panic("pushing a nil callback") + } + cb := &asyncCB{f: f} + if ac.tail != nil { + ac.tail.next = cb + } else { + ac.head = cb + } + ac.tail = cb + if close { + ac.cond.Broadcast() + } else { + ac.cond.Signal() + } +} + +// readLoop() will sit on the socket reading and processing the +// protocol from the server. It will dispatch appropriately based +// on the op type. +func (nc *Conn) readLoop() { + // Release the wait group on exit + defer nc.wg.Done() + + // Create a parseState if needed. + nc.mu.Lock() + if nc.ps == nil { + nc.ps = &parseState{} + } + nc.mu.Unlock() + + // Stack based buffer. + b := make([]byte, defaultBufSize) + + for { + // FIXME(dlc): RWLock here? + nc.mu.Lock() + sb := nc.isClosed() || nc.isReconnecting() + if sb { + nc.ps = &parseState{} + } + conn := nc.conn + nc.mu.Unlock() + + if sb || conn == nil { + break + } + + n, err := conn.Read(b) + if err != nil { + nc.processOpErr(err) + break + } + + if err := nc.parse(b[:n]); err != nil { + nc.processOpErr(err) + break + } + } + // Clear the parseState here.. + nc.mu.Lock() + nc.ps = nil + nc.mu.Unlock() +} + +// waitForMsgs waits on the conditional shared with readLoop and processMsg. +// It is used to deliver messages to asynchronous subscribers. +func (nc *Conn) waitForMsgs(s *Subscription) { + var closed bool + var delivered, max uint64 + + // Used to account for adjustments to sub.pBytes when we wrap back around. + msgLen := -1 + + for { + s.mu.Lock() + // Do accounting for last msg delivered here so we only lock once + // and drain state trips after callback has returned. + if msgLen >= 0 { + s.pMsgs-- + s.pBytes -= msgLen + msgLen = -1 + } + + if s.pHead == nil && !s.closed { + s.pCond.Wait() + } + // Pop the msg off the list + m := s.pHead + if m != nil { + s.pHead = m.next + if s.pHead == nil { + s.pTail = nil + } + if m.barrier != nil { + s.mu.Unlock() + if atomic.AddInt64(&m.barrier.refs, -1) == 0 { + m.barrier.f() + } + continue + } + msgLen = len(m.Data) + } + mcb := s.mcb + max = s.max + closed = s.closed + if !s.closed { + s.delivered++ + delivered = s.delivered + } + s.mu.Unlock() + + if closed { + break + } + + // Deliver the message. + if m != nil && (max == 0 || delivered <= max) { + mcb(m) + } + // If we have hit the max for delivered msgs, remove sub. + if max > 0 && delivered >= max { + nc.mu.Lock() + nc.removeSub(s) + nc.mu.Unlock() + break + } + } + // Check for barrier messages + s.mu.Lock() + for m := s.pHead; m != nil; m = s.pHead { + if m.barrier != nil { + s.mu.Unlock() + if atomic.AddInt64(&m.barrier.refs, -1) == 0 { + m.barrier.f() + } + s.mu.Lock() + } + s.pHead = m.next + } + s.mu.Unlock() +} + +// processMsg is called by parse and will place the msg on the +// appropriate channel/pending queue for processing. If the channel is full, +// or the pending queue is over the pending limits, the connection is +// considered a slow consumer. +func (nc *Conn) processMsg(data []byte) { + // Don't lock the connection to avoid server cutting us off if the + // flusher is holding the connection lock, trying to send to the server + // that is itself trying to send data to us. + nc.subsMu.RLock() + + // Stats + nc.InMsgs++ + nc.InBytes += uint64(len(data)) + + sub := nc.subs[nc.ps.ma.sid] + if sub == nil { + nc.subsMu.RUnlock() + return + } + + // Copy them into string + subj := string(nc.ps.ma.subject) + reply := string(nc.ps.ma.reply) + + // Doing message create outside of the sub's lock to reduce contention. + // It's possible that we end-up not using the message, but that's ok. + + // FIXME(dlc): Need to copy, should/can do COW? + msgPayload := make([]byte, len(data)) + copy(msgPayload, data) + + // FIXME(dlc): Should we recycle these containers? + m := &Msg{Data: msgPayload, Subject: subj, Reply: reply, Sub: sub} + + sub.mu.Lock() + + // Subscription internal stats (applicable only for non ChanSubscription's) + if sub.typ != ChanSubscription { + sub.pMsgs++ + if sub.pMsgs > sub.pMsgsMax { + sub.pMsgsMax = sub.pMsgs + } + sub.pBytes += len(m.Data) + if sub.pBytes > sub.pBytesMax { + sub.pBytesMax = sub.pBytes + } + + // Check for a Slow Consumer + if (sub.pMsgsLimit > 0 && sub.pMsgs > sub.pMsgsLimit) || + (sub.pBytesLimit > 0 && sub.pBytes > sub.pBytesLimit) { + goto slowConsumer + } + } + + // We have two modes of delivery. One is the channel, used by channel + // subscribers and syncSubscribers, the other is a linked list for async. + if sub.mch != nil { + select { + case sub.mch <- m: + default: + goto slowConsumer + } + } else { + // Push onto the async pList + if sub.pHead == nil { + sub.pHead = m + sub.pTail = m + sub.pCond.Signal() + } else { + sub.pTail.next = m + sub.pTail = m + } + } + + // Clear SlowConsumer status. + sub.sc = false + + sub.mu.Unlock() + nc.subsMu.RUnlock() + return + +slowConsumer: + sub.dropped++ + sc := !sub.sc + sub.sc = true + // Undo stats from above + if sub.typ != ChanSubscription { + sub.pMsgs-- + sub.pBytes -= len(m.Data) + } + sub.mu.Unlock() + nc.subsMu.RUnlock() + if sc { + // Now we need connection's lock and we may end-up in the situation + // that we were trying to avoid, except that in this case, the client + // is already experiencing client-side slow consumer situation. + nc.mu.Lock() + nc.err = ErrSlowConsumer + if nc.Opts.AsyncErrorCB != nil { + nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, sub, ErrSlowConsumer) }) + } + nc.mu.Unlock() + } +} + +// processPermissionsViolation is called when the server signals a subject +// permissions violation on either publish or subscribe. +func (nc *Conn) processPermissionsViolation(err string) { + nc.mu.Lock() + // create error here so we can pass it as a closure to the async cb dispatcher. + e := errors.New("nats: " + err) + nc.err = e + if nc.Opts.AsyncErrorCB != nil { + nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, nil, e) }) + } + nc.mu.Unlock() +} + +// processAuthorizationViolation is called when the server signals a user +// authorization violation. +func (nc *Conn) processAuthorizationViolation(err string) { + nc.mu.Lock() + nc.err = ErrAuthorization + if nc.Opts.AsyncErrorCB != nil { + nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, nil, ErrAuthorization) }) + } + nc.mu.Unlock() +} + +// flusher is a separate Go routine that will process flush requests for the write +// bufio. This allows coalescing of writes to the underlying socket. +func (nc *Conn) flusher() { + // Release the wait group + defer nc.wg.Done() + + // snapshot the bw and conn since they can change from underneath of us. + nc.mu.Lock() + bw := nc.bw + conn := nc.conn + fch := nc.fch + flusherTimeout := nc.Opts.FlusherTimeout + nc.mu.Unlock() + + if conn == nil || bw == nil { + return + } + + for { + if _, ok := <-fch; !ok { + return + } + nc.mu.Lock() + + // Check to see if we should bail out. + if !nc.isConnected() || nc.isConnecting() || bw != nc.bw || conn != nc.conn { + nc.mu.Unlock() + return + } + if bw.Buffered() > 0 { + // Allow customizing how long we should wait for a flush to be done + // to prevent unhealthy connections blocking the client for too long. + if flusherTimeout > 0 { + conn.SetWriteDeadline(time.Now().Add(flusherTimeout)) + } + + if err := bw.Flush(); err != nil { + if nc.err == nil { + nc.err = err + } + } + conn.SetWriteDeadline(time.Time{}) + } + nc.mu.Unlock() + } +} + +// processPing will send an immediate pong protocol response to the +// server. The server uses this mechanism to detect dead clients. +func (nc *Conn) processPing() { + nc.sendProto(pongProto) +} + +// processPong is used to process responses to the client's ping +// messages. We use pings for the flush mechanism as well. +func (nc *Conn) processPong() { + var ch chan struct{} + + nc.mu.Lock() + if len(nc.pongs) > 0 { + ch = nc.pongs[0] + nc.pongs = nc.pongs[1:] + } + nc.pout = 0 + nc.mu.Unlock() + if ch != nil { + ch <- struct{}{} + } +} + +// processOK is a placeholder for processing OK messages. +func (nc *Conn) processOK() { + // do nothing +} + +// processInfo is used to parse the info messages sent +// from the server. +// This function may update the server pool. +func (nc *Conn) processInfo(info string) error { + if info == _EMPTY_ { + return nil + } + ncInfo := serverInfo{} + if err := json.Unmarshal([]byte(info), &ncInfo); err != nil { + return err + } + // Copy content into connection's info structure. + nc.info = ncInfo + // The array could be empty/not present on initial connect, + // if advertise is disabled on that server, or servers that + // did not include themselves in the async INFO protocol. + // If empty, do not remove the implicit servers from the pool. + if len(ncInfo.ConnectURLs) == 0 { + return nil + } + // Note about pool randomization: when the pool was first created, + // it was randomized (if allowed). We keep the order the same (removing + // implicit servers that are no longer sent to us). New URLs are sent + // to us in no specific order so don't need extra randomization. + hasNew := false + // This is what we got from the server we are connected to. + urls := nc.info.ConnectURLs + // Transform that to a map for easy lookups + tmp := make(map[string]struct{}, len(urls)) + for _, curl := range urls { + tmp[curl] = struct{}{} + } + // Walk the pool and removed the implicit servers that are no longer in the + // given array/map + sp := nc.srvPool + for i := 0; i < len(sp); i++ { + srv := sp[i] + curl := srv.url.Host + // Check if this URL is in the INFO protocol + _, inInfo := tmp[curl] + // Remove from the temp map so that at the end we are left with only + // new (or restarted) servers that need to be added to the pool. + delete(tmp, curl) + // Keep servers that were set through Options, but also the one that + // we are currently connected to (even if it is a discovered server). + if !srv.isImplicit || srv.url == nc.url { + continue + } + if !inInfo { + // Remove from server pool. Keep current order. + copy(sp[i:], sp[i+1:]) + nc.srvPool = sp[:len(sp)-1] + sp = nc.srvPool + i-- + } + } + // If there are any left in the tmp map, these are new (or restarted) servers + // and need to be added to the pool. + for curl := range tmp { + // Before adding, check if this is a new (as in never seen) URL. + // This is used to figure out if we invoke the DiscoveredServersCB + if _, present := nc.urls[curl]; !present { + hasNew = true + } + nc.addURLToPool(fmt.Sprintf("nats://%s", curl), true) + } + if hasNew && !nc.initc && nc.Opts.DiscoveredServersCB != nil { + nc.ach.push(func() { nc.Opts.DiscoveredServersCB(nc) }) + } + return nil +} + +// processAsyncInfo does the same than processInfo, but is called +// from the parser. Calls processInfo under connection's lock +// protection. +func (nc *Conn) processAsyncInfo(info []byte) { + nc.mu.Lock() + // Ignore errors, we will simply not update the server pool... + nc.processInfo(string(info)) + nc.mu.Unlock() +} + +// LastError reports the last error encountered via the connection. +// It can be used reliably within ClosedCB in order to find out reason +// why connection was closed for example. +func (nc *Conn) LastError() error { + if nc == nil { + return ErrInvalidConnection + } + nc.mu.Lock() + err := nc.err + nc.mu.Unlock() + return err +} + +// processErr processes any error messages from the server and +// sets the connection's lastError. +func (nc *Conn) processErr(e string) { + // Trim, remove quotes, convert to lower case. + e = normalizeErr(e) + + // FIXME(dlc) - process Slow Consumer signals special. + if e == STALE_CONNECTION { + nc.processOpErr(ErrStaleConnection) + } else if strings.HasPrefix(e, PERMISSIONS_ERR) { + nc.processPermissionsViolation(e) + } else if strings.HasPrefix(e, AUTHORIZATION_ERR) { + nc.processAuthorizationViolation(e) + } else { + nc.mu.Lock() + nc.err = errors.New("nats: " + e) + nc.mu.Unlock() + nc.Close() + } +} + +// kickFlusher will send a bool on a channel to kick the +// flush Go routine to flush data to the server. +func (nc *Conn) kickFlusher() { + if nc.bw != nil { + select { + case nc.fch <- struct{}{}: + default: + } + } +} + +// Publish publishes the data argument to the given subject. The data +// argument is left untouched and needs to be correctly interpreted on +// the receiver. +func (nc *Conn) Publish(subj string, data []byte) error { + return nc.publish(subj, _EMPTY_, data) +} + +// PublishMsg publishes the Msg structure, which includes the +// Subject, an optional Reply and an optional Data field. +func (nc *Conn) PublishMsg(m *Msg) error { + if m == nil { + return ErrInvalidMsg + } + return nc.publish(m.Subject, m.Reply, m.Data) +} + +// PublishRequest will perform a Publish() excpecting a response on the +// reply subject. Use Request() for automatically waiting for a response +// inline. +func (nc *Conn) PublishRequest(subj, reply string, data []byte) error { + return nc.publish(subj, reply, data) +} + +// Used for handrolled itoa +const digits = "0123456789" + +// publish is the internal function to publish messages to a nats-server. +// Sends a protocol data message by queuing into the bufio writer +// and kicking the flush go routine. These writes should be protected. +func (nc *Conn) publish(subj, reply string, data []byte) error { + if nc == nil { + return ErrInvalidConnection + } + if subj == "" { + return ErrBadSubject + } + nc.mu.Lock() + + if nc.isClosed() { + nc.mu.Unlock() + return ErrConnectionClosed + } + + if nc.isDrainingPubs() { + nc.mu.Unlock() + return ErrConnectionDraining + } + + // Proactively reject payloads over the threshold set by server. + msgSize := int64(len(data)) + if msgSize > nc.info.MaxPayload { + nc.mu.Unlock() + return ErrMaxPayload + } + + // Check if we are reconnecting, and if so check if + // we have exceeded our reconnect outbound buffer limits. + if nc.isReconnecting() { + // Flush to underlying buffer. + nc.bw.Flush() + // Check if we are over + if nc.pending.Len() >= nc.Opts.ReconnectBufSize { + nc.mu.Unlock() + return ErrReconnectBufExceeded + } + } + + msgh := nc.scratch[:len(_PUB_P_)] + msgh = append(msgh, subj...) + msgh = append(msgh, ' ') + if reply != "" { + msgh = append(msgh, reply...) + msgh = append(msgh, ' ') + } + + // We could be smarter here, but simple loop is ok, + // just avoid strconv in fast path + // FIXME(dlc) - Find a better way here. + // msgh = strconv.AppendInt(msgh, int64(len(data)), 10) + + var b [12]byte + var i = len(b) + if len(data) > 0 { + for l := len(data); l > 0; l /= 10 { + i -= 1 + b[i] = digits[l%10] + } + } else { + i -= 1 + b[i] = digits[0] + } + + msgh = append(msgh, b[i:]...) + msgh = append(msgh, _CRLF_...) + + _, err := nc.bw.Write(msgh) + if err == nil { + _, err = nc.bw.Write(data) + } + if err == nil { + _, err = nc.bw.WriteString(_CRLF_) + } + if err != nil { + nc.mu.Unlock() + return err + } + + nc.OutMsgs++ + nc.OutBytes += uint64(len(data)) + + if len(nc.fch) == 0 { + nc.kickFlusher() + } + nc.mu.Unlock() + return nil +} + +// respHandler is the global response handler. It will look up +// the appropriate channel based on the last token and place +// the message on the channel if possible. +func (nc *Conn) respHandler(m *Msg) { + rt := respToken(m.Subject) + + nc.mu.Lock() + // Just return if closed. + if nc.isClosed() { + nc.mu.Unlock() + return + } + + // Grab mch + mch := nc.respMap[rt] + // Delete the key regardless, one response only. + // FIXME(dlc) - should we track responses past 1 + // just statistics wise? + delete(nc.respMap, rt) + nc.mu.Unlock() + + // Don't block, let Request timeout instead, mch is + // buffered and we should delete the key before a + // second response is processed. + select { + case mch <- m: + default: + return + } +} + +// Create the response subscription we will use for all +// new style responses. This will be on an _INBOX with an +// additional terminal token. The subscription will be on +// a wildcard. Caller is responsible for ensuring this is +// only called once. +func (nc *Conn) createRespMux(respSub string) error { + s, err := nc.Subscribe(respSub, nc.respHandler) + if err != nil { + return err + } + nc.mu.Lock() + nc.respMux = s + nc.mu.Unlock() + return nil +} + +// Request will send a request payload and deliver the response message, +// or an error, including a timeout if no message was received properly. +func (nc *Conn) Request(subj string, data []byte, timeout time.Duration) (*Msg, error) { + if nc == nil { + return nil, ErrInvalidConnection + } + + nc.mu.Lock() + // If user wants the old style. + if nc.Opts.UseOldRequestStyle { + nc.mu.Unlock() + return nc.oldRequest(subj, data, timeout) + } + + // Do setup for the new style. + if nc.respMap == nil { + // _INBOX wildcard + nc.respSub = fmt.Sprintf("%s.*", NewInbox()) + nc.respMap = make(map[string]chan *Msg) + } + // Create literal Inbox and map to a chan msg. + mch := make(chan *Msg, RequestChanLen) + respInbox := nc.newRespInbox() + token := respToken(respInbox) + nc.respMap[token] = mch + createSub := nc.respMux == nil + ginbox := nc.respSub + nc.mu.Unlock() + + if createSub { + // Make sure scoped subscription is setup only once. + var err error + nc.respSetup.Do(func() { err = nc.createRespMux(ginbox) }) + if err != nil { + return nil, err + } + } + + if err := nc.PublishRequest(subj, respInbox, data); err != nil { + return nil, err + } + + t := globalTimerPool.Get(timeout) + defer globalTimerPool.Put(t) + + var ok bool + var msg *Msg + + select { + case msg, ok = <-mch: + if !ok { + return nil, ErrConnectionClosed + } + case <-t.C: + nc.mu.Lock() + delete(nc.respMap, token) + nc.mu.Unlock() + return nil, ErrTimeout + } + + return msg, nil +} + +// oldRequest will create an Inbox and perform a Request() call +// with the Inbox reply and return the first reply received. +// This is optimized for the case of multiple responses. +func (nc *Conn) oldRequest(subj string, data []byte, timeout time.Duration) (*Msg, error) { + inbox := NewInbox() + ch := make(chan *Msg, RequestChanLen) + + s, err := nc.subscribe(inbox, _EMPTY_, nil, ch) + if err != nil { + return nil, err + } + s.AutoUnsubscribe(1) + defer s.Unsubscribe() + + err = nc.PublishRequest(subj, inbox, data) + if err != nil { + return nil, err + } + return s.NextMsg(timeout) +} + +// InboxPrefix is the prefix for all inbox subjects. +const InboxPrefix = "_INBOX." +const inboxPrefixLen = len(InboxPrefix) +const respInboxPrefixLen = inboxPrefixLen + nuidSize + 1 + +// NewInbox will return an inbox string which can be used for directed replies from +// subscribers. These are guaranteed to be unique, but can be shared and subscribed +// to by others. +func NewInbox() string { + var b [inboxPrefixLen + nuidSize]byte + pres := b[:inboxPrefixLen] + copy(pres, InboxPrefix) + ns := b[inboxPrefixLen:] + copy(ns, nuid.Next()) + return string(b[:]) +} + +// Creates a new literal response subject that will trigger +// the global subscription handler. +func (nc *Conn) newRespInbox() string { + var b [inboxPrefixLen + (2 * nuidSize) + 1]byte + pres := b[:respInboxPrefixLen] + copy(pres, nc.respSub) + ns := b[respInboxPrefixLen:] + copy(ns, nuid.Next()) + return string(b[:]) +} + +// respToken will return the last token of a literal response inbox +// which we use for the message channel lookup. +func respToken(respInbox string) string { + return respInbox[respInboxPrefixLen:] +} + +// Subscribe will express interest in the given subject. The subject +// can have wildcards (partial:*, full:>). Messages will be delivered +// to the associated MsgHandler. +func (nc *Conn) Subscribe(subj string, cb MsgHandler) (*Subscription, error) { + return nc.subscribe(subj, _EMPTY_, cb, nil) +} + +// ChanSubscribe will express interest in the given subject and place +// all messages received on the channel. +// You should not close the channel until sub.Unsubscribe() has been called. +func (nc *Conn) ChanSubscribe(subj string, ch chan *Msg) (*Subscription, error) { + return nc.subscribe(subj, _EMPTY_, nil, ch) +} + +// ChanQueueSubscribe will express interest in the given subject. +// All subscribers with the same queue name will form the queue group +// and only one member of the group will be selected to receive any given message, +// which will be placed on the channel. +// You should not close the channel until sub.Unsubscribe() has been called. +// Note: This is the same than QueueSubscribeSyncWithChan. +func (nc *Conn) ChanQueueSubscribe(subj, group string, ch chan *Msg) (*Subscription, error) { + return nc.subscribe(subj, group, nil, ch) +} + +// SubscribeSync will express interest on the given subject. Messages will +// be received synchronously using Subscription.NextMsg(). +func (nc *Conn) SubscribeSync(subj string) (*Subscription, error) { + if nc == nil { + return nil, ErrInvalidConnection + } + mch := make(chan *Msg, nc.Opts.SubChanLen) + s, e := nc.subscribe(subj, _EMPTY_, nil, mch) + if s != nil { + s.typ = SyncSubscription + } + return s, e +} + +// QueueSubscribe creates an asynchronous queue subscriber on the given subject. +// All subscribers with the same queue name will form the queue group and +// only one member of the group will be selected to receive any given +// message asynchronously. +func (nc *Conn) QueueSubscribe(subj, queue string, cb MsgHandler) (*Subscription, error) { + return nc.subscribe(subj, queue, cb, nil) +} + +// QueueSubscribeSync creates a synchronous queue subscriber on the given +// subject. All subscribers with the same queue name will form the queue +// group and only one member of the group will be selected to receive any +// given message synchronously using Subscription.NextMsg(). +func (nc *Conn) QueueSubscribeSync(subj, queue string) (*Subscription, error) { + mch := make(chan *Msg, nc.Opts.SubChanLen) + s, e := nc.subscribe(subj, queue, nil, mch) + if s != nil { + s.typ = SyncSubscription + } + return s, e +} + +// QueueSubscribeSyncWithChan will express interest in the given subject. +// All subscribers with the same queue name will form the queue group +// and only one member of the group will be selected to receive any given message, +// which will be placed on the channel. +// You should not close the channel until sub.Unsubscribe() has been called. +// Note: This is the same than ChanQueueSubscribe. +func (nc *Conn) QueueSubscribeSyncWithChan(subj, queue string, ch chan *Msg) (*Subscription, error) { + return nc.subscribe(subj, queue, nil, ch) +} + +// subscribe is the internal subscribe function that indicates interest in a subject. +func (nc *Conn) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg) (*Subscription, error) { + if nc == nil { + return nil, ErrInvalidConnection + } + nc.mu.Lock() + // ok here, but defer is generally expensive + defer nc.mu.Unlock() + defer nc.kickFlusher() + + // Check for some error conditions. + if nc.isClosed() { + return nil, ErrConnectionClosed + } + if nc.isDraining() { + return nil, ErrConnectionDraining + } + + if cb == nil && ch == nil { + return nil, ErrBadSubscription + } + + sub := &Subscription{Subject: subj, Queue: queue, mcb: cb, conn: nc} + // Set pending limits. + sub.pMsgsLimit = DefaultSubPendingMsgsLimit + sub.pBytesLimit = DefaultSubPendingBytesLimit + + // If we have an async callback, start up a sub specific + // Go routine to deliver the messages. + if cb != nil { + sub.typ = AsyncSubscription + sub.pCond = sync.NewCond(&sub.mu) + go nc.waitForMsgs(sub) + } else { + sub.typ = ChanSubscription + sub.mch = ch + } + + nc.subsMu.Lock() + nc.ssid++ + sub.sid = nc.ssid + nc.subs[sub.sid] = sub + nc.subsMu.Unlock() + + // We will send these for all subs when we reconnect + // so that we can suppress here. + if !nc.isReconnecting() { + fmt.Fprintf(nc.bw, subProto, subj, queue, sub.sid) + } + return sub, nil +} + +// NumSubscriptions returns active number of subscriptions. +func (nc *Conn) NumSubscriptions() int { + nc.mu.Lock() + defer nc.mu.Unlock() + return len(nc.subs) +} + +// Lock for nc should be held here upon entry +func (nc *Conn) removeSub(s *Subscription) { + nc.subsMu.Lock() + delete(nc.subs, s.sid) + nc.subsMu.Unlock() + s.mu.Lock() + defer s.mu.Unlock() + // Release callers on NextMsg for SyncSubscription only + if s.mch != nil && s.typ == SyncSubscription { + close(s.mch) + } + s.mch = nil + + // Mark as invalid + s.conn = nil + s.closed = true + if s.pCond != nil { + s.pCond.Broadcast() + } +} + +// SubscriptionType is the type of the Subscription. +type SubscriptionType int + +// The different types of subscription types. +const ( + AsyncSubscription = SubscriptionType(iota) + SyncSubscription + ChanSubscription + NilSubscription +) + +// Type returns the type of Subscription. +func (s *Subscription) Type() SubscriptionType { + if s == nil { + return NilSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + return s.typ +} + +// IsValid returns a boolean indicating whether the subscription +// is still active. This will return false if the subscription has +// already been closed. +func (s *Subscription) IsValid() bool { + if s == nil { + return false + } + s.mu.Lock() + defer s.mu.Unlock() + return s.conn != nil +} + +// Drain will remove interest but continue callbacks until all messages +// have been processed. +func (s *Subscription) Drain() error { + if s == nil { + return ErrBadSubscription + } + s.mu.Lock() + conn := s.conn + s.mu.Unlock() + if conn == nil { + return ErrBadSubscription + } + return conn.unsubscribe(s, 0, true) +} + +// Unsubscribe will remove interest in the given subject. +func (s *Subscription) Unsubscribe() error { + if s == nil { + return ErrBadSubscription + } + s.mu.Lock() + conn := s.conn + s.mu.Unlock() + if conn == nil { + return ErrBadSubscription + } + if conn.IsDraining() { + return ErrConnectionDraining + } + return conn.unsubscribe(s, 0, false) +} + +// checkDrained will watch for a subscription to be fully drained +// and then remove it. +func (nc *Conn) checkDrained(sub *Subscription) { + if nc == nil || sub == nil { + return + } + + // This allows us to know that whatever we have in the client pending + // is correct and the server will not send additional information. + nc.Flush() + + // Once we are here we just wait for Pending to reach 0 or + // any other state to exit this go routine. + for { + // check connection is still valid. + if nc.IsClosed() { + return + } + + // Check subscription state + sub.mu.Lock() + conn := sub.conn + closed := sub.closed + pMsgs := sub.pMsgs + sub.mu.Unlock() + + if conn == nil || closed || pMsgs == 0 { + nc.mu.Lock() + nc.removeSub(sub) + nc.mu.Unlock() + return + } + + time.Sleep(100 * time.Millisecond) + } +} + +// AutoUnsubscribe will issue an automatic Unsubscribe that is +// processed by the server when max messages have been received. +// This can be useful when sending a request to an unknown number +// of subscribers. +func (s *Subscription) AutoUnsubscribe(max int) error { + if s == nil { + return ErrBadSubscription + } + s.mu.Lock() + conn := s.conn + s.mu.Unlock() + if conn == nil { + return ErrBadSubscription + } + return conn.unsubscribe(s, max, false) +} + +// unsubscribe performs the low level unsubscribe to the server. +// Use Subscription.Unsubscribe() +func (nc *Conn) unsubscribe(sub *Subscription, max int, drainMode bool) error { + nc.mu.Lock() + // ok here, but defer is expensive + defer nc.mu.Unlock() + defer nc.kickFlusher() + + if nc.isClosed() { + return ErrConnectionClosed + } + + nc.subsMu.RLock() + s := nc.subs[sub.sid] + nc.subsMu.RUnlock() + // Already unsubscribed + if s == nil { + return nil + } + + maxStr := _EMPTY_ + if max > 0 { + s.max = uint64(max) + maxStr = strconv.Itoa(max) + } else if !drainMode { + nc.removeSub(s) + } + + if drainMode { + go nc.checkDrained(sub) + } + + // We will send these for all subs when we reconnect + // so that we can suppress here. + if !nc.isReconnecting() { + fmt.Fprintf(nc.bw, unsubProto, s.sid, maxStr) + } + return nil +} + +// NextMsg will return the next message available to a synchronous subscriber +// or block until one is available. A timeout can be used to return when no +// message has been delivered. +func (s *Subscription) NextMsg(timeout time.Duration) (*Msg, error) { + if s == nil { + return nil, ErrBadSubscription + } + + s.mu.Lock() + err := s.validateNextMsgState() + if err != nil { + s.mu.Unlock() + return nil, err + } + + // snapshot + mch := s.mch + s.mu.Unlock() + + var ok bool + var msg *Msg + + t := globalTimerPool.Get(timeout) + defer globalTimerPool.Put(t) + + select { + case msg, ok = <-mch: + if !ok { + return nil, ErrConnectionClosed + } + err := s.processNextMsgDelivered(msg) + if err != nil { + return nil, err + } + case <-t.C: + return nil, ErrTimeout + } + + return msg, nil +} + +// validateNextMsgState checks whether the subscription is in a valid +// state to call NextMsg and be delivered another message synchronously. +// This should be called while holding the lock. +func (s *Subscription) validateNextMsgState() error { + if s.connClosed { + return ErrConnectionClosed + } + if s.mch == nil { + if s.max > 0 && s.delivered >= s.max { + return ErrMaxMessages + } else if s.closed { + return ErrBadSubscription + } + } + if s.mcb != nil { + return ErrSyncSubRequired + } + if s.sc { + s.sc = false + return ErrSlowConsumer + } + + return nil +} + +// processNextMsgDelivered takes a message and applies the needed +// accounting to the stats from the subscription, returning an +// error in case we have the maximum number of messages have been +// delivered already. It should not be called while holding the lock. +func (s *Subscription) processNextMsgDelivered(msg *Msg) error { + s.mu.Lock() + nc := s.conn + max := s.max + + // Update some stats. + s.delivered++ + delivered := s.delivered + if s.typ == SyncSubscription { + s.pMsgs-- + s.pBytes -= len(msg.Data) + } + s.mu.Unlock() + + if max > 0 { + if delivered > max { + return ErrMaxMessages + } + // Remove subscription if we have reached max. + if delivered == max { + nc.mu.Lock() + nc.removeSub(s) + nc.mu.Unlock() + } + } + + return nil +} + +// Queued returns the number of queued messages in the client for this subscription. +// DEPRECATED: Use Pending() +func (s *Subscription) QueuedMsgs() (int, error) { + m, _, err := s.Pending() + return int(m), err +} + +// Pending returns the number of queued messages and queued bytes in the client for this subscription. +func (s *Subscription) Pending() (int, int, error) { + if s == nil { + return -1, -1, ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return -1, -1, ErrBadSubscription + } + if s.typ == ChanSubscription { + return -1, -1, ErrTypeSubscription + } + return s.pMsgs, s.pBytes, nil +} + +// MaxPending returns the maximum number of queued messages and queued bytes seen so far. +func (s *Subscription) MaxPending() (int, int, error) { + if s == nil { + return -1, -1, ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return -1, -1, ErrBadSubscription + } + if s.typ == ChanSubscription { + return -1, -1, ErrTypeSubscription + } + return s.pMsgsMax, s.pBytesMax, nil +} + +// ClearMaxPending resets the maximums seen so far. +func (s *Subscription) ClearMaxPending() error { + if s == nil { + return ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return ErrBadSubscription + } + if s.typ == ChanSubscription { + return ErrTypeSubscription + } + s.pMsgsMax, s.pBytesMax = 0, 0 + return nil +} + +// Pending Limits +const ( + DefaultSubPendingMsgsLimit = 65536 + DefaultSubPendingBytesLimit = 65536 * 1024 +) + +// PendingLimits returns the current limits for this subscription. +// If no error is returned, a negative value indicates that the +// given metric is not limited. +func (s *Subscription) PendingLimits() (int, int, error) { + if s == nil { + return -1, -1, ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return -1, -1, ErrBadSubscription + } + if s.typ == ChanSubscription { + return -1, -1, ErrTypeSubscription + } + return s.pMsgsLimit, s.pBytesLimit, nil +} + +// SetPendingLimits sets the limits for pending msgs and bytes for this subscription. +// Zero is not allowed. Any negative value means that the given metric is not limited. +func (s *Subscription) SetPendingLimits(msgLimit, bytesLimit int) error { + if s == nil { + return ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return ErrBadSubscription + } + if s.typ == ChanSubscription { + return ErrTypeSubscription + } + if msgLimit == 0 || bytesLimit == 0 { + return ErrInvalidArg + } + s.pMsgsLimit, s.pBytesLimit = msgLimit, bytesLimit + return nil +} + +// Delivered returns the number of delivered messages for this subscription. +func (s *Subscription) Delivered() (int64, error) { + if s == nil { + return -1, ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return -1, ErrBadSubscription + } + return int64(s.delivered), nil +} + +// Dropped returns the number of known dropped messages for this subscription. +// This will correspond to messages dropped by violations of PendingLimits. If +// the server declares the connection a SlowConsumer, this number may not be +// valid. +func (s *Subscription) Dropped() (int, error) { + if s == nil { + return -1, ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return -1, ErrBadSubscription + } + return s.dropped, nil +} + +// FIXME: This is a hack +// removeFlushEntry is needed when we need to discard queued up responses +// for our pings as part of a flush call. This happens when we have a flush +// call outstanding and we call close. +func (nc *Conn) removeFlushEntry(ch chan struct{}) bool { + nc.mu.Lock() + defer nc.mu.Unlock() + if nc.pongs == nil { + return false + } + for i, c := range nc.pongs { + if c == ch { + nc.pongs[i] = nil + return true + } + } + return false +} + +// The lock must be held entering this function. +func (nc *Conn) sendPing(ch chan struct{}) { + nc.pongs = append(nc.pongs, ch) + nc.bw.WriteString(pingProto) + // Flush in place. + nc.bw.Flush() +} + +// This will fire periodically and send a client origin +// ping to the server. Will also check that we have received +// responses from the server. +func (nc *Conn) processPingTimer() { + nc.mu.Lock() + + if nc.status != CONNECTED { + nc.mu.Unlock() + return + } + + // Check for violation + nc.pout++ + if nc.pout > nc.Opts.MaxPingsOut { + nc.mu.Unlock() + nc.processOpErr(ErrStaleConnection) + return + } + + nc.sendPing(nil) + nc.ptmr.Reset(nc.Opts.PingInterval) + nc.mu.Unlock() +} + +// FlushTimeout allows a Flush operation to have an associated timeout. +func (nc *Conn) FlushTimeout(timeout time.Duration) (err error) { + if nc == nil { + return ErrInvalidConnection + } + if timeout <= 0 { + return ErrBadTimeout + } + + nc.mu.Lock() + if nc.isClosed() { + nc.mu.Unlock() + return ErrConnectionClosed + } + t := globalTimerPool.Get(timeout) + defer globalTimerPool.Put(t) + + // Create a buffered channel to prevent chan send to block + // in processPong() if this code here times out just when + // PONG was received. + ch := make(chan struct{}, 1) + nc.sendPing(ch) + nc.mu.Unlock() + + select { + case _, ok := <-ch: + if !ok { + err = ErrConnectionClosed + } else { + close(ch) + } + case <-t.C: + err = ErrTimeout + } + + if err != nil { + nc.removeFlushEntry(ch) + } + return +} + +// Flush will perform a round trip to the server and return when it +// receives the internal reply. +func (nc *Conn) Flush() error { + return nc.FlushTimeout(60 * time.Second) +} + +// Buffered will return the number of bytes buffered to be sent to the server. +// FIXME(dlc) take into account disconnected state. +func (nc *Conn) Buffered() (int, error) { + nc.mu.Lock() + defer nc.mu.Unlock() + if nc.isClosed() || nc.bw == nil { + return -1, ErrConnectionClosed + } + return nc.bw.Buffered(), nil +} + +// resendSubscriptions will send our subscription state back to the +// server. Used in reconnects +func (nc *Conn) resendSubscriptions() { + // Since we are going to send protocols to the server, we don't want to + // be holding the subsMu lock (which is used in processMsg). So copy + // the subscriptions in a temporary array. + nc.subsMu.RLock() + subs := make([]*Subscription, 0, len(nc.subs)) + for _, s := range nc.subs { + subs = append(subs, s) + } + nc.subsMu.RUnlock() + for _, s := range subs { + adjustedMax := uint64(0) + s.mu.Lock() + if s.max > 0 { + if s.delivered < s.max { + adjustedMax = s.max - s.delivered + } + + // adjustedMax could be 0 here if the number of delivered msgs + // reached the max, if so unsubscribe. + if adjustedMax == 0 { + s.mu.Unlock() + fmt.Fprintf(nc.bw, unsubProto, s.sid, _EMPTY_) + continue + } + } + s.mu.Unlock() + + fmt.Fprintf(nc.bw, subProto, s.Subject, s.Queue, s.sid) + if adjustedMax > 0 { + maxStr := strconv.Itoa(int(adjustedMax)) + fmt.Fprintf(nc.bw, unsubProto, s.sid, maxStr) + } + } +} + +// This will clear any pending flush calls and release pending calls. +// Lock is assumed to be held by the caller. +func (nc *Conn) clearPendingFlushCalls() { + // Clear any queued pongs, e.g. pending flush calls. + for _, ch := range nc.pongs { + if ch != nil { + close(ch) + } + } + nc.pongs = nil +} + +// This will clear any pending Request calls. +// Lock is assumed to be held by the caller. +func (nc *Conn) clearPendingRequestCalls() { + if nc.respMap == nil { + return + } + for key, ch := range nc.respMap { + if ch != nil { + close(ch) + delete(nc.respMap, key) + } + } +} + +// Low level close call that will do correct cleanup and set +// desired status. Also controls whether user defined callbacks +// will be triggered. The lock should not be held entering this +// function. This function will handle the locking manually. +func (nc *Conn) close(status Status, doCBs bool) { + nc.mu.Lock() + if nc.isClosed() { + nc.status = status + nc.mu.Unlock() + return + } + nc.status = CLOSED + + // Kick the Go routines so they fall out. + nc.kickFlusher() + nc.mu.Unlock() + + nc.mu.Lock() + + // Clear any queued pongs, e.g. pending flush calls. + nc.clearPendingFlushCalls() + + // Clear any queued and blocking Requests. + nc.clearPendingRequestCalls() + + // Stop ping timer if set. + nc.stopPingTimer() + nc.ptmr = nil + + // Go ahead and make sure we have flushed the outbound + if nc.conn != nil { + nc.bw.Flush() + defer nc.conn.Close() + } + + // Close sync subscriber channels and release any + // pending NextMsg() calls. + nc.subsMu.Lock() + for _, s := range nc.subs { + s.mu.Lock() + + // Release callers on NextMsg for SyncSubscription only + if s.mch != nil && s.typ == SyncSubscription { + close(s.mch) + } + s.mch = nil + // Mark as invalid, for signaling to deliverMsgs + s.closed = true + // Mark connection closed in subscription + s.connClosed = true + // If we have an async subscription, signals it to exit + if s.typ == AsyncSubscription && s.pCond != nil { + s.pCond.Signal() + } + + s.mu.Unlock() + } + nc.subs = nil + nc.subsMu.Unlock() + + nc.status = status + + // Perform appropriate callback if needed for a disconnect. + if doCBs { + if nc.Opts.DisconnectedCB != nil && nc.conn != nil { + nc.ach.push(func() { nc.Opts.DisconnectedCB(nc) }) + } + if nc.Opts.ClosedCB != nil { + nc.ach.push(func() { nc.Opts.ClosedCB(nc) }) + } + nc.ach.close() + } + nc.mu.Unlock() +} + +// Close will close the connection to the server. This call will release +// all blocking calls, such as Flush() and NextMsg() +func (nc *Conn) Close() { + nc.close(CLOSED, true) +} + +// IsClosed tests if a Conn has been closed. +func (nc *Conn) IsClosed() bool { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.isClosed() +} + +// IsReconnecting tests if a Conn is reconnecting. +func (nc *Conn) IsReconnecting() bool { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.isReconnecting() +} + +// IsConnected tests if a Conn is connected. +func (nc *Conn) IsConnected() bool { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.isConnected() +} + +// drainConnection will run in a separate Go routine and will +// flush all publishes and drain all active subscriptions. +func (nc *Conn) drainConnection() { + // Snapshot subs list. + nc.mu.Lock() + subs := make([]*Subscription, 0, len(nc.subs)) + for _, s := range nc.subs { + subs = append(subs, s) + } + errCB := nc.Opts.AsyncErrorCB + drainWait := nc.Opts.DrainTimeout + nc.mu.Unlock() + + // for pushing errors with context. + pushErr := func(err error) { + nc.mu.Lock() + if errCB != nil { + nc.ach.push(func() { errCB(nc, nil, err) }) + } + nc.mu.Unlock() + } + + // Do subs first + for _, s := range subs { + if err := s.Drain(); err != nil { + // We will notify about these but continue. + pushErr(err) + } + } + + // Wait for the subscriptions to drop to zero. + timeout := time.Now().Add(drainWait) + for time.Now().Before(timeout) { + if nc.NumSubscriptions() == 0 { + break + } + time.Sleep(10 * time.Millisecond) + } + + // Check if we timed out. + if nc.NumSubscriptions() != 0 { + pushErr(ErrDrainTimeout) + } + + // Flip State + nc.mu.Lock() + nc.status = DRAINING_PUBS + nc.mu.Unlock() + + // Do publish drain via Flush() call. + err := nc.Flush() + if err != nil { + pushErr(err) + nc.Close() + return + } + + // Move to closed state. + nc.Close() +} + +// Drain will put a connection into a drain state. All subscriptions will +// immediately be put into a drain state. Upon completion, the publishers +// will be drained and can not publish any additional messages. Upon draining +// of the publishers, the connection will be closed. Use the ClosedCB() +// option to know when the connection has moved from draining to closed. +func (nc *Conn) Drain() error { + nc.mu.Lock() + defer nc.mu.Unlock() + + if nc.isClosed() { + return ErrConnectionClosed + } + if nc.isConnecting() || nc.isReconnecting() { + return ErrConnectionReconnecting + } + if nc.isDraining() { + return nil + } + + nc.status = DRAINING_SUBS + go nc.drainConnection() + return nil +} + +// IsDraining tests if a Conn is in the draining state. +func (nc *Conn) IsDraining() bool { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.isDraining() +} + +// caller must lock +func (nc *Conn) getServers(implicitOnly bool) []string { + poolSize := len(nc.srvPool) + var servers = make([]string, 0) + for i := 0; i < poolSize; i++ { + if implicitOnly && !nc.srvPool[i].isImplicit { + continue + } + url := nc.srvPool[i].url + servers = append(servers, fmt.Sprintf("%s://%s", url.Scheme, url.Host)) + } + return servers +} + +// Servers returns the list of known server urls, including additional +// servers discovered after a connection has been established. If +// authentication is enabled, use UserInfo or Token when connecting with +// these urls. +func (nc *Conn) Servers() []string { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.getServers(false) +} + +// DiscoveredServers returns only the server urls that have been discovered +// after a connection has been established. If authentication is enabled, +// use UserInfo or Token when connecting with these urls. +func (nc *Conn) DiscoveredServers() []string { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.getServers(true) +} + +// Status returns the current state of the connection. +func (nc *Conn) Status() Status { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.status +} + +// Test if Conn has been closed Lock is assumed held. +func (nc *Conn) isClosed() bool { + return nc.status == CLOSED +} + +// Test if Conn is in the process of connecting +func (nc *Conn) isConnecting() bool { + return nc.status == CONNECTING +} + +// Test if Conn is being reconnected. +func (nc *Conn) isReconnecting() bool { + return nc.status == RECONNECTING +} + +// Test if Conn is connected or connecting. +func (nc *Conn) isConnected() bool { + return nc.status == CONNECTED || nc.isDraining() +} + +// Test if Conn is in the draining state. +func (nc *Conn) isDraining() bool { + return nc.status == DRAINING_SUBS || nc.status == DRAINING_PUBS +} + +// Test if Conn is in the draining state for pubs. +func (nc *Conn) isDrainingPubs() bool { + return nc.status == DRAINING_PUBS +} + +// Stats will return a race safe copy of the Statistics section for the connection. +func (nc *Conn) Stats() Statistics { + // Stats are updated either under connection's mu or subsMu mutexes. + // Lock both to safely get them. + nc.mu.Lock() + nc.subsMu.RLock() + stats := Statistics{ + InMsgs: nc.InMsgs, + InBytes: nc.InBytes, + OutMsgs: nc.OutMsgs, + OutBytes: nc.OutBytes, + Reconnects: nc.Reconnects, + } + nc.subsMu.RUnlock() + nc.mu.Unlock() + return stats +} + +// MaxPayload returns the size limit that a message payload can have. +// This is set by the server configuration and delivered to the client +// upon connect. +func (nc *Conn) MaxPayload() int64 { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.info.MaxPayload +} + +// AuthRequired will return if the connected server requires authorization. +func (nc *Conn) AuthRequired() bool { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.info.AuthRequired +} + +// TLSRequired will return if the connected server requires TLS connections. +func (nc *Conn) TLSRequired() bool { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.info.TLSRequired +} + +// Barrier schedules the given function `f` to all registered asynchronous +// subscriptions. +// Only the last subscription to see this barrier will invoke the function. +// If no subscription is registered at the time of this call, `f()` is invoked +// right away. +// ErrConnectionClosed is returned if the connection is closed prior to +// the call. +func (nc *Conn) Barrier(f func()) error { + nc.mu.Lock() + if nc.isClosed() { + nc.mu.Unlock() + return ErrConnectionClosed + } + nc.subsMu.Lock() + // Need to figure out how many non chan subscriptions there are + numSubs := 0 + for _, sub := range nc.subs { + if sub.typ == AsyncSubscription { + numSubs++ + } + } + if numSubs == 0 { + nc.subsMu.Unlock() + nc.mu.Unlock() + f() + return nil + } + barrier := &barrierInfo{refs: int64(numSubs), f: f} + for _, sub := range nc.subs { + sub.mu.Lock() + if sub.mch == nil { + msg := &Msg{barrier: barrier} + // Push onto the async pList + if sub.pTail != nil { + sub.pTail.next = msg + } else { + sub.pHead = msg + sub.pCond.Signal() + } + sub.pTail = msg + } + sub.mu.Unlock() + } + nc.subsMu.Unlock() + nc.mu.Unlock() + return nil +} diff --git a/vendor/github.com/nats-io/go-nats/netchan.go b/vendor/github.com/nats-io/go-nats/netchan.go new file mode 100644 index 00000000000..add3cba5294 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/netchan.go @@ -0,0 +1,111 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nats + +import ( + "errors" + "reflect" +) + +// This allows the functionality for network channels by binding send and receive Go chans +// to subjects and optionally queue groups. +// Data will be encoded and decoded via the EncodedConn and its associated encoders. + +// BindSendChan binds a channel for send operations to NATS. +func (c *EncodedConn) BindSendChan(subject string, channel interface{}) error { + chVal := reflect.ValueOf(channel) + if chVal.Kind() != reflect.Chan { + return ErrChanArg + } + go chPublish(c, chVal, subject) + return nil +} + +// Publish all values that arrive on the channel until it is closed or we +// encounter an error. +func chPublish(c *EncodedConn, chVal reflect.Value, subject string) { + for { + val, ok := chVal.Recv() + if !ok { + // Channel has most likely been closed. + return + } + if e := c.Publish(subject, val.Interface()); e != nil { + // Do this under lock. + c.Conn.mu.Lock() + defer c.Conn.mu.Unlock() + + if c.Conn.Opts.AsyncErrorCB != nil { + // FIXME(dlc) - Not sure this is the right thing to do. + // FIXME(ivan) - If the connection is not yet closed, try to schedule the callback + if c.Conn.isClosed() { + go c.Conn.Opts.AsyncErrorCB(c.Conn, nil, e) + } else { + c.Conn.ach.push(func() { c.Conn.Opts.AsyncErrorCB(c.Conn, nil, e) }) + } + } + return + } + } +} + +// BindRecvChan binds a channel for receive operations from NATS. +func (c *EncodedConn) BindRecvChan(subject string, channel interface{}) (*Subscription, error) { + return c.bindRecvChan(subject, _EMPTY_, channel) +} + +// BindRecvQueueChan binds a channel for queue-based receive operations from NATS. +func (c *EncodedConn) BindRecvQueueChan(subject, queue string, channel interface{}) (*Subscription, error) { + return c.bindRecvChan(subject, queue, channel) +} + +// Internal function to bind receive operations for a channel. +func (c *EncodedConn) bindRecvChan(subject, queue string, channel interface{}) (*Subscription, error) { + chVal := reflect.ValueOf(channel) + if chVal.Kind() != reflect.Chan { + return nil, ErrChanArg + } + argType := chVal.Type().Elem() + + cb := func(m *Msg) { + var oPtr reflect.Value + if argType.Kind() != reflect.Ptr { + oPtr = reflect.New(argType) + } else { + oPtr = reflect.New(argType.Elem()) + } + if err := c.Enc.Decode(m.Subject, m.Data, oPtr.Interface()); err != nil { + c.Conn.err = errors.New("nats: Got an error trying to unmarshal: " + err.Error()) + if c.Conn.Opts.AsyncErrorCB != nil { + c.Conn.ach.push(func() { c.Conn.Opts.AsyncErrorCB(c.Conn, m.Sub, c.Conn.err) }) + } + return + } + if argType.Kind() != reflect.Ptr { + oPtr = reflect.Indirect(oPtr) + } + // This is a bit hacky, but in this instance we may be trying to send to a closed channel. + // and the user does not know when it is safe to close the channel. + defer func() { + // If we have panicked, recover and close the subscription. + if r := recover(); r != nil { + m.Sub.Unsubscribe() + } + }() + // Actually do the send to the channel. + chVal.Send(oPtr) + } + + return c.Conn.subscribe(subject, queue, cb, nil) +} diff --git a/vendor/github.com/nats-io/go-nats/parser.go b/vendor/github.com/nats-io/go-nats/parser.go new file mode 100644 index 00000000000..a4b3ea0ea75 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/parser.go @@ -0,0 +1,481 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nats + +import ( + "fmt" +) + +type msgArg struct { + subject []byte + reply []byte + sid int64 + size int +} + +const MAX_CONTROL_LINE_SIZE = 1024 + +type parseState struct { + state int + as int + drop int + ma msgArg + argBuf []byte + msgBuf []byte + scratch [MAX_CONTROL_LINE_SIZE]byte +} + +const ( + OP_START = iota + OP_PLUS + OP_PLUS_O + OP_PLUS_OK + OP_MINUS + OP_MINUS_E + OP_MINUS_ER + OP_MINUS_ERR + OP_MINUS_ERR_SPC + MINUS_ERR_ARG + OP_M + OP_MS + OP_MSG + OP_MSG_SPC + MSG_ARG + MSG_PAYLOAD + MSG_END + OP_P + OP_PI + OP_PIN + OP_PING + OP_PO + OP_PON + OP_PONG + OP_I + OP_IN + OP_INF + OP_INFO + OP_INFO_SPC + INFO_ARG +) + +// parse is the fast protocol parser engine. +func (nc *Conn) parse(buf []byte) error { + var i int + var b byte + + // Move to loop instead of range syntax to allow jumping of i + for i = 0; i < len(buf); i++ { + b = buf[i] + + switch nc.ps.state { + case OP_START: + switch b { + case 'M', 'm': + nc.ps.state = OP_M + case 'P', 'p': + nc.ps.state = OP_P + case '+': + nc.ps.state = OP_PLUS + case '-': + nc.ps.state = OP_MINUS + case 'I', 'i': + nc.ps.state = OP_I + default: + goto parseErr + } + case OP_M: + switch b { + case 'S', 's': + nc.ps.state = OP_MS + default: + goto parseErr + } + case OP_MS: + switch b { + case 'G', 'g': + nc.ps.state = OP_MSG + default: + goto parseErr + } + case OP_MSG: + switch b { + case ' ', '\t': + nc.ps.state = OP_MSG_SPC + default: + goto parseErr + } + case OP_MSG_SPC: + switch b { + case ' ', '\t': + continue + default: + nc.ps.state = MSG_ARG + nc.ps.as = i + } + case MSG_ARG: + switch b { + case '\r': + nc.ps.drop = 1 + case '\n': + var arg []byte + if nc.ps.argBuf != nil { + arg = nc.ps.argBuf + } else { + arg = buf[nc.ps.as : i-nc.ps.drop] + } + if err := nc.processMsgArgs(arg); err != nil { + return err + } + nc.ps.drop, nc.ps.as, nc.ps.state = 0, i+1, MSG_PAYLOAD + + // jump ahead with the index. If this overruns + // what is left we fall out and process split + // buffer. + i = nc.ps.as + nc.ps.ma.size - 1 + default: + if nc.ps.argBuf != nil { + nc.ps.argBuf = append(nc.ps.argBuf, b) + } + } + case MSG_PAYLOAD: + if nc.ps.msgBuf != nil { + if len(nc.ps.msgBuf) >= nc.ps.ma.size { + nc.processMsg(nc.ps.msgBuf) + nc.ps.argBuf, nc.ps.msgBuf, nc.ps.state = nil, nil, MSG_END + } else { + // copy as much as we can to the buffer and skip ahead. + toCopy := nc.ps.ma.size - len(nc.ps.msgBuf) + avail := len(buf) - i + + if avail < toCopy { + toCopy = avail + } + + if toCopy > 0 { + start := len(nc.ps.msgBuf) + // This is needed for copy to work. + nc.ps.msgBuf = nc.ps.msgBuf[:start+toCopy] + copy(nc.ps.msgBuf[start:], buf[i:i+toCopy]) + // Update our index + i = (i + toCopy) - 1 + } else { + nc.ps.msgBuf = append(nc.ps.msgBuf, b) + } + } + } else if i-nc.ps.as >= nc.ps.ma.size { + nc.processMsg(buf[nc.ps.as:i]) + nc.ps.argBuf, nc.ps.msgBuf, nc.ps.state = nil, nil, MSG_END + } + case MSG_END: + switch b { + case '\n': + nc.ps.drop, nc.ps.as, nc.ps.state = 0, i+1, OP_START + default: + continue + } + case OP_PLUS: + switch b { + case 'O', 'o': + nc.ps.state = OP_PLUS_O + default: + goto parseErr + } + case OP_PLUS_O: + switch b { + case 'K', 'k': + nc.ps.state = OP_PLUS_OK + default: + goto parseErr + } + case OP_PLUS_OK: + switch b { + case '\n': + nc.processOK() + nc.ps.drop, nc.ps.state = 0, OP_START + } + case OP_MINUS: + switch b { + case 'E', 'e': + nc.ps.state = OP_MINUS_E + default: + goto parseErr + } + case OP_MINUS_E: + switch b { + case 'R', 'r': + nc.ps.state = OP_MINUS_ER + default: + goto parseErr + } + case OP_MINUS_ER: + switch b { + case 'R', 'r': + nc.ps.state = OP_MINUS_ERR + default: + goto parseErr + } + case OP_MINUS_ERR: + switch b { + case ' ', '\t': + nc.ps.state = OP_MINUS_ERR_SPC + default: + goto parseErr + } + case OP_MINUS_ERR_SPC: + switch b { + case ' ', '\t': + continue + default: + nc.ps.state = MINUS_ERR_ARG + nc.ps.as = i + } + case MINUS_ERR_ARG: + switch b { + case '\r': + nc.ps.drop = 1 + case '\n': + var arg []byte + if nc.ps.argBuf != nil { + arg = nc.ps.argBuf + nc.ps.argBuf = nil + } else { + arg = buf[nc.ps.as : i-nc.ps.drop] + } + nc.processErr(string(arg)) + nc.ps.drop, nc.ps.as, nc.ps.state = 0, i+1, OP_START + default: + if nc.ps.argBuf != nil { + nc.ps.argBuf = append(nc.ps.argBuf, b) + } + } + case OP_P: + switch b { + case 'I', 'i': + nc.ps.state = OP_PI + case 'O', 'o': + nc.ps.state = OP_PO + default: + goto parseErr + } + case OP_PO: + switch b { + case 'N', 'n': + nc.ps.state = OP_PON + default: + goto parseErr + } + case OP_PON: + switch b { + case 'G', 'g': + nc.ps.state = OP_PONG + default: + goto parseErr + } + case OP_PONG: + switch b { + case '\n': + nc.processPong() + nc.ps.drop, nc.ps.state = 0, OP_START + } + case OP_PI: + switch b { + case 'N', 'n': + nc.ps.state = OP_PIN + default: + goto parseErr + } + case OP_PIN: + switch b { + case 'G', 'g': + nc.ps.state = OP_PING + default: + goto parseErr + } + case OP_PING: + switch b { + case '\n': + nc.processPing() + nc.ps.drop, nc.ps.state = 0, OP_START + } + case OP_I: + switch b { + case 'N', 'n': + nc.ps.state = OP_IN + default: + goto parseErr + } + case OP_IN: + switch b { + case 'F', 'f': + nc.ps.state = OP_INF + default: + goto parseErr + } + case OP_INF: + switch b { + case 'O', 'o': + nc.ps.state = OP_INFO + default: + goto parseErr + } + case OP_INFO: + switch b { + case ' ', '\t': + nc.ps.state = OP_INFO_SPC + default: + goto parseErr + } + case OP_INFO_SPC: + switch b { + case ' ', '\t': + continue + default: + nc.ps.state = INFO_ARG + nc.ps.as = i + } + case INFO_ARG: + switch b { + case '\r': + nc.ps.drop = 1 + case '\n': + var arg []byte + if nc.ps.argBuf != nil { + arg = nc.ps.argBuf + nc.ps.argBuf = nil + } else { + arg = buf[nc.ps.as : i-nc.ps.drop] + } + nc.processAsyncInfo(arg) + nc.ps.drop, nc.ps.as, nc.ps.state = 0, i+1, OP_START + default: + if nc.ps.argBuf != nil { + nc.ps.argBuf = append(nc.ps.argBuf, b) + } + } + default: + goto parseErr + } + } + // Check for split buffer scenarios + if (nc.ps.state == MSG_ARG || nc.ps.state == MINUS_ERR_ARG || nc.ps.state == INFO_ARG) && nc.ps.argBuf == nil { + nc.ps.argBuf = nc.ps.scratch[:0] + nc.ps.argBuf = append(nc.ps.argBuf, buf[nc.ps.as:i-nc.ps.drop]...) + // FIXME, check max len + } + // Check for split msg + if nc.ps.state == MSG_PAYLOAD && nc.ps.msgBuf == nil { + // We need to clone the msgArg if it is still referencing the + // read buffer and we are not able to process the msg. + if nc.ps.argBuf == nil { + nc.cloneMsgArg() + } + + // If we will overflow the scratch buffer, just create a + // new buffer to hold the split message. + if nc.ps.ma.size > cap(nc.ps.scratch)-len(nc.ps.argBuf) { + lrem := len(buf[nc.ps.as:]) + + nc.ps.msgBuf = make([]byte, lrem, nc.ps.ma.size) + copy(nc.ps.msgBuf, buf[nc.ps.as:]) + } else { + nc.ps.msgBuf = nc.ps.scratch[len(nc.ps.argBuf):len(nc.ps.argBuf)] + nc.ps.msgBuf = append(nc.ps.msgBuf, (buf[nc.ps.as:])...) + } + } + + return nil + +parseErr: + return fmt.Errorf("nats: Parse Error [%d]: '%s'", nc.ps.state, buf[i:]) +} + +// cloneMsgArg is used when the split buffer scenario has the pubArg in the existing read buffer, but +// we need to hold onto it into the next read. +func (nc *Conn) cloneMsgArg() { + nc.ps.argBuf = nc.ps.scratch[:0] + nc.ps.argBuf = append(nc.ps.argBuf, nc.ps.ma.subject...) + nc.ps.argBuf = append(nc.ps.argBuf, nc.ps.ma.reply...) + nc.ps.ma.subject = nc.ps.argBuf[:len(nc.ps.ma.subject)] + if nc.ps.ma.reply != nil { + nc.ps.ma.reply = nc.ps.argBuf[len(nc.ps.ma.subject):] + } +} + +const argsLenMax = 4 + +func (nc *Conn) processMsgArgs(arg []byte) error { + // Unroll splitArgs to avoid runtime/heap issues + a := [argsLenMax][]byte{} + args := a[:0] + start := -1 + for i, b := range arg { + switch b { + case ' ', '\t', '\r', '\n': + if start >= 0 { + args = append(args, arg[start:i]) + start = -1 + } + default: + if start < 0 { + start = i + } + } + } + if start >= 0 { + args = append(args, arg[start:]) + } + + switch len(args) { + case 3: + nc.ps.ma.subject = args[0] + nc.ps.ma.sid = parseInt64(args[1]) + nc.ps.ma.reply = nil + nc.ps.ma.size = int(parseInt64(args[2])) + case 4: + nc.ps.ma.subject = args[0] + nc.ps.ma.sid = parseInt64(args[1]) + nc.ps.ma.reply = args[2] + nc.ps.ma.size = int(parseInt64(args[3])) + default: + return fmt.Errorf("nats: processMsgArgs Parse Error: '%s'", arg) + } + if nc.ps.ma.sid < 0 { + return fmt.Errorf("nats: processMsgArgs Bad or Missing Sid: '%s'", arg) + } + if nc.ps.ma.size < 0 { + return fmt.Errorf("nats: processMsgArgs Bad or Missing Size: '%s'", arg) + } + return nil +} + +// Ascii numbers 0-9 +const ( + ascii_0 = 48 + ascii_9 = 57 +) + +// parseInt64 expects decimal positive numbers. We +// return -1 to signal error +func parseInt64(d []byte) (n int64) { + if len(d) == 0 { + return -1 + } + for _, dec := range d { + if dec < ascii_0 || dec > ascii_9 { + return -1 + } + n = n*10 + (int64(dec) - ascii_0) + } + return n +} diff --git a/vendor/github.com/nats-io/go-nats/timer.go b/vendor/github.com/nats-io/go-nats/timer.go new file mode 100644 index 00000000000..1216762d422 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/timer.go @@ -0,0 +1,56 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nats + +import ( + "sync" + "time" +) + +// global pool of *time.Timer's. can be used by multiple goroutines concurrently. +var globalTimerPool timerPool + +// timerPool provides GC-able pooling of *time.Timer's. +// can be used by multiple goroutines concurrently. +type timerPool struct { + p sync.Pool +} + +// Get returns a timer that completes after the given duration. +func (tp *timerPool) Get(d time.Duration) *time.Timer { + if t, _ := tp.p.Get().(*time.Timer); t != nil { + t.Reset(d) + return t + } + + return time.NewTimer(d) +} + +// Put pools the given timer. +// +// There is no need to call t.Stop() before calling Put. +// +// Put will try to stop the timer before pooling. If the +// given timer already expired, Put will read the unreceived +// value if there is one. +func (tp *timerPool) Put(t *time.Timer) { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + + tp.p.Put(t) +} diff --git a/vendor/github.com/nats-io/go-nats/util/tls.go b/vendor/github.com/nats-io/go-nats/util/tls.go new file mode 100644 index 00000000000..53ff9aa2b48 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/util/tls.go @@ -0,0 +1,27 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +package util + +import "crypto/tls" + +// CloneTLSConfig returns a copy of c. +func CloneTLSConfig(c *tls.Config) *tls.Config { + if c == nil { + return &tls.Config{} + } + + return c.Clone() +} diff --git a/vendor/github.com/nats-io/go-nats/util/tls_go17.go b/vendor/github.com/nats-io/go-nats/util/tls_go17.go new file mode 100644 index 00000000000..fd646d31b95 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/util/tls_go17.go @@ -0,0 +1,49 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.7,!go1.8 + +package util + +import ( + "crypto/tls" +) + +// CloneTLSConfig returns a copy of c. Only the exported fields are copied. +// This is temporary, until this is provided by the language. +// https://go-review.googlesource.com/#/c/28075/ +func CloneTLSConfig(c *tls.Config) *tls.Config { + return &tls.Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + } +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/LICENSE b/vendor/github.com/nats-io/nats-streaming-server/LICENSE new file mode 100644 index 00000000000..f49a4e16e68 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/vendor/github.com/nats-io/nats-streaming-server/logger/logger.go b/vendor/github.com/nats-io/nats-streaming-server/logger/logger.go new file mode 100644 index 00000000000..2090b05d046 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/logger/logger.go @@ -0,0 +1,146 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logger + +import ( + "io" + "sync" + + natsdLogger "github.com/nats-io/gnatsd/logger" + natsd "github.com/nats-io/gnatsd/server" +) + +// LogPrefix is prefixed to all NATS Streaming log messages +const LogPrefix = "STREAM: " + +// Logger interface for the Streaming project. +// This is an alias of the NATS Server's Logger interface. +type Logger natsd.Logger + +// StanLogger is the logger used in this project and implements +// the Logger interface. +type StanLogger struct { + mu sync.RWMutex + debug bool + trace bool + ltime bool + lfile string + log natsd.Logger +} + +// NewStanLogger returns an instance of StanLogger +func NewStanLogger() *StanLogger { + return &StanLogger{} +} + +// SetLogger sets the logger, debug and trace +func (s *StanLogger) SetLogger(log Logger, logtime, debug, trace bool, logfile string) { + s.mu.Lock() + s.log = log + s.ltime = logtime + s.debug = debug + s.trace = trace + s.lfile = logfile + s.mu.Unlock() +} + +// GetLogger returns the logger +func (s *StanLogger) GetLogger() Logger { + s.mu.RLock() + l := s.log + s.mu.RUnlock() + return l +} + +// ReopenLogFile closes and reopen the logfile. +// Does nothing if the logger is not a file based. +func (s *StanLogger) ReopenLogFile() { + s.mu.Lock() + if s.lfile == "" { + s.mu.Unlock() + s.Noticef("File log re-open ignored, not a file logger") + return + } + if l, ok := s.log.(io.Closer); ok { + if err := l.Close(); err != nil { + s.mu.Unlock() + s.Errorf("Unable to close logger: %v", err) + return + } + } + fileLog := natsdLogger.NewFileLogger(s.lfile, s.ltime, s.debug, s.trace, true) + s.log = fileLog + s.mu.Unlock() + s.Noticef("File log re-opened") +} + +// Close closes this logger, releasing possible held resources. +func (s *StanLogger) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + if l, ok := s.log.(io.Closer); ok { + return l.Close() + } + return nil +} + +// Noticef logs a notice statement +func (s *StanLogger) Noticef(format string, v ...interface{}) { + s.executeLogCall(func(log Logger, format string, v ...interface{}) { + log.Noticef(format, v...) + }, format, v...) +} + +// Errorf logs an error +func (s *StanLogger) Errorf(format string, v ...interface{}) { + s.executeLogCall(func(log Logger, format string, v ...interface{}) { + log.Errorf(format, v...) + }, format, v...) +} + +// Fatalf logs a fatal error +func (s *StanLogger) Fatalf(format string, v ...interface{}) { + s.executeLogCall(func(log Logger, format string, v ...interface{}) { + log.Fatalf(format, v...) + }, format, v...) +} + +// Debugf logs a debug statement +func (s *StanLogger) Debugf(format string, v ...interface{}) { + s.executeLogCall(func(log Logger, format string, v ...interface{}) { + // This is running under the protection of StanLogging's lock + if s.debug { + log.Debugf(format, v...) + } + }, format, v...) +} + +// Tracef logs a trace statement +func (s *StanLogger) Tracef(format string, v ...interface{}) { + s.executeLogCall(func(logger Logger, format string, v ...interface{}) { + if s.trace { + logger.Tracef(format, v...) + } + }, format, v...) +} + +func (s *StanLogger) executeLogCall(f func(logger Logger, format string, v ...interface{}), format string, args ...interface{}) { + s.mu.Lock() + if s.log == nil { + s.mu.Unlock() + return + } + f(s.log, LogPrefix+format, args...) + s.mu.Unlock() +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/client.go b/vendor/github.com/nats-io/nats-streaming-server/server/client.go new file mode 100644 index 00000000000..add69cc30d1 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/client.go @@ -0,0 +1,292 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "sync" + "time" + + "github.com/nats-io/nats-streaming-server/spb" + "github.com/nats-io/nats-streaming-server/stores" +) + +// This is a proxy to the store interface. +type clientStore struct { + sync.RWMutex + clients map[string]*client + connIDs map[string]*client + waitOnRegister map[string]chan struct{} + store stores.Store +} + +// client has information needed by the server. A client is also +// stored in a stores.Client object (which contains ID and HbInbox). +type client struct { + sync.RWMutex + info *stores.Client + hbt *time.Timer + fhb int + subs []*subState +} + +// newClientStore creates a new clientStore instance using `store` as the backing storage. +func newClientStore(store stores.Store) *clientStore { + return &clientStore{ + clients: make(map[string]*client), + connIDs: make(map[string]*client), + store: store, + } +} + +// getSubsCopy returns a copy of the client's subscribers array. +// At least Read-lock must be held by the caller. +func (c *client) getSubsCopy() []*subState { + subs := make([]*subState, len(c.subs)) + copy(subs, c.subs) + return subs +} + +// Register a new client. Returns ErrInvalidClient if client is already registered. +func (cs *clientStore) register(info *spb.ClientInfo) (*client, error) { + cs.Lock() + defer cs.Unlock() + c := cs.clients[info.ID] + if c != nil { + return nil, ErrInvalidClient + } + sc, err := cs.store.AddClient(info) + if err != nil { + return nil, err + } + c = &client{info: sc, subs: make([]*subState, 0, 4)} + cs.clients[c.info.ID] = c + if len(c.info.ConnID) > 0 { + cs.connIDs[string(c.info.ConnID)] = c + } + if cs.waitOnRegister != nil { + ch := cs.waitOnRegister[c.info.ID] + if ch != nil { + ch <- struct{}{} + delete(cs.waitOnRegister, c.info.ID) + } + } + return c, nil +} + +// Unregister a client. +func (cs *clientStore) unregister(ID string) (*client, error) { + cs.Lock() + defer cs.Unlock() + c := cs.clients[ID] + if c == nil { + return nil, nil + } + c.Lock() + if c.hbt != nil { + c.hbt.Stop() + c.hbt = nil + } + connID := c.info.ConnID + c.Unlock() + delete(cs.clients, ID) + if len(connID) > 0 { + delete(cs.connIDs, string(connID)) + } + if cs.waitOnRegister != nil { + delete(cs.waitOnRegister, ID) + } + err := cs.store.DeleteClient(ID) + return c, err +} + +// IsValid returns true if the client is registered, false otherwise. +func (cs *clientStore) isValid(ID string, connID []byte) bool { + cs.RLock() + valid := cs.lookupByConnIDOrID(ID, connID) != nil + cs.RUnlock() + return valid +} + +// isValidWithTimeout will return true if the client is registered, +// false if not. +// When the client is not yet registered, this call sets up a go channel +// and waits up to `timeout` for the register() call to send the newly +// registered client to the channel. +// On timeout, this call return false to indicate that the client +// has still not registered. +func (cs *clientStore) isValidWithTimeout(ID string, connID []byte, timeout time.Duration) bool { + cs.Lock() + c := cs.lookupByConnIDOrID(ID, connID) + if c != nil { + cs.Unlock() + return true + } + if cs.waitOnRegister == nil { + cs.waitOnRegister = make(map[string]chan struct{}) + } + ch := make(chan struct{}, 1) + cs.waitOnRegister[ID] = ch + cs.Unlock() + select { + case <-ch: + return true + case <-time.After(timeout): + // We timed out, remove the entry in the map + cs.Lock() + delete(cs.waitOnRegister, ID) + cs.Unlock() + return false + } +} + +// Lookup client by ConnID if not nil, otherwise by clientID. +// Assume at least clientStore RLock is held on entry. +func (cs *clientStore) lookupByConnIDOrID(ID string, connID []byte) *client { + var c *client + if len(connID) > 0 { + c = cs.connIDs[string(connID)] + } else { + c = cs.clients[ID] + } + return c +} + +// Lookup a client +func (cs *clientStore) lookup(ID string) *client { + cs.RLock() + c := cs.clients[ID] + cs.RUnlock() + return c +} + +// Lookup a client by connection ID +func (cs *clientStore) lookupByConnID(connID []byte) *client { + cs.RLock() + c := cs.connIDs[string(connID)] + cs.RUnlock() + return c +} + +// GetSubs returns the list of subscriptions for the client identified by ID, +// or nil if such client is not found. +func (cs *clientStore) getSubs(ID string) []*subState { + cs.RLock() + defer cs.RUnlock() + c := cs.clients[ID] + if c == nil { + return nil + } + c.RLock() + subs := c.getSubsCopy() + c.RUnlock() + return subs +} + +// AddSub adds the subscription to the client identified by clientID +// and returns true only if the client has not been unregistered, +// otherwise returns false. +func (cs *clientStore) addSub(ID string, sub *subState) bool { + cs.RLock() + defer cs.RUnlock() + c := cs.clients[ID] + if c == nil { + return false + } + c.Lock() + c.subs = append(c.subs, sub) + c.Unlock() + return true +} + +// RemoveSub removes the subscription from the client identified by clientID +// and returns true only if the client has not been unregistered and that +// the subscription was found, otherwise returns false. +func (cs *clientStore) removeSub(ID string, sub *subState) bool { + cs.RLock() + defer cs.RUnlock() + c := cs.clients[ID] + if c == nil { + return false + } + c.Lock() + removed := false + c.subs, removed = sub.deleteFromList(c.subs) + c.Unlock() + return removed +} + +// recoverClients recreates the content of the client store based on clients +// information recovered from the Store. +func (cs *clientStore) recoverClients(clients []*stores.Client) { + cs.Lock() + for _, sc := range clients { + client := &client{info: sc, subs: make([]*subState, 0, 4)} + cs.clients[client.info.ID] = client + if len(client.info.ConnID) > 0 { + cs.connIDs[string(client.info.ConnID)] = client + } + } + cs.Unlock() +} + +// setClientHB will lookup the client `ID` and, if present, set the +// client's timer with the given interval and function. +func (cs *clientStore) setClientHB(ID string, interval time.Duration, f func()) { + cs.RLock() + defer cs.RUnlock() + c := cs.clients[ID] + if c == nil { + return + } + c.Lock() + if c.hbt == nil { + c.hbt = time.AfterFunc(interval, f) + } + c.Unlock() +} + +// removeClientHB will stop and remove the client's heartbeat timer, if +// present. +func (cs *clientStore) removeClientHB(c *client) { + if c == nil { + return + } + c.Lock() + if c.hbt != nil { + c.hbt.Stop() + c.hbt = nil + } + c.Unlock() +} + +// getClients returns a snapshot of the registered clients. +// The map itself is a copy (can be iterated safely), but +// the clients objects returned are the one stored in the clientStore. +func (cs *clientStore) getClients() map[string]*client { + cs.RLock() + defer cs.RUnlock() + clients := make(map[string]*client, len(cs.clients)) + for _, c := range cs.clients { + clients[c.info.ID] = c + } + return clients +} + +// count returns the number of registered clients +func (cs *clientStore) count() int { + cs.RLock() + total := len(cs.clients) + cs.RUnlock() + return total +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/clustering.go b/vendor/github.com/nats-io/nats-streaming-server/server/clustering.go new file mode 100644 index 00000000000..cdafb2b758d --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/clustering.go @@ -0,0 +1,487 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "time" + + "github.com/hashicorp/raft" + "github.com/nats-io/go-nats" + "github.com/nats-io/nats-streaming-server/spb" +) + +const ( + defaultJoinRaftGroupTimeout = time.Second +) + +var ( + runningInTests bool + joinRaftGroupTimeout = defaultJoinRaftGroupTimeout + testPauseAfterNewRaftCalled bool +) + +func clusterSetupForTest() { + runningInTests = true + lazyReplicationInterval = 250 * time.Millisecond + joinRaftGroupTimeout = 250 * time.Millisecond +} + +// ClusteringOptions contains STAN Server options related to clustering. +type ClusteringOptions struct { + Clustered bool // Run the server in a clustered configuration. + NodeID string // ID of the node within the cluster. + Bootstrap bool // Bootstrap the cluster as a seed node if there is no existing state. + Peers []string // List of cluster peer node IDs to bootstrap cluster state. + RaftLogPath string // Path to Raft log store directory. + LogCacheSize int // Number of Raft log entries to cache in memory to reduce disk IO. + LogSnapshots int // Number of Raft log snapshots to retain. + TrailingLogs int64 // Number of logs left after a snapshot. + Sync bool // Do a file sync after every write to the Raft log and message store. + RaftLogging bool // Enable logging of Raft library (disabled by default since really verbose). +} + +// raftNode is a handle to a member in a Raft consensus group. +type raftNode struct { + leader int64 + sync.Mutex + closed bool + *raft.Raft + store *raftLog + transport *raft.NetworkTransport + logInput io.WriteCloser + joinSub *nats.Subscription + notifyCh <-chan bool + fsm *raftFSM +} + +type replicatedSub struct { + sub *subState + err error +} + +type raftFSM struct { + sync.Mutex + snapshotsOnInit int + server *StanServer +} + +// shutdown attempts to stop the Raft node. +func (r *raftNode) shutdown() error { + r.Lock() + if r.closed { + r.Unlock() + return nil + } + r.closed = true + r.Unlock() + if r.Raft != nil { + if err := r.Raft.Shutdown().Error(); err != nil { + return err + } + } + if r.transport != nil { + if err := r.transport.Close(); err != nil { + return err + } + } + if r.store != nil { + if err := r.store.Close(); err != nil { + return err + } + } + if r.joinSub != nil { + if err := r.joinSub.Unsubscribe(); err != nil { + return err + } + } + if r.logInput != nil { + if err := r.logInput.Close(); err != nil { + return err + } + } + return nil +} + +// createRaftNode creates and starts a new Raft node. +func (s *StanServer) createServerRaftNode(hasStreamingState bool) error { + var ( + name = s.info.ClusterID + addr = s.getClusteringAddr(name) + existingState, err = s.createRaftNode(name) + ) + if err != nil { + return err + } + if !existingState && hasStreamingState { + return fmt.Errorf("streaming state was recovered but cluster log path %q is empty", s.opts.Clustering.RaftLogPath) + } + node := s.raft + + // Bootstrap if there is no previous state and we are starting this node as + // a seed or a cluster configuration is provided. + bootstrap := !existingState && (s.opts.Clustering.Bootstrap || len(s.opts.Clustering.Peers) > 0) + if bootstrap { + if err := s.bootstrapCluster(name, node.Raft); err != nil { + node.shutdown() + return err + } + } else if !existingState { + // Attempt to join the cluster if we're not bootstrapping. + req, err := (&spb.RaftJoinRequest{NodeID: s.opts.Clustering.NodeID, NodeAddr: addr}).Marshal() + if err != nil { + panic(err) + } + var ( + joined = false + resp = &spb.RaftJoinResponse{} + ) + s.log.Debugf("Joining Raft group %s", name) + // Attempt to join up to 5 times before giving up. + for i := 0; i < 5; i++ { + r, err := s.ncr.Request(fmt.Sprintf("%s.%s.join", defaultRaftPrefix, name), req, joinRaftGroupTimeout) + if err != nil { + time.Sleep(20 * time.Millisecond) + continue + } + if err := resp.Unmarshal(r.Data); err != nil { + time.Sleep(20 * time.Millisecond) + continue + } + if resp.Error != "" { + time.Sleep(20 * time.Millisecond) + continue + } + joined = true + break + } + if !joined { + node.shutdown() + return fmt.Errorf("failed to join Raft group %s", name) + } + } + if s.opts.Clustering.Bootstrap { + // If node is started with bootstrap, regardless if state exist or not, try to + // detect (and report) other nodes in same cluster started with bootstrap=true. + s.wg.Add(1) + go func() { + s.detectBootstrapMisconfig(name) + s.wg.Done() + }() + } + return nil +} + +func (s *StanServer) detectBootstrapMisconfig(name string) { + srvID := []byte(s.serverID) + subj := fmt.Sprintf("%s.%s.bootstrap", defaultRaftPrefix, name) + s.ncr.Subscribe(subj, func(m *nats.Msg) { + if m.Data != nil && m.Reply != "" { + // Ignore message to ourself + if string(m.Data) != s.serverID { + s.ncr.Publish(m.Reply, srvID) + s.log.Fatalf("Server %s was also started with -cluster_bootstrap", string(m.Data)) + } + } + }) + inbox := nats.NewInbox() + s.ncr.Subscribe(inbox, func(m *nats.Msg) { + s.log.Fatalf("Server %s was also started with -cluster_bootstrap", string(m.Data)) + }) + if err := s.ncr.Flush(); err != nil { + s.log.Errorf("Error setting up bootstrap misconfiguration detection: %v", err) + return + } + ticker := time.NewTicker(time.Second) + for { + select { + case <-s.shutdownCh: + ticker.Stop() + return + case <-ticker.C: + s.ncr.PublishRequest(subj, inbox, srvID) + } + } +} + +type raftLogger struct { + *StanServer +} + +func (rl *raftLogger) Write(b []byte) (int, error) { + if !rl.raftLogging { + return len(b), nil + } + levelStart := bytes.IndexByte(b, '[') + if levelStart != -1 { + switch b[levelStart+1] { + case 'D': // [DEBUG] + rl.log.Tracef("%s", b[levelStart+8:]) + case 'I': // [INFO] + rl.log.Noticef("%s", b[levelStart+7:]) + case 'W': // [WARN] + rl.log.Noticef("%s", b[levelStart+7:]) + case 'E': // [ERR] + rl.log.Errorf("%s", b[levelStart+6:]) + default: + rl.log.Noticef("%s", b) + } + } + return len(b), nil +} + +func (rl *raftLogger) Close() error { return nil } + +// createRaftNode creates and starts a new Raft node with the given name and FSM. +func (s *StanServer) createRaftNode(name string) (bool, error) { + path := filepath.Join(s.opts.Clustering.RaftLogPath, name) + if _, err := os.Stat(path); os.IsNotExist(err) { + if err := os.MkdirAll(path, os.ModeDir+os.ModePerm); err != nil { + return false, err + } + } + + // We create s.raft early because once NewRaft() is called, the + // raft code may asynchronously invoke FSM.Apply() and FSM.Restore() + // So we want the object to exist so we can check on leader atomic, etc.. + s.raft = &raftNode{} + + raftLogFileName := filepath.Join(path, raftLogFile) + store, err := newRaftLog(s.log, raftLogFileName, s.opts.Clustering.Sync, int(s.opts.Clustering.TrailingLogs)) + if err != nil { + return false, err + } + cacheStore, err := raft.NewLogCache(s.opts.Clustering.LogCacheSize, store) + if err != nil { + store.Close() + return false, err + } + + addr := s.getClusteringAddr(name) + config := raft.DefaultConfig() + // For tests + if runningInTests { + config.ElectionTimeout = 100 * time.Millisecond + config.HeartbeatTimeout = 100 * time.Millisecond + config.LeaderLeaseTimeout = 50 * time.Millisecond + } + config.LocalID = raft.ServerID(s.opts.Clustering.NodeID) + config.TrailingLogs = uint64(s.opts.Clustering.TrailingLogs) + + logWriter := &raftLogger{s} + config.LogOutput = logWriter + + snapshotStore, err := raft.NewFileSnapshotStore(path, s.opts.Clustering.LogSnapshots, logWriter) + if err != nil { + store.Close() + return false, err + } + sl, err := snapshotStore.List() + if err != nil { + store.Close() + return false, err + } + + // TODO: using a single NATS conn for every channel might be a bottleneck. Maybe pool conns? + transport, err := newNATSTransport(addr, s.ncr, 2*time.Second, logWriter) + if err != nil { + store.Close() + return false, err + } + // Make the snapshot process never timeout... check (s *serverSnapshot).Persist() for details + transport.TimeoutScale = 1 + + // Set up a channel for reliable leader notifications. + raftNotifyCh := make(chan bool, 1) + config.NotifyCh = raftNotifyCh + + fsm := &raftFSM{server: s} + fsm.Lock() + fsm.snapshotsOnInit = len(sl) + fsm.Unlock() + s.raft.fsm = fsm + node, err := raft.NewRaft(config, fsm, cacheStore, store, snapshotStore, transport) + if err != nil { + transport.Close() + store.Close() + return false, err + } + if testPauseAfterNewRaftCalled { + time.Sleep(time.Second) + } + existingState, err := raft.HasExistingState(cacheStore, store, snapshotStore) + if err != nil { + node.Shutdown() + transport.Close() + store.Close() + return false, err + } + + if existingState { + s.log.Debugf("Loaded existing state for Raft group %s", name) + } + + // Handle requests to join the cluster. + sub, err := s.ncr.Subscribe(fmt.Sprintf("%s.%s.join", defaultRaftPrefix, name), func(msg *nats.Msg) { + // Drop the request if we're not the leader. There's no race condition + // after this check because even if we proceed with the cluster add, it + // will fail if the node is not the leader as cluster changes go + // through the Raft log. + if node.State() != raft.Leader { + return + } + req := &spb.RaftJoinRequest{} + if err := req.Unmarshal(msg.Data); err != nil { + s.log.Errorf("Invalid join request for Raft group %s", name) + return + } + + // Add the node as a voter. This is idempotent. No-op if the request + // came from ourselves. + resp := &spb.RaftJoinResponse{} + if req.NodeID != s.opts.Clustering.NodeID { + future := node.AddVoter( + raft.ServerID(req.NodeID), + raft.ServerAddress(req.NodeAddr), 0, 0) + if err := future.Error(); err != nil { + resp.Error = err.Error() + } + } + + // Send the response. + r, err := resp.Marshal() + if err != nil { + panic(err) + } + s.ncr.Publish(msg.Reply, r) + }) + if err != nil { + node.Shutdown() + transport.Close() + store.Close() + return false, err + } + s.raft.Raft = node + s.raft.store = store + s.raft.transport = transport + s.raft.logInput = logWriter + s.raft.notifyCh = raftNotifyCh + s.raft.joinSub = sub + return existingState, nil +} + +// bootstrapCluster bootstraps the node for the provided Raft group either as a +// seed node or with the given peer configuration, depending on configuration +// and with the latter taking precedence. +func (s *StanServer) bootstrapCluster(name string, node *raft.Raft) error { + var ( + addr = s.getClusteringAddr(name) + // Include ourself in the cluster. + servers = []raft.Server{raft.Server{ + ID: raft.ServerID(s.opts.Clustering.NodeID), + Address: raft.ServerAddress(addr), + }} + ) + if len(s.opts.Clustering.Peers) > 0 { + // Bootstrap using provided cluster configuration. + s.log.Debugf("Bootstrapping Raft group %s using provided configuration", name) + for _, peer := range s.opts.Clustering.Peers { + servers = append(servers, raft.Server{ + ID: raft.ServerID(peer), + Address: raft.ServerAddress(s.getClusteringPeerAddr(name, peer)), + }) + } + } else { + // Bootstrap as a seed node. + s.log.Debugf("Bootstrapping Raft group %s as seed node", name) + } + config := raft.Configuration{Servers: servers} + return node.BootstrapCluster(config).Error() +} + +func (s *StanServer) getClusteringAddr(raftName string) string { + return s.getClusteringPeerAddr(raftName, s.opts.Clustering.NodeID) +} + +func (s *StanServer) getClusteringPeerAddr(raftName, nodeID string) string { + return fmt.Sprintf("%s.%s.%s", s.opts.ID, nodeID, raftName) +} + +// Apply log is invoked once a log entry is committed. +// It returns a value which will be made available in the +// ApplyFuture returned by Raft.Apply method if that +// method was called on the same Raft node as the FSM. +func (r *raftFSM) Apply(l *raft.Log) interface{} { + s := r.server + op := &spb.RaftOperation{} + if err := op.Unmarshal(l.Data); err != nil { + panic(err) + } + switch op.OpType { + case spb.RaftOperation_Publish: + // Message replication. + var ( + c *channel + err error + ) + for _, msg := range op.PublishBatch.Messages { + // This is a batch for a given channel, so lookup channel once. + if c == nil { + c, err = s.lookupOrCreateChannel(msg.Subject) + } + if err == nil { + _, err = c.store.Msgs.Store(msg) + } + if err != nil { + panic(fmt.Errorf("failed to store replicated message %d on channel %s: %v", + msg.Sequence, msg.Subject, err)) + } + } + return nil + case spb.RaftOperation_Connect: + // Client connection create replication. + return s.processConnect(op.ClientConnect.Request, op.ClientConnect.Refresh) + case spb.RaftOperation_Disconnect: + // Client connection close replication. + return s.closeClient(op.ClientDisconnect.ClientID) + case spb.RaftOperation_Subscribe: + // Subscription replication. + sub, err := s.processSub(nil, op.Sub.Request, op.Sub.AckInbox) + return &replicatedSub{sub: sub, err: err} + case spb.RaftOperation_RemoveSubscription: + fallthrough + case spb.RaftOperation_CloseSubscription: + // Close/Unsub subscription replication. + isSubClose := op.OpType == spb.RaftOperation_CloseSubscription + s.closeMu.Lock() + err := s.unsubscribe(op.Unsub, isSubClose) + s.closeMu.Unlock() + return err + case spb.RaftOperation_SendAndAck: + if !s.isLeader() { + s.processReplicatedSendAndAck(op.SubSentAck) + } + return nil + case spb.RaftOperation_DeleteChannel: + s.processDeleteChannel(op.Channel) + return nil + default: + panic(fmt.Sprintf("unknown op type %s", op.OpType)) + } +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/conf.go b/vendor/github.com/nats-io/nats-streaming-server/server/conf.go new file mode 100644 index 00000000000..55c94007153 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/conf.go @@ -0,0 +1,666 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "flag" + "fmt" + "reflect" + "strconv" + "strings" + "time" + + "github.com/nats-io/gnatsd/conf" + natsd "github.com/nats-io/gnatsd/server" + "github.com/nats-io/nats-streaming-server/stores" + "github.com/nats-io/nats-streaming-server/util" +) + +// ProcessConfigFile parses the configuration file `configFile` and updates +// the given Streaming options `opts`. +func ProcessConfigFile(configFile string, opts *Options) error { + m, err := conf.ParseFile(configFile) + if err != nil { + return err + } + // Look for a "streaming" key. If so, use only the content of this + // map, otherwise, use all keys. + for k, v := range m { + name := strings.ToLower(k) + if name == "streaming" { + content, ok := v.(map[string]interface{}) + if !ok { + return fmt.Errorf("expected streaming section to be a map/struct, got %v", v) + } + // Override `m` with the content of the streaming map. + m = content + } + } + for k, v := range m { + name := strings.ToLower(k) + switch name { + case "id", "cid", "cluster_id": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + opts.ID = v.(string) + case "discover_prefix": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + opts.DiscoverPrefix = v.(string) + case "st", "store_type", "store", "storetype": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + st := strings.ToUpper(v.(string)) + switch st { + case stores.TypeFile, stores.TypeMemory, stores.TypeSQL: + opts.StoreType = st + default: + return fmt.Errorf("unknown store type: %v", v.(string)) + } + case "dir", "datastore": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + opts.FilestoreDir = v.(string) + case "sd", "stan_debug": + if err := checkType(k, reflect.Bool, v); err != nil { + return err + } + opts.Debug = v.(bool) + case "sv", "stan_trace": + if err := checkType(k, reflect.Bool, v); err != nil { + return err + } + opts.Trace = v.(bool) + case "ns", "nats_server", "nats_server_url": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + opts.NATSServerURL = v.(string) + case "secure": + if err := checkType(k, reflect.Bool, v); err != nil { + return err + } + opts.Secure = v.(bool) + case "tls": + if err := parseTLS(v, opts); err != nil { + return err + } + case "limits", "store_limits", "storelimits": + if err := parseStoreLimits(v, opts); err != nil { + return err + } + case "file", "file_options": + if err := parseFileOptions(v, opts); err != nil { + return err + } + case "sql", "sql_options": + if err := parseSQLOptions(v, opts); err != nil { + return err + } + case "hbi", "hb_interval", "server_to_client_hb_interval": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + dur, err := time.ParseDuration(v.(string)) + if err != nil { + return err + } + opts.ClientHBInterval = dur + case "hbt", "hb_timeout", "server_to_client_hb_timeout": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + dur, err := time.ParseDuration(v.(string)) + if err != nil { + return err + } + opts.ClientHBTimeout = dur + case "hbf", "hb_fail_count", "server_to_client_hb_fail_count": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.ClientHBFailCount = int(v.(int64)) + case "ft_group", "ft_group_name": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + opts.FTGroupName = v.(string) + case "partitioning": + if err := checkType(k, reflect.Bool, v); err != nil { + return err + } + opts.Partitioning = v.(bool) + case "cluster": + if err := parseCluster(v, opts); err != nil { + return err + } + case "syslog_name": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + opts.SyslogName = v.(string) + } + } + return nil +} + +// checkType returns a formatted error if `v` is not of the expected kind. +func checkType(name string, kind reflect.Kind, v interface{}) error { + actualKind := reflect.TypeOf(v).Kind() + if actualKind != kind { + return fmt.Errorf("parameter %q value is expected to be %v, got %v", + name, kind.String(), actualKind.String()) + } + return nil +} + +// parseTLS updates `opts` with TLS config +func parseTLS(itf interface{}, opts *Options) error { + m, ok := itf.(map[string]interface{}) + if !ok { + return fmt.Errorf("expected TLS to be a map/struct, got %v", itf) + } + for k, v := range m { + name := strings.ToLower(k) + switch name { + case "client_cert": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + opts.ClientCert = v.(string) + case "client_key": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + opts.ClientKey = v.(string) + case "client_ca", "client_cacert": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + opts.ClientCA = v.(string) + } + } + return nil +} + +// parseCluster updates `opts` with cluster config +func parseCluster(itf interface{}, opts *Options) error { + m, ok := itf.(map[string]interface{}) + if !ok { + return fmt.Errorf("expected cluster to be a map/struct, got %v", itf) + } + opts.Clustering.Clustered = true + for k, v := range m { + name := strings.ToLower(k) + switch name { + case "node_id": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + opts.Clustering.NodeID = v.(string) + case "bootstrap": + if err := checkType(k, reflect.Bool, v); err != nil { + return err + } + opts.Clustering.Bootstrap = v.(bool) + case "peers": + if err := checkType(k, reflect.Slice, v); err != nil { + return err + } + peers := make([]string, len(v.([]interface{}))) + for i, p := range v.([]interface{}) { + peers[i] = p.(string) + } + opts.Clustering.Peers = peers + case "log_path": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + opts.Clustering.RaftLogPath = v.(string) + case "log_cache_size": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.Clustering.LogCacheSize = int(v.(int64)) + case "log_snapshots": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.Clustering.LogSnapshots = int(v.(int64)) + case "trailing_logs": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.Clustering.TrailingLogs = v.(int64) + case "sync": + if err := checkType(k, reflect.Bool, v); err != nil { + return err + } + opts.Clustering.Sync = v.(bool) + case "raft_logging": + if err := checkType(k, reflect.Bool, v); err != nil { + return err + } + opts.Clustering.RaftLogging = v.(bool) + } + } + return nil +} + +// parseStoreLimits updates `opts` with store limits +func parseStoreLimits(itf interface{}, opts *Options) error { + m, ok := itf.(map[string]interface{}) + if !ok { + return fmt.Errorf("expected store limits to be a map/struct, got %v", itf) + } + for k, v := range m { + name := strings.ToLower(k) + switch name { + case "mc", "max_channels", "maxchannels": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.MaxChannels = int(v.(int64)) + case "channels", "channels_limits", "channelslimits", "per_channel", "per_channel_limits": + if err := parsePerChannelLimits(v, opts); err != nil { + return err + } + default: + // Check for the global limits (MaxMsgs, MaxBytes, etc..) + if err := parseChannelLimits(&opts.ChannelLimits, k, name, v, true); err != nil { + return err + } + } + } + return nil +} + +// parseChannelLimits updates `cl` with channel limits. +func parseChannelLimits(cl *stores.ChannelLimits, k, name string, v interface{}, isGlobal bool) error { + switch name { + case "msu", "max_subs", "max_subscriptions", "maxsubscriptions": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + cl.MaxSubscriptions = int(v.(int64)) + if !isGlobal && cl.MaxSubscriptions == 0 { + cl.MaxSubscriptions = -1 + } + case "mm", "max_msgs", "maxmsgs", "max_count", "maxcount": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + cl.MaxMsgs = int(v.(int64)) + if !isGlobal && cl.MaxMsgs == 0 { + cl.MaxMsgs = -1 + } + case "mb", "max_bytes", "maxbytes": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + cl.MaxBytes = v.(int64) + if !isGlobal && cl.MaxBytes == 0 { + cl.MaxBytes = -1 + } + case "ma", "max_age", "maxage": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + dur, err := time.ParseDuration(v.(string)) + if err != nil { + return err + } + cl.MaxAge = dur + if !isGlobal && cl.MaxAge == 0 { + cl.MaxAge = -1 + } + case "mi", "max_inactivity", "maxinactivity": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + dur, err := time.ParseDuration(v.(string)) + if err != nil { + return err + } + cl.MaxInactivity = dur + if !isGlobal && cl.MaxInactivity == 0 { + cl.MaxInactivity = -1 + } + } + return nil +} + +// parsePerChannelLimits updates `opts` with per channel limits. +func parsePerChannelLimits(itf interface{}, opts *Options) error { + m, ok := itf.(map[string]interface{}) + if !ok { + return fmt.Errorf("expected per channel limits to be a map/struct, got %v", itf) + } + for channelName, limits := range m { + limitsMap, ok := limits.(map[string]interface{}) + if !ok { + return fmt.Errorf("expected channel limits to be a map/struct, got %v", limits) + } + if !util.IsChannelNameValid(channelName, true) { + return fmt.Errorf("invalid channel name %q", channelName) + } + cl := &stores.ChannelLimits{} + for k, v := range limitsMap { + name := strings.ToLower(k) + if err := parseChannelLimits(cl, k, name, v, false); err != nil { + return err + } + } + sl := &opts.StoreLimits + sl.AddPerChannel(channelName, cl) + } + return nil +} + +func parseFileOptions(itf interface{}, opts *Options) error { + m, ok := itf.(map[string]interface{}) + if !ok { + return fmt.Errorf("expected file options to be a map/struct, got %v", itf) + } + for k, v := range m { + name := strings.ToLower(k) + switch name { + case "compact", "compact_enabled": + if err := checkType(k, reflect.Bool, v); err != nil { + return err + } + opts.FileStoreOpts.CompactEnabled = v.(bool) + case "compact_frag", "compact_fragmentation": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.FileStoreOpts.CompactFragmentation = int(v.(int64)) + case "compact_interval": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.FileStoreOpts.CompactInterval = int(v.(int64)) + case "compact_min_size": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.FileStoreOpts.CompactMinFileSize = v.(int64) + case "buffer_size": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.FileStoreOpts.BufferSize = int(v.(int64)) + case "crc", "do_crc": + if err := checkType(k, reflect.Bool, v); err != nil { + return err + } + opts.FileStoreOpts.DoCRC = v.(bool) + case "crc_poly": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.FileStoreOpts.CRCPolynomial = v.(int64) + case "sync", "do_sync", "sync_on_flush": + if err := checkType(k, reflect.Bool, v); err != nil { + return err + } + opts.FileStoreOpts.DoSync = v.(bool) + case "slice_max_msgs", "slice_max_count", "slice_msgs", "slice_count": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.FileStoreOpts.SliceMaxMsgs = int(v.(int64)) + case "slice_max_bytes", "slice_max_size", "slice_bytes", "slice_size": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.FileStoreOpts.SliceMaxBytes = v.(int64) + case "slice_max_age", "slice_age", "slice_max_time", "slice_time_limit": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + dur, err := time.ParseDuration(v.(string)) + if err != nil { + return err + } + opts.FileStoreOpts.SliceMaxAge = dur + case "slice_archive_script", "slice_archive", "slice_script": + if err := checkType(k, reflect.String, v); err != nil { + return err + } + opts.FileStoreOpts.SliceArchiveScript = v.(string) + case "file_descriptors_limit", "fds_limit": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.FileStoreOpts.FileDescriptorsLimit = v.(int64) + case "parallel_recovery": + if err := checkType(k, reflect.Int64, v); err != nil { + return err + } + opts.FileStoreOpts.ParallelRecovery = int(v.(int64)) + } + } + return nil +} + +func parseSQLOptions(itf interface{}, opts *Options) error { + m, ok := itf.(map[string]interface{}) + if !ok { + return fmt.Errorf("expected SQL options to be a map/struct, got %v", itf) + } + for k, v := range m { + name := strings.ToLower(k) + switch name { + case "driver": + if err := checkType(name, reflect.String, v); err != nil { + return err + } + opts.SQLStoreOpts.Driver = v.(string) + case "source": + if err := checkType(name, reflect.String, v); err != nil { + return err + } + opts.SQLStoreOpts.Source = v.(string) + case "no_caching": + if err := checkType(name, reflect.Bool, v); err != nil { + return err + } + opts.SQLStoreOpts.NoCaching = v.(bool) + case "max_open_conns", "max_conns": + if err := checkType(name, reflect.Int64, v); err != nil { + return err + } + opts.SQLStoreOpts.MaxOpenConns = int(v.(int64)) + } + } + return nil +} + +// ConfigureOptions accepts a flag set and augment it with NATS Streaming Server +// specific flags. It then invokes the corresponding function from NATS Server. +// On success, Streaming and NATS options structures are returned configured +// based on the selected flags and/or configuration files. +// The command line options take precedence to the ones in the configuration files. +func ConfigureOptions(fs *flag.FlagSet, args []string, printVersion, printHelp, printTLSHelp func()) (*Options, *natsd.Options, error) { + sopts := GetDefaultOptions() + + var ( + stanConfigFile string + natsConfigFile string + clusterPeers string + ) + + fs.StringVar(&sopts.ID, "cluster_id", DefaultClusterID, "stan.ID") + fs.StringVar(&sopts.ID, "cid", DefaultClusterID, "stan.ID") + fs.StringVar(&sopts.StoreType, "store", stores.TypeMemory, "stan.StoreType") + fs.StringVar(&sopts.StoreType, "st", stores.TypeMemory, "stan.StoreType") + fs.StringVar(&sopts.FilestoreDir, "dir", "", "stan.FilestoreDir") + fs.IntVar(&sopts.MaxChannels, "max_channels", stores.DefaultStoreLimits.MaxChannels, "stan.MaxChannels") + fs.IntVar(&sopts.MaxChannels, "mc", stores.DefaultStoreLimits.MaxChannels, "stan.MaxChannels") + fs.IntVar(&sopts.MaxSubscriptions, "max_subs", stores.DefaultStoreLimits.MaxSubscriptions, "stan.MaxSubscriptions") + fs.IntVar(&sopts.MaxSubscriptions, "msu", stores.DefaultStoreLimits.MaxSubscriptions, "stan.MaxSubscriptions") + fs.IntVar(&sopts.MaxMsgs, "max_msgs", stores.DefaultStoreLimits.MaxMsgs, "stan.MaxMsgs") + fs.IntVar(&sopts.MaxMsgs, "mm", stores.DefaultStoreLimits.MaxMsgs, "stan.MaxMsgs") + fs.String("max_bytes", fmt.Sprintf("%v", stores.DefaultStoreLimits.MaxBytes), "stan.MaxBytes") + fs.String("mb", fmt.Sprintf("%v", stores.DefaultStoreLimits.MaxBytes), "stan.MaxBytes") + fs.DurationVar(&sopts.MaxAge, "max_age", stores.DefaultStoreLimits.MaxAge, "stan.MaxAge") + fs.DurationVar(&sopts.MaxAge, "ma", stores.DefaultStoreLimits.MaxAge, "stan.MaxAge") + fs.DurationVar(&sopts.MaxInactivity, "max_inactivity", stores.DefaultStoreLimits.MaxInactivity, "Maximum inactivity (no new message, no subscription) after which a channel can be garbage collected") + fs.DurationVar(&sopts.MaxInactivity, "mi", stores.DefaultStoreLimits.MaxInactivity, "Maximum inactivity (no new message, no subscription) after which a channel can be garbage collected") + fs.DurationVar(&sopts.ClientHBInterval, "hbi", DefaultHeartBeatInterval, "stan.ClientHBInterval") + fs.DurationVar(&sopts.ClientHBInterval, "hb_interval", DefaultHeartBeatInterval, "stan.ClientHBInterval") + fs.DurationVar(&sopts.ClientHBTimeout, "hbt", DefaultClientHBTimeout, "stan.ClientHBTimeout") + fs.DurationVar(&sopts.ClientHBTimeout, "hb_timeout", DefaultClientHBTimeout, "stan.ClientHBTimeout") + fs.IntVar(&sopts.ClientHBFailCount, "hbf", DefaultMaxFailedHeartBeats, "stan.ClientHBFailCount") + fs.IntVar(&sopts.ClientHBFailCount, "hb_fail_count", DefaultMaxFailedHeartBeats, "stan.ClientHBFailCount") + fs.BoolVar(&sopts.Debug, "SD", false, "stan.Debug") + fs.BoolVar(&sopts.Debug, "stan_debug", false, "stan.Debug") + fs.BoolVar(&sopts.Trace, "SV", false, "stan.Trace") + fs.BoolVar(&sopts.Trace, "stan_trace", false, "stan.Trace") + fs.Bool("SDV", false, "") + fs.BoolVar(&sopts.Secure, "secure", false, "stan.Secure") + fs.StringVar(&sopts.ClientCert, "tls_client_cert", "", "stan.ClientCert") + fs.StringVar(&sopts.ClientKey, "tls_client_key", "", "stan.ClientKey") + fs.StringVar(&sopts.ClientCA, "tls_client_cacert", "", "stan.ClientCA") + fs.StringVar(&sopts.NATSServerURL, "nats_server", "", "stan.NATSServerURL") + fs.StringVar(&sopts.NATSServerURL, "ns", "", "stan.NATSServerURL") + fs.StringVar(&stanConfigFile, "sc", "", "") + fs.StringVar(&stanConfigFile, "stan_config", "", "") + fs.BoolVar(&sopts.FileStoreOpts.CompactEnabled, "file_compact_enabled", stores.DefaultFileStoreOptions.CompactEnabled, "stan.FileStoreOpts.CompactEnabled") + fs.IntVar(&sopts.FileStoreOpts.CompactFragmentation, "file_compact_frag", stores.DefaultFileStoreOptions.CompactFragmentation, "stan.FileStoreOpts.CompactFragmentation") + fs.IntVar(&sopts.FileStoreOpts.CompactInterval, "file_compact_interval", stores.DefaultFileStoreOptions.CompactInterval, "stan.FileStoreOpts.CompactInterval") + fs.String("file_compact_min_size", fmt.Sprintf("%v", stores.DefaultFileStoreOptions.CompactMinFileSize), "stan.FileStoreOpts.CompactMinFileSize") + fs.String("file_buffer_size", fmt.Sprintf("%v", stores.DefaultFileStoreOptions.BufferSize), "stan.FileStoreOpts.BufferSize") + fs.BoolVar(&sopts.FileStoreOpts.DoCRC, "file_crc", stores.DefaultFileStoreOptions.DoCRC, "stan.FileStoreOpts.DoCRC") + fs.Int64Var(&sopts.FileStoreOpts.CRCPolynomial, "file_crc_poly", stores.DefaultFileStoreOptions.CRCPolynomial, "stan.FileStoreOpts.CRCPolynomial") + fs.BoolVar(&sopts.FileStoreOpts.DoSync, "file_sync", stores.DefaultFileStoreOptions.DoSync, "stan.FileStoreOpts.DoSync") + fs.IntVar(&sopts.FileStoreOpts.SliceMaxMsgs, "file_slice_max_msgs", stores.DefaultFileStoreOptions.SliceMaxMsgs, "stan.FileStoreOpts.SliceMaxMsgs") + fs.String("file_slice_max_bytes", fmt.Sprintf("%v", stores.DefaultFileStoreOptions.SliceMaxBytes), "stan.FileStoreOpts.SliceMaxBytes") + fs.DurationVar(&sopts.FileStoreOpts.SliceMaxAge, "file_slice_max_age", stores.DefaultFileStoreOptions.SliceMaxAge, "stan.FileStoreOpts.SliceMaxAge") + fs.StringVar(&sopts.FileStoreOpts.SliceArchiveScript, "file_slice_archive_script", "", "stan.FileStoreOpts.SliceArchiveScript") + fs.Int64Var(&sopts.FileStoreOpts.FileDescriptorsLimit, "file_fds_limit", stores.DefaultFileStoreOptions.FileDescriptorsLimit, "stan.FileStoreOpts.FileDescriptorsLimit") + fs.IntVar(&sopts.FileStoreOpts.ParallelRecovery, "file_parallel_recovery", stores.DefaultFileStoreOptions.ParallelRecovery, "stan.FileStoreOpts.ParallelRecovery") + fs.BoolVar(&sopts.FileStoreOpts.TruncateUnexpectedEOF, "file_truncate_bad_eof", stores.DefaultFileStoreOptions.TruncateUnexpectedEOF, "Truncate files for which there is an unexpected EOF on recovery, dataloss may occur") + fs.IntVar(&sopts.IOBatchSize, "io_batch_size", DefaultIOBatchSize, "stan.IOBatchSize") + fs.Int64Var(&sopts.IOSleepTime, "io_sleep_time", DefaultIOSleepTime, "stan.IOSleepTime") + fs.StringVar(&sopts.FTGroupName, "ft_group", "", "stan.FTGroupName") + fs.BoolVar(&sopts.Clustering.Clustered, "clustered", false, "stan.Clustering.Clustered") + fs.StringVar(&sopts.Clustering.NodeID, "cluster_node_id", "", "stan.Clustering.NodeID") + fs.BoolVar(&sopts.Clustering.Bootstrap, "cluster_bootstrap", false, "stan.Clustering.Bootstrap") + fs.StringVar(&clusterPeers, "cluster_peers", "", "stan.Clustering.Peers") + fs.StringVar(&sopts.Clustering.RaftLogPath, "cluster_log_path", "", "stan.Clustering.RaftLogPath") + fs.IntVar(&sopts.Clustering.LogCacheSize, "cluster_log_cache_size", DefaultLogCacheSize, "stan.Clustering.LogCacheSize") + fs.IntVar(&sopts.Clustering.LogSnapshots, "cluster_log_snapshots", DefaultLogSnapshots, "stan.Clustering.LogSnapshots") + fs.Int64Var(&sopts.Clustering.TrailingLogs, "cluster_trailing_logs", DefaultTrailingLogs, "stan.Clustering.TrailingLogs") + fs.BoolVar(&sopts.Clustering.Sync, "cluster_sync", false, "stan.Clustering.Sync") + fs.BoolVar(&sopts.Clustering.RaftLogging, "cluster_raft_logging", false, "") + fs.StringVar(&sopts.SQLStoreOpts.Driver, "sql_driver", "", "SQL Driver") + fs.StringVar(&sopts.SQLStoreOpts.Source, "sql_source", "", "SQL Data Source") + defSQLOpts := stores.DefaultSQLStoreOptions() + fs.BoolVar(&sopts.SQLStoreOpts.NoCaching, "sql_no_caching", defSQLOpts.NoCaching, "Enable/Disable caching") + fs.IntVar(&sopts.SQLStoreOpts.MaxOpenConns, "sql_max_open_conns", defSQLOpts.MaxOpenConns, "Max opened connections to the database") + fs.StringVar(&sopts.SyslogName, "syslog_name", "", "Syslog Name") + + // First, we need to call NATS's ConfigureOptions() with above flag set. + // It will be augmented with NATS specific flags and call fs.Parse(args) for us. + nopts, err := natsd.ConfigureOptions(fs, args, printVersion, printHelp, printTLSHelp) + if err != nil { + return nil, nil, err + } + // At this point, if NATS config file was specified in the command line (-c of -config) + // nopts.ConfigFile will not be empty. + natsConfigFile = nopts.ConfigFile + + if clusterPeers != "" { + sopts.Clustering.Peers = []string{} + for _, p := range strings.Split(clusterPeers, ",") { + if p = strings.TrimSpace(p); p != sopts.Clustering.NodeID { + sopts.Clustering.Peers = append(sopts.Clustering.Peers, p) + } + } + } + + // If both nats and streaming configuration files are used, then + // we only use the config file for the corresponding module. + // However, if only one command line parameter was specified, + // we use the same config file for both modules. + if stanConfigFile != "" || natsConfigFile != "" { + // If NATS config file was not specified, but streaming was, use + // streaming config file for NATS too. + if natsConfigFile == "" { + if err := nopts.ProcessConfigFile(stanConfigFile); err != nil { + return nil, nil, err + } + } + // If NATS config file was specified, but not the streaming one, + // use nats config file for streaming too. + if stanConfigFile == "" { + stanConfigFile = natsConfigFile + } + if err := ProcessConfigFile(stanConfigFile, sopts); err != nil { + return nil, nil, err + } + // Need to call Parse() again to override with command line params. + // No need to check for errors since this has already been called + // in natsd.ConfigureOptions() + fs.Parse(args) + } + + // Special handling for some command line params + var flagErr error + fs.Visit(func(f *flag.Flag) { + if flagErr != nil { + return + } + switch f.Name { + case "SDV": + // Check value to support -SDV=false + boolValue, _ := strconv.ParseBool(f.Value.String()) + sopts.Trace, sopts.Debug = boolValue, boolValue + case "max_bytes", "mb": + sopts.MaxBytes, flagErr = getBytes(f) + case "file_compact_min_size": + sopts.FileStoreOpts.CompactMinFileSize, flagErr = getBytes(f) + case "file_buffer_size": + var i64 int64 + i64, flagErr = getBytes(f) + sopts.FileStoreOpts.BufferSize = int(i64) + } + }) + if flagErr != nil { + return nil, nil, flagErr + } + return sopts, nopts, nil +} + +// getBytes returns the number of bytes from the flag's String size. +// For instance, 1KB would return 1024. +func getBytes(f *flag.Flag) (int64, error) { + var res map[string]interface{} + // Use NATS parser to do the conversion for us. + res, err := conf.Parse(fmt.Sprintf("bytes: %v", f.Value.String())) + if err != nil { + return 0, err + } + resVal := res["bytes"] + if resVal == nil || reflect.TypeOf(resVal).Kind() != reflect.Int64 { + return 0, fmt.Errorf("%v should be a size, got '%v'", f.Name, resVal) + } + return resVal.(int64), nil +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/ft.go b/vendor/github.com/nats-io/nats-streaming-server/server/ft.go new file mode 100644 index 00000000000..baad068d8f1 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/ft.go @@ -0,0 +1,236 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "math/rand" + "time" + + "github.com/nats-io/go-nats" + "github.com/nats-io/nats-streaming-server/spb" + "github.com/nats-io/nats-streaming-server/stores" + "github.com/nats-io/nats-streaming-server/util" +) + +// FT constants +const ( + ftDefaultHBInterval = time.Second + ftDefaultHBMissedInterval = 1250 * time.Millisecond +) + +var ( + // Some go-routine will panic, which we can't recover in test. + // So the tests will set this to true to be able to test the + // correct behavior. + ftNoPanic bool + // For tests purposes, we may want to pause for the first + // attempt at getting the store lock so that test can + // switch store with a mocked one. + ftPauseBeforeFirstAttempt bool + ftPauseCh = make(chan struct{}) + // This can be changed for tests purposes. + ftHBInterval = ftDefaultHBInterval + ftHBMissedInterval = ftDefaultHBMissedInterval +) + +func ftReleasePause() { + ftPauseCh <- struct{}{} +} + +// ftStart will return only when this server has become active +// and was able to get the store's exclusive lock. +// This is running in a separate go-routine so if server state +// changes, take care of using the server's lock. +func (s *StanServer) ftStart() (retErr error) { + s.log.Noticef("Starting in standby mode") + // For tests purposes + if ftPauseBeforeFirstAttempt { + <-ftPauseCh + } + print, _ := util.NewBackoffTimeCheck(time.Second, 2, time.Minute) + for { + select { + case <-s.ftQuit: + // we are done + return nil + case <-s.ftHBCh: + // go back to the beginning of the for loop + continue + case <-time.After(s.ftHBMissedInterval): + // try to lock the store + } + locked, err := s.ftGetStoreLock() + if err != nil { + // Log the error, but go back and wait for the next interval and + // try again. It is possible that the error resolves (for instance + // the connection to the database is restored - for SQL stores). + s.log.Errorf("ft: error attempting to get the store lock: %v", err) + continue + } else if locked { + break + } + // Here, we did not get the lock, print and go back to standby. + // Use some backoff for the printing to not fill up the log + if print.Ok() { + s.log.Noticef("ft: unable to get store lock at this time, going back to standby") + } + } + // Capture the time this server activated. It will be used in case several + // servers claim to be active. Not bulletproof since there could be clock + // differences, etc... but when more than one server has acquired the store + // lock it means we are already in trouble, so just trying to minimize the + // possible store corruption... + activationTime := time.Now() + s.log.Noticef("Server is active") + s.startGoRoutine(func() { + s.ftSendHBLoop(activationTime) + }) + // Start the recovery process, etc.. + return s.start(FTActive) +} + +// ftGetStoreLock returns true if the server was able to get the +// exclusive store lock, false othewise, or if there was a fatal error doing so. +func (s *StanServer) ftGetStoreLock() (bool, error) { + // Normally, the store would be set early and is immutable, but some + // FT tests do set a mock store after the server is created, so use + // locking here to avoid race reports. + s.mu.Lock() + store := s.store + s.mu.Unlock() + if ok, err := store.GetExclusiveLock(); !ok || err != nil { + // We got an error not related to locking (could be not supported, + // permissions error, file not reachable, etc..) + if err != nil { + return false, fmt.Errorf("ft: fatal error getting the store lock: %v", err) + } + // If ok is false, it means that we did not get the lock. + return false, nil + } + return true, nil +} + +// ftSendHBLoop is used by an active server to send HB to the FT subject. +// Standby servers receiving those HBs do not attempt to lock the store. +// When they miss HBs, they will. +func (s *StanServer) ftSendHBLoop(activationTime time.Time) { + // Release the wait group on exit + defer s.wg.Done() + + timeAsBytes, _ := activationTime.MarshalBinary() + ftHB := &spb.CtrlMsg{ + MsgType: spb.CtrlMsg_FTHeartbeat, + ServerID: s.serverID, + Data: timeAsBytes, + } + ftHBBytes, _ := ftHB.Marshal() + print, _ := util.NewBackoffTimeCheck(time.Second, 2, time.Minute) + for { + if err := s.ftnc.Publish(s.ftSubject, ftHBBytes); err != nil { + if print.Ok() { + s.log.Errorf("Unable to send FT heartbeat: %v", err) + } + } + startSelect: + select { + case m := <-s.ftHBCh: + hb := spb.CtrlMsg{} + if err := hb.Unmarshal(m.Data); err != nil { + goto startSelect + } + // Ignore our own message + if hb.MsgType != spb.CtrlMsg_FTHeartbeat || hb.ServerID == s.serverID { + goto startSelect + } + // Another server claims to be active + peerActivationTime := time.Time{} + if err := peerActivationTime.UnmarshalBinary(hb.Data); err != nil { + s.log.Errorf("Error decoding activation time: %v", err) + } else { + // Step down if the peer's activation time is earlier than ours. + err := fmt.Errorf("ft: serverID %q claims to be active", hb.ServerID) + if peerActivationTime.Before(activationTime) { + err = fmt.Errorf("%s, aborting", err) + if ftNoPanic { + s.setLastError(err) + return + } + panic(err) + } else { + s.log.Errorf(err.Error()) + } + } + case <-time.After(s.ftHBInterval): + // We'll send the ping at the top of the for loop + case <-s.ftQuit: + return + } + } +} + +// ftSetup checks that all required FT parameters have been specified and +// create the channel required for shutdown. +// Note that FTGroupName has to be set before server invokes this function, +// so this parameter is not checked here. +func (s *StanServer) ftSetup() error { + // Check that store type is ok. So far only support for FileStore + if s.opts.StoreType != stores.TypeFile && s.opts.StoreType != stores.TypeSQL { + return fmt.Errorf("ft: only %v or %v stores supported in FT mode", stores.TypeFile, stores.TypeSQL) + } + // So far, those are not exposed to users, just used in tests. + // Still make sure that the missed HB interval is > than the HB + // interval. + if ftHBMissedInterval < time.Duration(float64(ftHBInterval)*1.1) { + return fmt.Errorf("ft: the missed heartbeat interval needs to be"+ + " at least 10%% of the heartbeat interval (hb=%v missed hb=%v", + ftHBInterval, ftHBMissedInterval) + } + // Set the HB and MissedHB intervals, using a bit of randomness + rand.Seed(time.Now().UnixNano()) + s.ftHBInterval = ftGetRandomInterval(ftHBInterval) + s.ftHBMissedInterval = ftGetRandomInterval(ftHBMissedInterval) + // Subscribe to FT subject + s.ftSubject = fmt.Sprintf("%s.%s.%s", ftHBPrefix, s.opts.ID, s.opts.FTGroupName) + s.ftHBCh = make(chan *nats.Msg) + sub, err := s.ftnc.Subscribe(s.ftSubject, func(m *nats.Msg) { + // Dropping incoming FT HBs is not crucial, we will then check for + // store lock. + select { + case s.ftHBCh <- m: + default: + } + }) + if err != nil { + return fmt.Errorf("ft: unable to subscribe on ft subject: %v", err) + } + // We don't want to cause possible slow consumer error + sub.SetPendingLimits(-1, -1) + // Create channel to notify FT go routine to quit. + s.ftQuit = make(chan struct{}, 1) + // Set the state as standby initially + s.state = FTStandby + return nil +} + +// ftGetRandomInterval returns a random interval with at most +/- 10% +// of the given interval. +func ftGetRandomInterval(interval time.Duration) time.Duration { + tenPercent := int(float64(interval) * 0.10) + random := time.Duration(rand.Intn(tenPercent)) + if rand.Intn(2) == 1 { + return interval + random + } + return interval - random +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/monitor.go b/vendor/github.com/nats-io/nats-streaming-server/server/monitor.go new file mode 100644 index 00000000000..0fcd45d958c --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/monitor.go @@ -0,0 +1,520 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "runtime" + "sort" + "strconv" + "time" + + gnatsd "github.com/nats-io/gnatsd/server" + "github.com/nats-io/nats-streaming-server/stores" +) + +// Routes for the monitoring pages +const ( + RootPath = "/streaming" + ServerPath = RootPath + "/serverz" + StorePath = RootPath + "/storez" + ClientsPath = RootPath + "/clientsz" + ChannelsPath = RootPath + "/channelsz" + + defaultMonitorListLimit = 1024 +) + +// Serverz describes the NATS Streaming Server +type Serverz struct { + ClusterID string `json:"cluster_id"` + ServerID string `json:"server_id"` + Version string `json:"version"` + GoVersion string `json:"go"` + State string `json:"state"` + Now time.Time `json:"now"` + Start time.Time `json:"start_time"` + Uptime string `json:"uptime"` + Clients int `json:"clients"` + Subscriptions int `json:"subscriptions"` + Channels int `json:"channels"` + TotalMsgs int `json:"total_msgs"` + TotalBytes uint64 `json:"total_bytes"` +} + +// Storez describes the NATS Streaming Store +type Storez struct { + ClusterID string `json:"cluster_id"` + ServerID string `json:"server_id"` + Now time.Time `json:"now"` + Type string `json:"type"` + Limits stores.StoreLimits `json:"limits"` + TotalMsgs int `json:"total_msgs"` + TotalBytes uint64 `json:"total_bytes"` +} + +// Clientsz lists the client connections +type Clientsz struct { + ClusterID string `json:"cluster_id"` + ServerID string `json:"server_id"` + Now time.Time `json:"now"` + Offset int `json:"offset"` + Limit int `json:"limit"` + Count int `json:"count"` + Total int `json:"total"` + Clients []*Clientz `json:"clients"` +} + +// Clientz describes a NATS Streaming Client connection +type Clientz struct { + ID string `json:"id"` + HBInbox string `json:"hb_inbox"` + Subscriptions map[string][]*Subscriptionz `json:"subscriptions,omitempty"` +} + +// Channelsz lists the name of all NATS Streaming Channelsz +type Channelsz struct { + ClusterID string `json:"cluster_id"` + ServerID string `json:"server_id"` + Now time.Time `json:"now"` + Offset int `json:"offset"` + Limit int `json:"limit"` + Count int `json:"count"` + Total int `json:"total"` + Names []string `json:"names,omitempty"` + Channels []*Channelz `json:"channels,omitempty"` +} + +// Channelz describes a NATS Streaming Channel +type Channelz struct { + Name string `json:"name"` + Msgs int `json:"msgs"` + Bytes uint64 `json:"bytes"` + FirstSeq uint64 `json:"first_seq"` + LastSeq uint64 `json:"last_seq"` + Subscriptions []*Subscriptionz `json:"subscriptions,omitempty"` +} + +// Subscriptionz describes a NATS Streaming Subscription +type Subscriptionz struct { + ClientID string `json:"client_id"` + Inbox string `json:"inbox"` + AckInbox string `json:"ack_inbox"` + DurableName string `json:"durable_name,omitempty"` + QueueName string `json:"queue_name,omitempty"` + IsDurable bool `json:"is_durable"` + IsOffline bool `json:"is_offline"` + MaxInflight int `json:"max_inflight"` + AckWait int `json:"ack_wait"` + LastSent uint64 `json:"last_sent"` + PendingCount int `json:"pending_count"` + IsStalled bool `json:"is_stalled"` +} + +func (s *StanServer) startMonitoring(nOpts *gnatsd.Options) error { + var hh http.Handler + // If we are connecting to remote NATS Server, we start our own + // HTTP(s) server. + if s.opts.NATSServerURL != "" { + s.natsServer = gnatsd.New(nOpts) + if err := s.natsServer.StartMonitoring(); err != nil { + return err + } + hh = s.natsServer.HTTPHandler() + } else { + hh = s.natsServer.HTTPHandler() + } + if hh == nil { + return errors.New("unable to start monitoring server") + } + + mux := hh.(*http.ServeMux) + mux.HandleFunc(RootPath, s.handleRootz) + mux.HandleFunc(ServerPath, s.handleServerz) + mux.HandleFunc(StorePath, s.handleStorez) + mux.HandleFunc(ClientsPath, s.handleClientsz) + mux.HandleFunc(ChannelsPath, s.handleChannelsz) + + return nil +} + +func (s *StanServer) handleRootz(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, ` + + + + + + NATS Streaming +
+ server
+ store
+ clients
+ channels
+
+ help + +`, ServerPath, StorePath, ClientsPath, ChannelsPath) +} + +func (s *StanServer) handleServerz(w http.ResponseWriter, r *http.Request) { + numChannels := s.channels.count() + count, bytes, err := s.channels.msgsState("") + if err != nil { + http.Error(w, fmt.Sprintf("Error getting information about channels state: %v", err), http.StatusInternalServerError) + return + } + s.mu.RLock() + state := s.state + s.mu.RUnlock() + s.monMu.RLock() + numSubs := s.numSubs + s.monMu.RUnlock() + now := time.Now() + serverz := &Serverz{ + ClusterID: s.info.ClusterID, + ServerID: s.serverID, + Version: VERSION, + GoVersion: runtime.Version(), + State: state.String(), + Now: now, + Start: s.startTime, + Uptime: myUptime(now.Sub(s.startTime)), + Clients: s.clients.count(), + Channels: numChannels, + Subscriptions: numSubs, + TotalMsgs: count, + TotalBytes: bytes, + } + s.sendResponse(w, r, serverz) +} + +func myUptime(d time.Duration) string { + // Just use total seconds for uptime, and display days / years + tsecs := d / time.Second + tmins := tsecs / 60 + thrs := tmins / 60 + tdays := thrs / 24 + tyrs := tdays / 365 + + if tyrs > 0 { + return fmt.Sprintf("%dy%dd%dh%dm%ds", tyrs, tdays%365, thrs%24, tmins%60, tsecs%60) + } + if tdays > 0 { + return fmt.Sprintf("%dd%dh%dm%ds", tdays, thrs%24, tmins%60, tsecs%60) + } + if thrs > 0 { + return fmt.Sprintf("%dh%dm%ds", thrs, tmins%60, tsecs%60) + } + if tmins > 0 { + return fmt.Sprintf("%dm%ds", tmins, tsecs%60) + } + return fmt.Sprintf("%ds", tsecs) +} + +func (s *StanServer) handleStorez(w http.ResponseWriter, r *http.Request) { + count, bytes, err := s.channels.msgsState("") + if err != nil { + http.Error(w, fmt.Sprintf("Error getting information about channels state: %v", err), http.StatusInternalServerError) + return + } + storez := &Storez{ + ClusterID: s.info.ClusterID, + ServerID: s.serverID, + Now: time.Now(), + Type: s.store.Name(), + Limits: s.opts.StoreLimits, + TotalMsgs: count, + TotalBytes: bytes, + } + s.sendResponse(w, r, storez) +} + +type byClientID []*Clientz + +func (c byClientID) Len() int { return len(c) } +func (c byClientID) Swap(i, j int) { c[i], c[j] = c[j], c[i] } +func (c byClientID) Less(i, j int) bool { return c[i].ID < c[j].ID } + +func (s *StanServer) handleClientsz(w http.ResponseWriter, r *http.Request) { + singleClient := r.URL.Query().Get("client") + subsOption, _ := strconv.Atoi(r.URL.Query().Get("subs")) + if singleClient != "" { + clientz := getMonitorClient(s, singleClient, subsOption) + if clientz == nil { + http.Error(w, fmt.Sprintf("Client %s not found", singleClient), http.StatusNotFound) + return + } + s.sendResponse(w, r, clientz) + } else { + offset, limit := getOffsetAndLimit(r) + clients := s.clients.getClients() + totalClients := len(clients) + carr := make([]*Clientz, 0, totalClients) + for cID := range clients { + cz := &Clientz{ID: cID} + carr = append(carr, cz) + } + sort.Sort(byClientID(carr)) + + minoff, maxoff := getMinMaxOffset(offset, limit, totalClients) + carr = carr[minoff:maxoff] + + // Since clients may be unregistered between the time we get the client IDs + // and the time we build carr array, lets count the number of elements + // actually intserted. + carrSize := 0 + for _, c := range carr { + client := s.clients.lookup(c.ID) + if client != nil { + client.RLock() + c.HBInbox = client.info.HbInbox + if subsOption == 1 { + c.Subscriptions = getMonitorClientSubs(client) + } + client.RUnlock() + carrSize++ + } + } + carr = carr[0:carrSize] + clientsz := &Clientsz{ + ClusterID: s.info.ClusterID, + ServerID: s.serverID, + Now: time.Now(), + Offset: offset, + Limit: limit, + Total: totalClients, + Count: len(carr), + Clients: carr, + } + s.sendResponse(w, r, clientsz) + } +} + +func getMonitorClient(s *StanServer, clientID string, subsOption int) *Clientz { + cli := s.clients.lookup(clientID) + if cli == nil { + return nil + } + cli.RLock() + defer cli.RUnlock() + cz := &Clientz{ + HBInbox: cli.info.HbInbox, + ID: cli.info.ID, + } + if subsOption == 1 { + cz.Subscriptions = getMonitorClientSubs(cli) + } + return cz +} + +func getMonitorClientSubs(client *client) map[string][]*Subscriptionz { + subs := client.subs + var subsz map[string][]*Subscriptionz + for _, sub := range subs { + if subsz == nil { + subsz = make(map[string][]*Subscriptionz) + } + array := subsz[sub.subject] + newArray := append(array, createSubscriptionz(sub)) + if &newArray != &array { + subsz[sub.subject] = newArray + } + } + return subsz +} + +func getMonitorChannelSubs(ss *subStore) []*Subscriptionz { + ss.RLock() + defer ss.RUnlock() + subsz := make([]*Subscriptionz, 0) + for _, sub := range ss.psubs { + subsz = append(subsz, createSubscriptionz(sub)) + } + // Get only offline durables (the online also appear in ss.psubs) + for _, sub := range ss.durables { + if sub.ClientID == "" { + subsz = append(subsz, createSubscriptionz(sub)) + } + } + for _, qsub := range ss.qsubs { + qsub.RLock() + for _, sub := range qsub.subs { + subsz = append(subsz, createSubscriptionz(sub)) + } + // If this is a durable queue subscription and all members + // are offline, qsub.shadow will be not nil. Report this one. + if qsub.shadow != nil { + subsz = append(subsz, createSubscriptionz(qsub.shadow)) + } + qsub.RUnlock() + } + return subsz +} + +func createSubscriptionz(sub *subState) *Subscriptionz { + sub.RLock() + subz := &Subscriptionz{ + ClientID: sub.ClientID, + Inbox: sub.Inbox, + AckInbox: sub.AckInbox, + DurableName: sub.DurableName, + QueueName: sub.QGroup, + IsDurable: sub.IsDurable, + IsOffline: (sub.ClientID == ""), + MaxInflight: int(sub.MaxInFlight), + AckWait: int(sub.AckWaitInSecs), + LastSent: sub.LastSent, + PendingCount: len(sub.acksPending), + IsStalled: sub.stalled, + } + // Case of offline durable (queue) subscriptions + if sub.ClientID == "" { + subz.ClientID = sub.savedClientID + } + sub.RUnlock() + return subz +} + +// When we support only Go 1.8+, replace sort with sort.Slice +type byName []string + +func (a byName) Len() int { return (len(a)) } +func (a byName) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a byName) Less(i, j int) bool { return a[i] < a[j] } + +type byChannelName []*Channelz + +func (a byChannelName) Len() int { return (len(a)) } +func (a byChannelName) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a byChannelName) Less(i, j int) bool { return a[i].Name < a[j].Name } + +func (s *StanServer) handleChannelsz(w http.ResponseWriter, r *http.Request) { + channelName := r.URL.Query().Get("channel") + subsOption, _ := strconv.Atoi(r.URL.Query().Get("subs")) + if channelName != "" { + s.handleOneChannel(w, r, channelName, subsOption) + } else { + offset, limit := getOffsetAndLimit(r) + channels := s.channels.getAll() + totalChannels := len(channels) + minoff, maxoff := getMinMaxOffset(offset, limit, totalChannels) + channelsz := &Channelsz{ + ClusterID: s.info.ClusterID, + ServerID: s.serverID, + Now: time.Now(), + Offset: offset, + Limit: limit, + Total: totalChannels, + } + if subsOption == 1 { + carr := make([]*Channelz, 0, totalChannels) + for cn := range channels { + cz := &Channelz{Name: cn} + carr = append(carr, cz) + } + sort.Sort(byChannelName(carr)) + carr = carr[minoff:maxoff] + for _, cz := range carr { + cs := channels[cz.Name] + if err := updateChannelz(cz, cs, subsOption); err != nil { + http.Error(w, fmt.Sprintf("Error getting information about channel %q: %v", channelName, err), http.StatusInternalServerError) + return + } + } + channelsz.Count = len(carr) + channelsz.Channels = carr + } else { + carr := make([]string, 0, totalChannels) + for cn := range channels { + carr = append(carr, cn) + } + sort.Sort(byName(carr)) + carr = carr[minoff:maxoff] + channelsz.Count = len(carr) + channelsz.Names = carr + } + s.sendResponse(w, r, channelsz) + } +} + +func (s *StanServer) handleOneChannel(w http.ResponseWriter, r *http.Request, name string, subsOption int) { + cs := s.channels.get(name) + if cs == nil { + http.Error(w, fmt.Sprintf("Channel %s not found", name), http.StatusNotFound) + return + } + channelz := &Channelz{Name: name} + if err := updateChannelz(channelz, cs, subsOption); err != nil { + http.Error(w, fmt.Sprintf("Error getting information about channel %q: %v", name, err), http.StatusInternalServerError) + return + } + s.sendResponse(w, r, channelz) +} + +func updateChannelz(cz *Channelz, c *channel, subsOption int) error { + msgs, bytes, err := c.store.Msgs.State() + if err != nil { + return fmt.Errorf("unable to get message state: %v", err) + } + fseq, lseq, err := c.store.Msgs.FirstAndLastSequence() + if err != nil { + return fmt.Errorf("unable to get first and last sequence: %v", err) + } + cz.Msgs = msgs + cz.Bytes = bytes + cz.FirstSeq = fseq + cz.LastSeq = lseq + if subsOption == 1 { + cz.Subscriptions = getMonitorChannelSubs(c.ss) + } + return nil +} + +func (s *StanServer) sendResponse(w http.ResponseWriter, r *http.Request, content interface{}) { + b, err := json.MarshalIndent(content, "", " ") + if err != nil { + s.log.Errorf("Error marshaling response to %q request: %v", r.URL, err) + } + gnatsd.ResponseHandler(w, r, b) +} + +func getOffsetAndLimit(r *http.Request) (int, int) { + offset, _ := strconv.Atoi(r.URL.Query().Get("offset")) + if offset < 0 { + offset = 0 + } + limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) + if limit <= 0 { + limit = defaultMonitorListLimit + } + return offset, limit +} + +func getMinMaxOffset(offset, limit, total int) (int, int) { + minoff := offset + if minoff > total { + minoff = total + } + maxoff := offset + limit + if maxoff > total { + maxoff = total + } + return minoff, maxoff +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/partitions.go b/vendor/github.com/nats-io/nats-streaming-server/server/partitions.go new file mode 100644 index 00000000000..2d0ade484b3 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/partitions.go @@ -0,0 +1,263 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "sync" + "time" + + "github.com/nats-io/go-nats" + "github.com/nats-io/nats-streaming-server/spb" + "github.com/nats-io/nats-streaming-server/stores" + "github.com/nats-io/nats-streaming-server/util" +) + +// Constants related to partitioning +const ( + // Prefix of subject to send list of channels in this server's partition + partitionsPrefix = "_STAN.part" + // Default timeout for a server to wait for replies to its request + partitionsDefaultRequestTimeout = time.Second + // This is the value that is stored in the sublist for a given subject + channelInterest = 1 + // Default wait before checking for channels when notified + // that the NATS cluster topology has changed. This gives a chance + // for the new server joining the cluster to send its subscriptions + // list. + partitionsDefaultWaitOnTopologyChange = 500 * time.Millisecond +) + +// So that we can override in tests +var ( + partitionsRequestTimeout = partitionsDefaultRequestTimeout + partitionsNoPanic = false + partitionsWaitOnChange = partitionsDefaultWaitOnTopologyChange +) + +type partitions struct { + sync.Mutex + s *StanServer + channels []string + sl *util.Sublist + nc *nats.Conn + sendListSubject string + processChanSub *nats.Subscription + inboxSub *nats.Subscription + isShutdown bool +} + +// Initialize the channels partitions objects and issue the first +// request to check if other servers in the cluster incorrectly have +// any of the channel that this server is supposed to handle. +func (s *StanServer) initPartitions() error { + // The option says that the server should only use the pre-defined channels, + // but none was specified. Don't see the point in continuing... + if len(s.opts.StoreLimits.PerChannel) == 0 { + return ErrNoChannel + } + nc, err := s.createNatsClientConn("pc") + if err != nil { + return err + } + p := &partitions{ + s: s, + nc: nc, + } + // Now that the connection is created, we need to set s.partitioning to cp + // so that server shutdown can properly close this connection. + s.partitions = p + p.createChannelsMapAndSublist(s.opts.StoreLimits.PerChannel) + p.sendListSubject = partitionsPrefix + "." + s.opts.ID + // Use the partitions' own connection for channels list requests + p.processChanSub, err = p.nc.Subscribe(p.sendListSubject, p.processChannelsListRequests) + if err != nil { + return fmt.Errorf("unable to subscribe: %v", err) + } + p.processChanSub.SetPendingLimits(-1, -1) + p.inboxSub, err = p.nc.SubscribeSync(nats.NewInbox()) + if err != nil { + return fmt.Errorf("unable to subscribe: %v", err) + } + p.Lock() + // Set this before the first attempt so we don't miss any notification + // of a change in topology. Since we hold the lock, and even if there + // was a notification happening now, the callback will execute only + // after we are done with the initial check. + nc.SetDiscoveredServersHandler(p.topologyChanged) + // Now send our list and check if any server is complaining + // about having one channel in common. + if err := p.checkChannelsUniqueInCluster(); err != nil { + p.Unlock() + return err + } + p.Unlock() + return nil +} + +// Creates the channels map based on the store's PerChannel map that was given. +func (p *partitions) createChannelsMapAndSublist(storeChannels map[string]*stores.ChannelLimits) { + p.channels = make([]string, 0, len(storeChannels)) + p.sl = util.NewSublist() + for c := range storeChannels { + p.channels = append(p.channels, c) + // When creating the store, we have already checked that channel names + // were valid. So this call cannot fail. + p.sl.Insert(c, channelInterest) + } +} + +// Topology changed. Sends the list of channels. +func (p *partitions) topologyChanged(_ *nats.Conn) { + p.Lock() + defer p.Unlock() + if p.isShutdown { + return + } + // Let's wait before checking (sending the list and waiting for a reply) + // so that the new NATS Server has a chance to send its local + // subscriptions to the rest of the cluster. That will reduce the risk + // of missing the reply from the new server. + time.Sleep(partitionsWaitOnChange) + if err := p.checkChannelsUniqueInCluster(); err != nil { + // If server is started from command line, the Fatalf + // call will cause the process to exit. If the server + // is run programmatically and no logger has been set + // we need to exit with the panic. + p.s.log.Fatalf("Partitioning error: %v", err) + // For tests + if partitionsNoPanic { + p.s.setLastError(err) + return + } + panic(err) + } +} + +// Create the internal subscriptions on the list of channels. +func (p *partitions) initSubscriptions() error { + // NOTE: Use the server's nc connection here, not the partitions' one. + for _, channelName := range p.channels { + pubSubject := fmt.Sprintf("%s.%s", p.s.info.Publish, channelName) + if _, err := p.s.nc.Subscribe(pubSubject, p.s.processClientPublish); err != nil { + return fmt.Errorf("could not subscribe to publish subject %q, %v", channelName, err) + } + } + return nil +} + +// Sends a request to the rest of the cluster and wait a bit for +// responses (we don't know if or how many servers there may be). +// No server lock used since this is called inside RunServerWithOpts(). +func (p *partitions) checkChannelsUniqueInCluster() error { + // We use the subscription on an inbox to get the replies. + // Send our list + if err := util.SendChannelsList(p.channels, p.sendListSubject, p.inboxSub.Subject, p.nc, p.s.serverID); err != nil { + return fmt.Errorf("unable to send channels list: %v", err) + } + // Since we don't know how many servers are out there, keep + // calling NextMsg until we get a timeout + for { + reply, err := p.inboxSub.NextMsg(partitionsRequestTimeout) + if err == nats.ErrTimeout { + return nil + } + if err != nil { + return fmt.Errorf("unable to get partitioning reply: %v", err) + } + resp := spb.CtrlMsg{} + if err := resp.Unmarshal(reply.Data); err != nil { + return fmt.Errorf("unable to decode partitioning response: %v", err) + } + if len(resp.Data) > 0 { + return fmt.Errorf("channel %q causes conflict with channels on server %q", + string(resp.Data), resp.ServerID) + } + } +} + +// Decode the incoming partitioning protocol message. +// It can be an HB, in which case, if it is from a new server +// we send our list to the cluster, or it can be a request +// from another server. If so, we reply to the given inbox +// with either an empty Data field or the name of the first +// channel we have in common. +func (p *partitions) processChannelsListRequests(m *nats.Msg) { + // Message cannot be empty, we are supposed to receive + // a spb.CtrlMsg_Partitioning protocol. We should also + // have a repy subject + if len(m.Data) == 0 || m.Reply == "" { + return + } + req := spb.CtrlMsg{} + if err := req.Unmarshal(m.Data); err != nil { + p.s.log.Errorf("Error processing partitioning request: %v", err) + return + } + // If this is our own request, ignore + if req.ServerID == p.s.serverID { + return + } + channels, err := util.DecodeChannels(req.Data) + if err != nil { + p.s.log.Errorf("Error processing partitioning request: %v", err) + return + } + // Check that we don't have any of these channels defined. + // If we do, send a reply with simply the name of the offending + // channel in reply.Data + reply := spb.CtrlMsg{ + ServerID: p.s.serverID, + MsgType: spb.CtrlMsg_Partitioning, + } + gotError := false + sl := util.NewSublist() + for _, c := range channels { + if r := p.sl.Match(c); len(r) > 0 { + reply.Data = []byte(c) + gotError = true + break + } + sl.Insert(c, channelInterest) + } + if !gotError { + // Go over our channels and check with the other server sublist + for _, c := range p.channels { + if r := sl.Match(c); len(r) > 0 { + reply.Data = []byte(c) + break + } + } + } + replyBytes, _ := reply.Marshal() + // If there is no duplicate, reply.Data will be empty, which means + // that there was no conflict. + if err := p.nc.Publish(m.Reply, replyBytes); err != nil { + p.s.log.Errorf("Error sending reply to partitioning request: %v", err) + } +} + +// Notifies all go-routines used by partitioning code that the +// server is shuting down and closes the internal NATS connection. +func (p *partitions) shutdown() { + p.Lock() + defer p.Unlock() + if p.isShutdown { + return + } + p.isShutdown = true + if p.nc != nil { + p.nc.Close() + } +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/raft_log.go b/vendor/github.com/nats-io/nats-streaming-server/server/raft_log.go new file mode 100644 index 00000000000..f22922f5fde --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/raft_log.go @@ -0,0 +1,452 @@ +// Copyright 2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "encoding/binary" + "errors" + "fmt" + "os" + "sync" + "time" + + "github.com/boltdb/bolt" + "github.com/hashicorp/go-msgpack/codec" + "github.com/hashicorp/raft" + "github.com/nats-io/nats-streaming-server/logger" +) + +// Bucket names +var ( + logsBucket = []byte("logs") + confBucket = []byte("conf") +) + +// When a key is not found. Raft checks the error text, and it needs to be exactly "not found" +var errKeyNotFound = errors.New("not found") + +// raftLog implements both the raft LogStore and Stable interfaces. This is used +// by raft to store logs and configuration changes. +type raftLog struct { + sync.RWMutex + log logger.Logger + conn *bolt.DB + fileName string + noSync bool + trailingLogs int + ratioThreshold int + simpleDelThresholdHigh int + simpleDelThresholdLow int + codec *codec.MsgpackHandle + closed bool +} + +func newRaftLog(log logger.Logger, fileName string, sync bool, trailingLogs int) (*raftLog, error) { + r := &raftLog{ + log: log, + fileName: fileName, + noSync: !sync, + trailingLogs: trailingLogs, + ratioThreshold: 50, + simpleDelThresholdLow: 1000, + simpleDelThresholdHigh: 100000, + codec: &codec.MsgpackHandle{}, + } + conn, err := r.openAndSetOptions(fileName) + if err != nil { + return nil, err + } + r.conn = conn + if err := r.init(); err != nil { + r.conn.Close() + return nil, err + } + return r, nil +} + +func (r *raftLog) init() error { + tx, err := r.conn.Begin(true) + if err != nil { + return err + } + defer tx.Rollback() + + // Create the configuration and logs buckets + if _, err := tx.CreateBucketIfNotExists(confBucket); err != nil { + return err + } + if _, err := tx.CreateBucketIfNotExists(logsBucket); err != nil { + return err + } + return tx.Commit() +} + +func (r *raftLog) openAndSetOptions(fileName string) (*bolt.DB, error) { + db, err := bolt.Open(fileName, 0600, nil) + if err != nil { + return nil, err + } + db.NoSync = r.noSync + return db, nil +} + +func (r *raftLog) encodeRaftLog(in *raft.Log) ([]byte, error) { + var buf []byte + enc := codec.NewEncoderBytes(&buf, r.codec) + err := enc.Encode(in) + return buf, err +} + +func (r *raftLog) decodeRaftLog(buf []byte, log *raft.Log) error { + dec := codec.NewDecoderBytes(buf, r.codec) + return dec.Decode(log) +} + +// Close implements the LogStore interface +func (r *raftLog) Close() error { + r.Lock() + if r.closed { + r.Unlock() + return nil + } + r.closed = true + err := r.conn.Close() + r.Unlock() + return err +} + +// FirstIndex implements the LogStore interface +func (r *raftLog) FirstIndex() (uint64, error) { + r.RLock() + idx, err := r.getIndex(true) + r.RUnlock() + return idx, err +} + +// LastIndex implements the LogStore interface +func (r *raftLog) LastIndex() (uint64, error) { + r.RLock() + idx, err := r.getIndex(false) + r.RUnlock() + return idx, err +} + +// returns either the first (if first is true) or the last +// index of the logs bucket. +func (r *raftLog) getIndex(first bool) (uint64, error) { + tx, err := r.conn.Begin(false) + if err != nil { + return 0, err + } + var ( + key []byte + idx uint64 + ) + curs := tx.Bucket(logsBucket).Cursor() + if first { + key, _ = curs.First() + } else { + key, _ = curs.Last() + } + if key != nil { + idx = binary.BigEndian.Uint64(key) + } + tx.Rollback() + return idx, nil +} + +// GetLog implements the LogStore interface +func (r *raftLog) GetLog(idx uint64, log *raft.Log) error { + r.RLock() + tx, err := r.conn.Begin(false) + if err != nil { + r.RUnlock() + return err + } + var key [8]byte + binary.BigEndian.PutUint64(key[:], idx) + bucket := tx.Bucket(logsBucket) + val := bucket.Get(key[:]) + if val == nil { + err = raft.ErrLogNotFound + } else { + err = r.decodeRaftLog(val, log) + } + tx.Rollback() + r.RUnlock() + return err +} + +// StoreLog implements the LogStore interface +func (r *raftLog) StoreLog(log *raft.Log) error { + return r.StoreLogs([]*raft.Log{log}) +} + +// StoreLogs implements the LogStore interface +func (r *raftLog) StoreLogs(logs []*raft.Log) error { + r.RLock() + tx, err := r.conn.Begin(true) + if err != nil { + r.RUnlock() + return err + } + for _, log := range logs { + var ( + key [8]byte + val []byte + ) + binary.BigEndian.PutUint64(key[:], log.Index) + val, err = r.encodeRaftLog(log) + if err != nil { + break + } + bucket := tx.Bucket(logsBucket) + err = bucket.Put(key[:], val) + if err != nil { + break + } + } + if err != nil { + tx.Rollback() + } else { + err = tx.Commit() + } + r.RUnlock() + return err +} + +// DeleteRange implements the LogStore interface +func (r *raftLog) DeleteRange(min, max uint64) (retErr error) { + r.Lock() + defer r.Unlock() + + start := time.Now() + r.log.Noticef("Deleting raft logs from %v to %v", min, max) + defer func() { + dur := time.Since(start) + durTxt := fmt.Sprintf("Deletion took %v", dur) + if dur > 2*time.Second { + r.log.Errorf(fmt.Sprintf("%s. This is too long, consider lowering TrailingLogs value currently set to %v", + durTxt, r.trailingLogs)) + } else { + r.log.Noticef(durTxt) + } + }() + + // We know that RAFT is calling DeleteRange leaving at least + // trailingLogs number of logs at the end. + + // If the selected number of trailingLogs (value set when RAFT is created) + // is too big, perform a simple delete range that removes logs from the DB. + if r.trailingLogs > r.simpleDelThresholdHigh { + return r.simpleDeleteRange(min, max) + } + // If the number of logs to delete is small, remove in place. + toRemove := int(max-min) + 1 + if toRemove <= r.simpleDelThresholdLow { + return r.simpleDeleteRange(min, max) + } + r.log.Noticef("Compaction in progress...") + + newfileName := r.fileName + ".new" + newdb, err := r.openAndSetOptions(newfileName) + if err != nil { + return err + } + removeNewFile := true + defer func() { + if removeNewFile { + newdb.Close() + os.Remove(newfileName) + } + }() + + curDBConn := r.conn + newDBConn := newdb + + // First, transfer all confLogs, there should not be that many + if err := r.transferLogs(curDBConn, newDBConn, confBucket, 1); err != nil { + return err + } + + if err := r.transferLogs(curDBConn, newDBConn, logsBucket, max+1); err != nil { + return err + } + + // Close new db + if err := newdb.Close(); err != nil { + return err + } + // Close current db + if err := curDBConn.Close(); err != nil { + // We got an error, try to reopen + db, cerr := r.openAndSetOptions(r.fileName) + if cerr != nil { + // At this point, panic... + panic(fmt.Errorf("error closing bolt db after compaction: %v - error re-opening: %v", err, cerr)) + } + r.conn = db + return err + } + // Rename new db file to the name of old one + os.Rename(newfileName, r.fileName) + // Reopen the compacted db file + db, err := r.openAndSetOptions(r.fileName) + if err != nil { + // At this point, panic... + panic(fmt.Errorf("error compacting bolt db: %v", err)) + } + // We now point to the compact db file + r.conn = db + // Success, skip cleanup code + removeNewFile = false + return nil +} + +// Delete logs from the "logs" bucket starting at the min index +// and up to max index (included). +// Lock is held on entry +func (r *raftLog) simpleDeleteRange(min, max uint64) error { + var key [8]byte + binary.BigEndian.PutUint64(key[:], min) + tx, err := r.conn.Begin(true) + if err != nil { + return err + } + defer tx.Rollback() + curs := tx.Bucket(logsBucket).Cursor() + for k, _ := curs.Seek(key[:]); k != nil; k, _ = curs.Next() { + // If we reach the max, we are done + if binary.BigEndian.Uint64(k) > max { + break + } + if err := curs.Delete(); err != nil { + return err + } + } + return tx.Commit() +} + +func (r *raftLog) transferLogs(curDB, newDB *bolt.DB, bucketName []byte, startKey uint64) error { + readTX, err := curDB.Begin(false) + if err != nil { + return err + } + // Read transactions must be rollback (not committed) + defer readTX.Rollback() + + var ( + key [8]byte + count int + limit = 1000 + writeTX *bolt.Tx + writeBucket *bolt.Bucket + ) + binary.BigEndian.PutUint64(key[:], startKey) + + curs := readTX.Bucket(bucketName).Cursor() + + for k, v := curs.Seek(key[:]); k != nil; k, v = curs.Next() { + if count == 0 { + writeTX, err = newDB.Begin(true) + if err != nil { + return err + } + writeBucket = writeTX.Bucket(bucketName) + if writeBucket == nil { + b, err := writeTX.CreateBucket(bucketName) + if err != nil { + writeTX.Rollback() + return err + } + writeBucket = b + } + } + if err := writeBucket.Put(k, v); err != nil { + writeTX.Rollback() + return err + } + count++ + if count == limit { + count = 0 + if err := writeTX.Commit(); err != nil { + return err + } + writeTX = nil + } + } + if writeTX != nil { + return writeTX.Commit() + } + return nil +} + +// Set implements the Stable interface +func (r *raftLog) Set(k, v []byte) error { + r.RLock() + tx, err := r.conn.Begin(true) + if err != nil { + r.RUnlock() + return err + } + bucket := tx.Bucket(confBucket) + err = bucket.Put(k, v) + if err != nil { + tx.Rollback() + } else { + err = tx.Commit() + } + r.RUnlock() + return err +} + +// Get implements the Stable interface +func (r *raftLog) Get(k []byte) ([]byte, error) { + r.RLock() + tx, err := r.conn.Begin(false) + if err != nil { + r.RUnlock() + return nil, err + } + var v []byte + bucket := tx.Bucket(confBucket) + val := bucket.Get(k) + if val == nil { + err = errKeyNotFound + } else { + // Make a copy + v = append([]byte(nil), val...) + } + tx.Rollback() + r.RUnlock() + return v, err +} + +// SetUint64 implements the Stable interface +func (r *raftLog) SetUint64(k []byte, v uint64) error { + var vbytes [8]byte + binary.BigEndian.PutUint64(vbytes[:], v) + err := r.Set(k, vbytes[:]) + return err +} + +// GetUint64 implements the Stable interface +func (r *raftLog) GetUint64(k []byte) (uint64, error) { + var v uint64 + vbytes, err := r.Get(k) + if err == nil { + v = binary.BigEndian.Uint64(vbytes) + } + return v, err +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/raft_transport.go b/vendor/github.com/nats-io/nats-streaming-server/server/raft_transport.go new file mode 100644 index 00000000000..31942fe26dc --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/raft_transport.go @@ -0,0 +1,393 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RAFT Transport implementation using NATS + +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "sync" + "time" + + "github.com/hashicorp/raft" + "github.com/nats-io/go-nats" +) + +const ( + natsConnectInbox = "raft.%s.accept" + natsRequestInbox = "raft.%s.request.%s" +) + +// natsAddr implements the net.Addr interface. An address for the NATS +// transport is simply a node id, which is then used to construct an inbox. +type natsAddr string + +func (n natsAddr) Network() string { + return "nats" +} + +func (n natsAddr) String() string { + return string(n) +} + +type connectRequestProto struct { + ID string `json:"id"` + Inbox string `json:"inbox"` +} + +type connectResponseProto struct { + Inbox string `json:"inbox"` +} + +// natsConn implements the net.Conn interface by simulating a stream-oriented +// connection between two peers. It does this by establishing a unique inbox at +// each endpoint which the peers use to stream data to each other. +type natsConn struct { + conn *nats.Conn + localAddr natsAddr + remoteAddr natsAddr + sub *nats.Subscription + outbox string + mu sync.RWMutex + closed bool + reader *timeoutReader + writer io.WriteCloser + parent *natsStreamLayer +} + +func (n *natsConn) Read(b []byte) (int, error) { + n.mu.RLock() + closed := n.closed + n.mu.RUnlock() + if closed { + return 0, errors.New("read from closed conn") + } + return n.reader.Read(b) +} + +func (n *natsConn) Write(b []byte) (int, error) { + n.mu.RLock() + closed := n.closed + n.mu.RUnlock() + if closed { + return 0, errors.New("write to closed conn") + } + + if len(b) == 0 { + return 0, nil + } + + // Send data in chunks to avoid hitting max payload. + for i := 0; i < len(b); { + chunkSize := min(int64(len(b[i:])), n.conn.MaxPayload()) + if err := n.conn.Publish(n.outbox, b[i:int64(i)+chunkSize]); err != nil { + return i, err + } + i += int(chunkSize) + } + + return len(b), nil +} + +func (n *natsConn) Close() error { + return n.close(true) +} + +func (n *natsConn) close(signalRemote bool) error { + n.mu.Lock() + defer n.mu.Unlock() + + if n.closed { + return nil + } + + if err := n.sub.Unsubscribe(); err != nil { + return err + } + + if signalRemote { + // Send empty message to signal EOF for a graceful disconnect. Not + // concerned with errors here as this is best effort. + n.conn.Publish(n.outbox, nil) + // Best effort, don't block for too long and don't check returned error. + n.conn.FlushTimeout(500 * time.Millisecond) + } + + n.closed = true + n.parent.mu.Lock() + delete(n.parent.conns, n) + n.parent.mu.Unlock() + n.writer.Close() + + return nil +} + +func (n *natsConn) LocalAddr() net.Addr { + return n.localAddr +} + +func (n *natsConn) RemoteAddr() net.Addr { + return n.remoteAddr +} + +func (n *natsConn) SetDeadline(t time.Time) error { + if err := n.SetReadDeadline(t); err != nil { + return err + } + return n.SetWriteDeadline(t) +} + +func (n *natsConn) SetReadDeadline(t time.Time) error { + n.reader.SetDeadline(t) + return nil +} + +func (n *natsConn) SetWriteDeadline(t time.Time) error { + return nil +} + +func (n *natsConn) msgHandler(msg *nats.Msg) { + // Check if remote peer disconnected. + if len(msg.Data) == 0 { + n.close(false) + return + } + + n.writer.Write(msg.Data) +} + +// natsStreamLayer implements the raft.StreamLayer interface. +type natsStreamLayer struct { + conn *nats.Conn + localAddr natsAddr + sub *nats.Subscription + logger *log.Logger + conns map[*natsConn]struct{} + mu sync.Mutex + timeout time.Duration +} + +func newNATSStreamLayer(id string, conn *nats.Conn, logger *log.Logger, timeout time.Duration) (*natsStreamLayer, error) { + n := &natsStreamLayer{ + localAddr: natsAddr(id), + conn: conn, + logger: logger, + conns: map[*natsConn]struct{}{}, + timeout: timeout, + } + sub, err := conn.SubscribeSync(fmt.Sprintf(natsConnectInbox, id)) + if err != nil { + return nil, err + } + sub.SetPendingLimits(-1, -1) + if err := conn.FlushTimeout(timeout); err != nil { + sub.Unsubscribe() + return nil, err + } + n.sub = sub + return n, nil +} + +func (n *natsStreamLayer) newNATSConn(address string) *natsConn { + // TODO: probably want a buffered pipe. + reader, writer := io.Pipe() + return &natsConn{ + conn: n.conn, + localAddr: n.localAddr, + remoteAddr: natsAddr(address), + reader: newTimeoutReader(reader), + writer: writer, + parent: n, + } +} + +// Dial creates a new net.Conn with the remote address. This is implemented by +// performing a handshake over NATS which establishes unique inboxes at each +// endpoint for streaming data. +func (n *natsStreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { + if !n.conn.IsConnected() { + return nil, errors.New("raft-nats: dial failed, not connected") + } + + // QUESTION: The Raft NetTransport does connection pooling, which is useful + // for TCP sockets. The NATS transport simulates a socket using a + // subscription at each endpoint, but everything goes over the same NATS + // socket. This means there is little advantage to pooling here currently. + // Should we actually Dial a new NATS connection here and rely on pooling? + + connect := &connectRequestProto{ + ID: n.localAddr.String(), + Inbox: fmt.Sprintf(natsRequestInbox, n.localAddr.String(), nats.NewInbox()), + } + data, err := json.Marshal(connect) + if err != nil { + panic(err) + } + + peerConn := n.newNATSConn(string(address)) + + // Setup inbox. + sub, err := n.conn.Subscribe(connect.Inbox, peerConn.msgHandler) + if err != nil { + return nil, err + } + sub.SetPendingLimits(-1, -1) + if err := n.conn.FlushTimeout(n.timeout); err != nil { + sub.Unsubscribe() + return nil, err + } + + // Make connect request to peer. + msg, err := n.conn.Request(fmt.Sprintf(natsConnectInbox, address), data, timeout) + if err != nil { + sub.Unsubscribe() + return nil, err + } + var resp connectResponseProto + if err := json.Unmarshal(msg.Data, &resp); err != nil { + sub.Unsubscribe() + return nil, err + } + + peerConn.sub = sub + peerConn.outbox = resp.Inbox + n.mu.Lock() + n.conns[peerConn] = struct{}{} + n.mu.Unlock() + return peerConn, nil +} + +// Accept waits for and returns the next connection to the listener. +func (n *natsStreamLayer) Accept() (net.Conn, error) { + for { + msg, err := n.sub.NextMsgWithContext(context.TODO()) + if err != nil { + return nil, err + } + if msg.Reply == "" { + n.logger.Println("[ERR] raft-nats: Invalid connect message (missing reply inbox)") + continue + } + + var connect connectRequestProto + if err := json.Unmarshal(msg.Data, &connect); err != nil { + n.logger.Println("[ERR] raft-nats: Invalid connect message (invalid data)") + continue + } + + peerConn := n.newNATSConn(connect.ID) + peerConn.outbox = connect.Inbox + + // Setup inbox for peer. + inbox := fmt.Sprintf(natsRequestInbox, n.localAddr.String(), nats.NewInbox()) + sub, err := n.conn.Subscribe(inbox, peerConn.msgHandler) + if err != nil { + n.logger.Printf("[ERR] raft-nats: Failed to create inbox for remote peer: %v", err) + continue + } + sub.SetPendingLimits(-1, -1) + // Reply to peer. + resp := &connectResponseProto{Inbox: inbox} + data, err := json.Marshal(resp) + if err != nil { + panic(err) + } + if err := n.conn.Publish(msg.Reply, data); err != nil { + n.logger.Printf("[ERR] raft-nats: Failed to send connect response to remote peer: %v", err) + sub.Unsubscribe() + continue + } + if err := n.conn.FlushTimeout(n.timeout); err != nil { + n.logger.Printf("[ERR] raft-nats: Failed to flush connect response to remote peer: %v", err) + sub.Unsubscribe() + continue + } + peerConn.sub = sub + n.mu.Lock() + n.conns[peerConn] = struct{}{} + n.mu.Unlock() + return peerConn, nil + } +} + +func (n *natsStreamLayer) Close() error { + n.mu.Lock() + conns := make(map[*natsConn]struct{}, len(n.conns)) + for conn, s := range n.conns { + conns[conn] = s + } + n.mu.Unlock() + for c := range conns { + c.Close() + } + return n.sub.Unsubscribe() +} + +func (n *natsStreamLayer) Addr() net.Addr { + return n.localAddr +} + +// newNATSTransport creates a new raft.NetworkTransport implemented with NATS +// as the transport layer. +func newNATSTransport(id string, conn *nats.Conn, timeout time.Duration, logOutput io.Writer) (*raft.NetworkTransport, error) { + if logOutput == nil { + logOutput = os.Stderr + } + return newNATSTransportWithLogger(id, conn, timeout, log.New(logOutput, "", log.LstdFlags)) +} + +// newNATSTransportWithLogger creates a new raft.NetworkTransport implemented +// with NATS as the transport layer using the provided Logger. +func newNATSTransportWithLogger(id string, conn *nats.Conn, timeout time.Duration, logger *log.Logger) (*raft.NetworkTransport, error) { + return createNATSTransport(id, conn, logger, timeout, func(stream raft.StreamLayer) *raft.NetworkTransport { + return raft.NewNetworkTransportWithLogger(stream, 3, timeout, logger) + }) +} + +// newNATSTransportWithConfig returns a raft.NetworkTransport implemented +// with NATS as the transport layer, using the given config struct. +func newNATSTransportWithConfig(id string, conn *nats.Conn, config *raft.NetworkTransportConfig) (*raft.NetworkTransport, error) { + if config.Timeout == 0 { + config.Timeout = 2 * time.Second + } + return createNATSTransport(id, conn, config.Logger, config.Timeout, func(stream raft.StreamLayer) *raft.NetworkTransport { + config.Stream = stream + return raft.NewNetworkTransportWithConfig(config) + }) +} + +func createNATSTransport(id string, conn *nats.Conn, logger *log.Logger, timeout time.Duration, + transportCreator func(stream raft.StreamLayer) *raft.NetworkTransport) (*raft.NetworkTransport, error) { + + stream, err := newNATSStreamLayer(id, conn, logger, timeout) + if err != nil { + return nil, err + } + + return transportCreator(stream), nil +} + +func min(x, y int64) int64 { + if x < y { + return x + } + return y +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/server.go b/vendor/github.com/nats-io/nats-streaming-server/server/server.go new file mode 100644 index 00000000000..5b80b5d5ec8 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/server.go @@ -0,0 +1,5009 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "errors" + "fmt" + "net" + "net/url" + "os" + "path/filepath" + "regexp" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/hashicorp/raft" + natsdLogger "github.com/nats-io/gnatsd/logger" + "github.com/nats-io/gnatsd/server" + "github.com/nats-io/go-nats" + "github.com/nats-io/go-nats-streaming/pb" + "github.com/nats-io/nuid" + + "github.com/nats-io/nats-streaming-server/logger" + "github.com/nats-io/nats-streaming-server/spb" + "github.com/nats-io/nats-streaming-server/stores" + "github.com/nats-io/nats-streaming-server/util" +) + +// A single NATS Streaming Server + +// Server defaults. +const ( + // VERSION is the current version for the NATS Streaming server. + VERSION = "0.11.0" + + DefaultClusterID = "test-cluster" + DefaultDiscoverPrefix = "_STAN.discover" + DefaultPubPrefix = "_STAN.pub" + DefaultSubPrefix = "_STAN.sub" + DefaultSubClosePrefix = "_STAN.subclose" + DefaultUnSubPrefix = "_STAN.unsub" + DefaultClosePrefix = "_STAN.close" + defaultAcksPrefix = "_STAN.ack" + defaultSnapshotPrefix = "_STAN.snap" + defaultRaftPrefix = "_STAN.raft" + DefaultStoreType = stores.TypeMemory + + // Prefix of subject active server is sending HBs to + ftHBPrefix = "_STAN.ft" + + // DefaultHeartBeatInterval is the interval at which server sends heartbeat to a client + DefaultHeartBeatInterval = 30 * time.Second + // DefaultClientHBTimeout is how long server waits for a heartbeat response + DefaultClientHBTimeout = 10 * time.Second + // DefaultMaxFailedHeartBeats is the number of failed heartbeats before server closes + // the client connection (total= (heartbeat interval + heartbeat timeout) * (fail count + 1) + DefaultMaxFailedHeartBeats = int((5 * time.Minute) / DefaultHeartBeatInterval) + + // Timeout used to ping the known client when processing a connection + // request for a duplicate client ID. + defaultCheckDupCIDTimeout = 500 * time.Millisecond + + // DefaultIOBatchSize is the maximum number of messages to accumulate before flushing a store. + DefaultIOBatchSize = 1024 + + // DefaultIOSleepTime is the duration (in micro-seconds) the server waits for more messages + // before starting processing. Set to 0 (or negative) to disable the wait. + DefaultIOSleepTime = int64(0) + + // DefaultLogCacheSize is the number of Raft log entries to cache in memory + // to reduce disk IO. + DefaultLogCacheSize = 512 + + // DefaultLogSnapshots is the number of Raft log snapshots to retain. + DefaultLogSnapshots = 2 + + // DefaultTrailingLogs is the number of log entries to leave after a + // snapshot and compaction. + DefaultTrailingLogs = 10240 + + // Length of the channel used to schedule subscriptions start requests. + // Subscriptions requests are processed from the same NATS subscription. + // When a subscriber starts and it has pending messages, the server + // processes the new subscription request and sends avail messages out + // (up to MaxInflight). When done in place, this can cause other + // new subscriptions requests to timeout. Server uses a channel to schedule + // start (that is sending avail messages) of new subscriptions. This is + // the default length of that channel. + defaultSubStartChanLen = 2048 + + // Name of the file to store Raft log. + raftLogFile = "raft.log" + + // In partitioning mode, when a client connects, the connect request + // may reach several servers, but the first response the client gets + // allows it to proceed with either publish or subscribe. + // So it is possible for a server running in partitioning mode to + // receives a connection request followed by a message or subscription. + // Although the conn request would be first in the tcp connection, it + // is then possible that the PubMsg or SubscriptionRequest be processed + // first due to the use of different nats subscriptions. + // To prevent that, when checking if a client exists, in this particular + // mode we will possibly wait to be notified when the client has been + // registered. This is the default duration for this wait. + defaultClientCheckTimeout = 4 * time.Second + + // Interval at which server goes through list of subscriptions with + // pending sent/ack operations that needs to be replicated. + defaultLazyReplicationInterval = time.Second +) + +// Constant to indicate that sendMsgToSub() should check number of acks pending +// against MaxInFlight to know if message should be sent out. +const ( + forceDelivery = true + honorMaxInFlight = false +) + +// Constants to indicate if we are replicating a Sent or an Ack +const ( + replicateSent = true + replicateAck = false +) + +const ( + // Client send connID in ConnectRequest and PubMsg, and server + // listens and responds to client PINGs. The validity of the + // connection (based on connID) is checked on incoming PINGs. + protocolOne = int32(1) +) + +// Errors. +var ( + ErrInvalidSubject = errors.New("stan: invalid subject") + ErrInvalidStart = errors.New("stan: invalid start position") + ErrInvalidSub = errors.New("stan: invalid subscription") + ErrInvalidClient = errors.New("stan: clientID already registered") + ErrMissingClient = errors.New("stan: clientID missing") + ErrInvalidClientID = errors.New("stan: invalid clientID: only alphanumeric and `-` or `_` characters allowed") + ErrInvalidAckWait = errors.New("stan: invalid ack wait time, should be >= 1s") + ErrInvalidMaxInflight = errors.New("stan: invalid MaxInflight, should be >= 1") + ErrInvalidConnReq = errors.New("stan: invalid connection request") + ErrInvalidPubReq = errors.New("stan: invalid publish request") + ErrInvalidSubReq = errors.New("stan: invalid subscription request") + ErrInvalidUnsubReq = errors.New("stan: invalid unsubscribe request") + ErrInvalidCloseReq = errors.New("stan: invalid close request") + ErrDupDurable = errors.New("stan: duplicate durable registration") + ErrInvalidDurName = errors.New("stan: durable name of a durable queue subscriber can't contain the character ':'") + ErrUnknownClient = errors.New("stan: unknown clientID") + ErrNoChannel = errors.New("stan: no configured channel") + ErrClusteredRestart = errors.New("stan: cannot restart server in clustered mode if it was not previously clustered") + ErrChanDelInProgress = errors.New("stan: channel is being deleted") +) + +// Shared regular expression to check clientID validity. +// No lock required since from doc: https://golang.org/pkg/regexp/ +// A Regexp is safe for concurrent use by multiple goroutines. +var clientIDRegEx *regexp.Regexp + +var ( + testAckWaitIsInMillisecond bool + clientCheckTimeout = defaultClientCheckTimeout + lazyReplicationInterval = defaultLazyReplicationInterval + testDeleteChannel bool +) + +func computeAckWait(wait int32) time.Duration { + unit := time.Second + if testAckWaitIsInMillisecond && wait < 0 { + wait = wait * -1 + unit = time.Millisecond + } + return time.Duration(wait) * unit +} + +func init() { + if re, err := regexp.Compile("^[a-zA-Z0-9_-]+$"); err != nil { + panic("Unable to compile regular expression") + } else { + clientIDRegEx = re + } +} + +// ioPendingMsg is a record that embeds the pointer to the incoming +// NATS Message, the PubMsg and PubAck structures so we reduce the +// number of memory allocations to 1 when processing a message from +// producer. +type ioPendingMsg struct { + m *nats.Msg + pm pb.PubMsg + pa pb.PubAck + c *channel + dc bool // if true, this is a request to delete this channel. + + // Use for synchronization between ioLoop and other routines + sc chan struct{} + sdc chan struct{} +} + +// Constant that defines the size of the channel that feeds the IO thread. +const ioChannelSize = 64 * 1024 + +// subStartInfo contains information used when a subscription request +// is successful and the start (sending avail messages) is scheduled. +type subStartInfo struct { + c *channel + sub *subState + qs *queueState + isDurable bool +} + +// State represents the possible server states +type State int8 + +// Possible server states +const ( + Standalone State = iota + FTActive + FTStandby + Failed + Shutdown + Clustered +) + +func (state State) String() string { + switch state { + case Standalone: + return "STANDALONE" + case FTActive: + return "FT_ACTIVE" + case FTStandby: + return "FT_STANDBY" + case Failed: + return "FAILED" + case Shutdown: + return "SHUTDOWN" + case Clustered: + return "CLUSTERED" + default: + return "UNKNOW STATE" + } +} + +type channelStore struct { + sync.RWMutex + delMu sync.Mutex + channels map[string]*channel + store stores.Store + stan *StanServer +} + +func newChannelStore(srv *StanServer, s stores.Store) *channelStore { + cs := &channelStore{ + channels: make(map[string]*channel), + store: s, + stan: srv, + } + return cs +} + +func (cs *channelStore) get(name string) *channel { + cs.RLock() + c := cs.channels[name] + cs.RUnlock() + return c +} + +func (cs *channelStore) createChannel(s *StanServer, name string) (*channel, error) { + cs.Lock() + defer cs.Unlock() + // It is possible that there were 2 concurrent calls to lookupOrCreateChannel + // which first uses `channelStore.get()` and if not found, calls this function. + // So we need to check now that we have the write lock that the channel has + // not already been created. + c := cs.channels[name] + if c != nil { + return c, nil + } + sc, err := cs.store.CreateChannel(name) + if err != nil { + return nil, err + } + c, err = cs.create(s, name, sc) + if err != nil { + return nil, err + } + isStandaloneOrLeader := true + if s.isClustered { + if s.isLeader() { + if err := c.subToSnapshotRestoreRequests(); err != nil { + delete(cs.channels, name) + return nil, err + } + } else { + isStandaloneOrLeader = false + } + } + if isStandaloneOrLeader && c.activity != nil { + c.startDeleteTimer() + } + cs.stan.log.Noticef("Channel %q has been created", name) + return c, nil +} + +// low-level creation and storage in memory of a *channel +// Lock is held on entry or not needed. +func (cs *channelStore) create(s *StanServer, name string, sc *stores.Channel) (*channel, error) { + c := &channel{name: name, store: sc, ss: s.createSubStore(), stan: s} + lastSequence, err := c.store.Msgs.LastSequence() + if err != nil { + return nil, err + } + c.nextSequence = lastSequence + 1 + cs.channels[name] = c + cl := cs.store.GetChannelLimits(name) + if cl.MaxInactivity > 0 { + c.activity = &channelActivity{maxInactivity: cl.MaxInactivity} + } + return c, nil +} + +func (cs *channelStore) getAll() map[string]*channel { + cs.RLock() + m := make(map[string]*channel, len(cs.channels)) + for k, v := range cs.channels { + m[k] = v + } + cs.RUnlock() + return m +} + +func (cs *channelStore) msgsState(channelName string) (int, uint64, error) { + cs.RLock() + defer cs.RUnlock() + if channelName != "" { + c := cs.channels[channelName] + if c == nil { + return 0, 0, fmt.Errorf("channel %q not found", channelName) + } + return c.store.Msgs.State() + } + var ( + count int + bytes uint64 + ) + for _, c := range cs.channels { + m, b, err := c.store.Msgs.State() + if err != nil { + return 0, 0, err + } + count += m + bytes += b + } + return count, bytes, nil +} + +func (cs *channelStore) count() int { + cs.RLock() + count := len(cs.channels) + cs.RUnlock() + return count +} + +func (cs *channelStore) lockDelete() { + cs.delMu.Lock() +} + +func (cs *channelStore) unlockDelete() { + cs.delMu.Unlock() +} + +func (cs *channelStore) maybeStartChannelDeleteTimer(name string, c *channel) { + cs.delMu.Lock() + cs.RLock() + if c == nil { + c = cs.channels[name] + } + if c != nil && c.activity != nil && !c.activity.deleteInProgress && !c.ss.hasActiveSubs() { + c.startDeleteTimer() + } + cs.RUnlock() + cs.delMu.Unlock() +} + +type channel struct { + nextSequence uint64 + name string + store *stores.Channel + ss *subStore + lTimestamp int64 + snapshotSub *nats.Subscription + stan *StanServer + activity *channelActivity +} + +type channelActivity struct { + last time.Time + maxInactivity time.Duration + timer *time.Timer + deleteInProgress bool + timerSet bool +} + +// Starts the delete timer that when firing will post +// a channel delete request to the ioLoop. +// The channelStore's delMu mutex must be held on entry. +func (c *channel) startDeleteTimer() { + c.activity.last = time.Now() + c.resetDeleteTimer(c.activity.maxInactivity) +} + +// Stops the delete timer. +// The channelStore's delMu mutex must be held on entry. +func (c *channel) stopDeleteTimer() { + if c.activity.timer != nil { + c.activity.timer.Stop() + c.activity.timerSet = false + if c.stan.debug { + c.stan.log.Debugf("Channel %q delete timer stopped", c.name) + } + } +} + +// Resets the delete timer to the given duration. +// If the timer was not created, this call will create it. +// The channelStore's delMu mutex must be held on entry. +func (c *channel) resetDeleteTimer(newDuration time.Duration) { + a := c.activity + if a.timer == nil { + a.timer = time.AfterFunc(newDuration, func() { + c.stan.sendDeleteChannelRequest(c) + }) + } else { + a.timer.Reset(newDuration) + } + if c.stan.debug { + c.stan.log.Debugf("Channel %q delete timer set to fire in %v", c.name, newDuration) + } + a.timerSet = true +} + +// pubMsgToMsgProto converts a PubMsg to a MsgProto and assigns a timestamp +// which is monotonic with respect to the channel. +func (c *channel) pubMsgToMsgProto(pm *pb.PubMsg, seq uint64) *pb.MsgProto { + m := &pb.MsgProto{ + Sequence: seq, + Subject: pm.Subject, + Reply: pm.Reply, + Data: pm.Data, + Timestamp: time.Now().UnixNano(), + } + if c.lTimestamp > 0 && m.Timestamp < c.lTimestamp { + m.Timestamp = c.lTimestamp + } + c.lTimestamp = m.Timestamp + return m +} + +// Sets a subscription that will handle snapshot restore requests from followers. +func (c *channel) subToSnapshotRestoreRequests() error { + var ( + msgBuf []byte + buf []byte + snapshotRestoreSubj = fmt.Sprintf("%s.%s.%s", defaultSnapshotPrefix, c.stan.info.ClusterID, c.name) + ) + sub, err := c.stan.ncsr.Subscribe(snapshotRestoreSubj, func(m *nats.Msg) { + if len(m.Data) != 16 { + c.stan.log.Errorf("Invalid snapshot request, data len=%v", len(m.Data)) + return + } + start := util.ByteOrder.Uint64(m.Data[:8]) + end := util.ByteOrder.Uint64(m.Data[8:]) + + for seq := start; seq <= end; seq++ { + msg, err := c.store.Msgs.Lookup(seq) + if err != nil { + c.stan.log.Errorf("Snapshot restore request error for channel %q, error looking up message %v: %v", c.name, seq, err) + return + } + if msg == nil { + // We don't have this message because of channel limits. + // Return nil to caller to signal this state. + buf = nil + } else { + msgBuf = util.EnsureBufBigEnough(msgBuf, msg.Size()) + n, err := msg.MarshalTo(msgBuf) + if err != nil { + panic(err) + } + buf = msgBuf[:n] + } + if err := c.stan.ncsr.Publish(m.Reply, buf); err != nil { + c.stan.log.Errorf("Snapshot restore request error for channel %q, unable to send response for seq %v: %v", c.name, seq, err) + } + if buf == nil { + return + } + select { + case <-c.stan.shutdownCh: + return + default: + } + } + }) + if err != nil { + return err + } + c.snapshotSub = sub + c.snapshotSub.SetPendingLimits(-1, -1) + return nil +} + +// StanServer structure represents the NATS Streaming Server +type StanServer struct { + // Keep all members for which we use atomic at the beginning of the + // struct and make sure they are all 64bits (or use padding if necessary). + // atomic.* functions crash on 32bit machines if operand is not aligned + // at 64bit. See https://github.com/golang/go/issues/599 + ioChannelStatsMaxBatchSize int64 // stats of the max number of messages than went into a single batch + + mu sync.RWMutex + shutdown bool + shutdownCh chan struct{} + serverID string + info spb.ServerInfo // Contains cluster ID and subjects + natsServer *server.Server + opts *Options + natsOpts *server.Options + startTime time.Time + + // For scalability, a dedicated connection is used to publish + // messages to subscribers and for replication. + nc *nats.Conn // used for most protocol messages + ncs *nats.Conn // used for sending to subscribers and acking publishers + nca *nats.Conn // used to receive subscriptions acks + ncr *nats.Conn // used for raft messages + ncsr *nats.Conn // used for raft snapshot replication + + wg sync.WaitGroup // Wait on go routines during shutdown + + // Used when processing connect requests for client ID already registered + dupCIDTimeout time.Duration + + // Clients + clients *clientStore + cliDupCIDsMu sync.Mutex + cliDipCIDsMap map[string]struct{} + + // channels + channels *channelStore + + // Store + store stores.Store + + // Monitoring + monMu sync.RWMutex + numSubs int + + // IO Channel + ioChannel chan *ioPendingMsg + ioChannelQuit chan struct{} + ioChannelWG sync.WaitGroup + + // To protect some close related requests + closeMu sync.Mutex + + tmpBuf []byte // Used to marshal protocols (right now, only PubAck) + + subStartCh chan *subStartInfo + subStartQuit chan struct{} + + // For FT mode + ftnc *nats.Conn + ftSubject string + ftHBInterval time.Duration + ftHBMissedInterval time.Duration + ftHBCh chan *nats.Msg + ftQuit chan struct{} + + state State + // This is in cases where a fatal error occurs after the server was + // started. We call Fatalf, but for users starting the server + // programmatically, it is a way to report what the error was. + lastError error + + // Will be created only when running in partitioning mode. + partitions *partitions + + // Use these flags for Debug/Trace in places where speed matters. + // Normally, Debugf and Tracef will check an internal variable to + // figure out if the statement should be logged, however, the + // cost of calling Debugf/Tracef is still significant since there + // may be memory allocations to format the string passed to these + // calls. So in those situations, use these flags to surround the + // calls to Debugf/Tracef. + trace bool + debug bool + log *logger.StanLogger + + // Specific to clustering + raft *raftNode + raftLogging bool + isClustered bool + lazyRepl *lazyReplication + + // Our internal subscriptions + connectSub *nats.Subscription + closeSub *nats.Subscription + pubSub *nats.Subscription + subSub *nats.Subscription + subCloseSub *nats.Subscription + subUnsubSub *nats.Subscription + cliPingSub *nats.Subscription + + // For sending responses to client PINGS. Used to be global but would + // cause races when running more than 1 server in a program or test. + pingResponseOKBytes []byte + pingResponseInvalidClientBytes []byte +} + +type lazyReplication struct { + sync.Mutex + subs map[*subState]struct{} +} + +func (s *StanServer) isLeader() bool { + return atomic.LoadInt64(&s.raft.leader) == 1 +} + +// subStore holds all known state for all subscriptions +type subStore struct { + sync.RWMutex + psubs []*subState // plain subscribers + qsubs map[string]*queueState // queue subscribers + durables map[string]*subState // durables lookup + acks map[string]*subState // ack inbox lookup + stan *StanServer // back link to the server +} + +// Holds all queue subsribers for a subject/group and +// tracks lastSent for the group. +type queueState struct { + sync.RWMutex + lastSent uint64 + subs []*subState + shadow *subState // For durable case, when last member leaves and group is not closed. + stalledSubCount int // number of stalled members + newOnHold bool +} + +// When doing message redelivery due to ack expiration, the function +// makeSortedPendingMsgs return an array of pendingMsg objects, +// ordered by their expiration date. +type pendingMsg struct { + seq uint64 + expire int64 +} + +// Holds Subscription state +type subState struct { + sync.RWMutex + spb.SubState // Embedded protobuf. Used for storage. + subject string + qstate *queueState + ackWait time.Duration // SubState.AckWaitInSecs expressed as a time.Duration + ackTimer *time.Timer + ackSub *nats.Subscription + acksPending map[uint64]int64 // key is message sequence, value is expiration time. + store stores.SubStore // for easy access to the store interface + + savedClientID string // Used only for closed durables in Clustering mode and monitoring endpoints. + + replicate *subSentAndAck // Used in Clustering mode + + // So far, compacting these booleans into a byte flag would not save space. + // May change if we need to add more. + initialized bool // false until the subscription response has been sent to prevent data to be sent too early. + stalled bool + newOnHold bool // Prevents delivery of new msgs until old are redelivered (on restart) + hasFailedHB bool // This is set when server sends heartbeat to this subscriber's client. +} + +type subSentAndAck struct { + sent []uint64 + ack []uint64 + inFlusher bool +} + +func (sa *subSentAndAck) reset() { + sa.sent = sa.sent[:0] + sa.ack = sa.ack[:0] +} + +// Looks up, or create a new channel if it does not exist +func (s *StanServer) lookupOrCreateChannel(name string) (*channel, error) { + c := s.channels.get(name) + if c != nil { + if c.activity != nil && c.activity.deleteInProgress { + return nil, ErrChanDelInProgress + } + return c, nil + } + return s.channels.createChannel(s, name) +} + +// createSubStore creates a new instance of `subStore`. +func (s *StanServer) createSubStore() *subStore { + subs := &subStore{ + psubs: make([]*subState, 0, 4), + qsubs: make(map[string]*queueState), + durables: make(map[string]*subState), + acks: make(map[string]*subState), + stan: s, + } + return subs +} + +// Store adds this subscription to the server's `subStore` and also in storage +func (ss *subStore) Store(sub *subState) error { + if sub == nil { + return nil + } + // `sub` has just been created and can't be referenced anywhere else in + // the code, so we don't need locking. + + // Adds to storage. + // Use sub lock to avoid race with waitForAcks in some tests + sub.Lock() + err := sub.store.CreateSub(&sub.SubState) + sub.Unlock() + if err != nil { + ss.stan.log.Errorf("Unable to store subscription [%v:%v] on [%s]: %v", sub.ClientID, sub.Inbox, sub.subject, err) + return err + } + + ss.Lock() + ss.updateState(sub) + ss.Unlock() + + return nil +} + +// Updates the subStore state with this sub. +// The subStore is locked on entry (or does not need, as during server restart). +// However, `sub` does not need locking since it has just been created. +func (ss *subStore) updateState(sub *subState) { + // Store by type + if sub.isQueueSubscriber() { + // Queue subscriber. + qs := ss.qsubs[sub.QGroup] + if qs == nil { + qs = &queueState{ + subs: make([]*subState, 0, 4), + } + ss.qsubs[sub.QGroup] = qs + } + qs.Lock() + // The recovered shadow queue sub will have ClientID=="", + // keep a reference to it until a member re-joins the group. + if sub.ClientID == "" { + // There should be only one shadow queue subscriber, but + // we found in https://github.com/nats-io/nats-streaming-server/issues/322 + // that some datastore had 2 of those (not sure how this happened except + // maybe due to upgrades from much older releases that had bugs?). + // So don't panic and use as the shadow the one with the highest LastSent + // value. + if qs.shadow == nil || sub.LastSent > qs.lastSent { + qs.shadow = sub + } + } else { + // Store by ackInbox for ack direct lookup + ss.acks[sub.AckInbox] = sub + + qs.subs = append(qs.subs, sub) + + // If the added sub has newOnHold it means that we are doing recovery and + // that this member had unacknowledged messages. Mark the queue group + // with newOnHold + if sub.newOnHold { + qs.newOnHold = true + } + // Update stalled (on recovery) + if sub.stalled { + qs.stalledSubCount++ + } + } + // Needed in the case of server restart, where + // the queue group's last sent needs to be updated + // based on the recovered subscriptions. + if sub.LastSent > qs.lastSent { + qs.lastSent = sub.LastSent + } + qs.Unlock() + sub.qstate = qs + } else { + // First store by ackInbox for ack direct lookup + ss.acks[sub.AckInbox] = sub + + // Plain subscriber. + ss.psubs = append(ss.psubs, sub) + + // Hold onto durables in special lookup. + if sub.isDurableSubscriber() { + ss.durables[sub.durableKey()] = sub + } + } +} + +// returns an array of all subscriptions (plain, online durables and queue members). +func (ss *subStore) getAllSubs() []*subState { + ss.RLock() + subs := make([]*subState, 0, len(ss.psubs)) + subs = append(subs, ss.psubs...) + for _, qs := range ss.qsubs { + qs.RLock() + subs = append(subs, qs.subs...) + qs.RUnlock() + } + ss.RUnlock() + return subs +} + +// hasSubs returns true if there is any active subscription for this subStore. +// That is, offline durable subscriptions are ignored. +func (ss *subStore) hasActiveSubs() bool { + ss.RLock() + defer ss.RUnlock() + if len(ss.psubs) > 0 { + return true + } + for _, qsub := range ss.qsubs { + // For a durable queue group, when the group is offline, + // qsub.shadow is not nil, but the qsub.subs array should be + // empty. + if len(qsub.subs) > 0 { + return true + } + } + return false +} + +// Remove a subscriber from the subscription store, leaving durable +// subscriptions unless `unsubscribe` is true. +func (ss *subStore) Remove(c *channel, sub *subState, unsubscribe bool) { + if sub == nil { + return + } + + var ( + log logger.Logger + queueGroupIsEmpty bool + ) + + ss.Lock() + if ss.stan.debug { + log = ss.stan.log + } + + sub.Lock() + subject := sub.subject + clientID := sub.ClientID + durableKey := "" + // Do this before clearing the sub.ClientID since this is part of the key!!! + if sub.isDurableSubscriber() { + durableKey = sub.durableKey() + } + // This is needed when doing a snapshot in clustering mode or for monitoring endpoints + sub.savedClientID = sub.ClientID + // Clear the subscriptions clientID + sub.ClientID = "" + ackInbox := sub.AckInbox + qs := sub.qstate + isDurable := sub.IsDurable + subid := sub.ID + store := sub.store + sub.stopAckSub() + sub.Unlock() + + reportError := func(err error) { + ss.stan.log.Errorf("Error deleting subscription subid=%d, subject=%s, err=%v", subid, subject, err) + } + + // Delete from storage non durable subscribers on either connection + // close or call to Unsubscribe(), and durable subscribers only on + // Unsubscribe(). Leave durable queue subs for now, they need to + // be treated differently. + if !isDurable || (unsubscribe && durableKey != "") { + if err := store.DeleteSub(subid); err != nil { + reportError(err) + } + } + + // Delete from ackInbox lookup. + delete(ss.acks, ackInbox) + + // Delete from durable if needed + if unsubscribe && durableKey != "" { + delete(ss.durables, durableKey) + } + + var qsubs map[uint64]*subState + + // Delete ourselves from the list + if qs != nil { + storageUpdate := false + // For queue state, we need to lock specifically, + // because qs.subs can be modified by findBestQueueSub, + // for which we don't have substore lock held. + qs.Lock() + + sub.Lock() + sub.clearAckTimer() + qgroup := sub.QGroup + sub.Unlock() + + qs.subs, _ = sub.deleteFromList(qs.subs) + if len(qs.subs) == 0 { + queueGroupIsEmpty = true + // If it was the last being removed, also remove the + // queue group from the subStore map, but only if + // non durable or explicit unsubscribe. + if !isDurable || unsubscribe { + delete(ss.qsubs, qgroup) + // Delete from storage too. + if err := store.DeleteSub(subid); err != nil { + reportError(err) + } + } else { + // Group is durable and last member just left the group, + // but didn't call Unsubscribe(). Need to keep a reference + // to this sub to maintain the state. + qs.shadow = sub + // Clear the number of stalled members + qs.stalledSubCount = 0 + // Will need to update the LastSent and clear the ClientID + // with a storage update. + storageUpdate = true + } + } else { + if sub.stalled && qs.stalledSubCount > 0 { + qs.stalledSubCount-- + } + now := time.Now().UnixNano() + // If there are pending messages in this sub, they need to be + // transferred to remaining queue subscribers. + numQSubs := len(qs.subs) + idx := 0 + sub.RLock() + // Need to update if this member was the one with the last + // message of the group. + storageUpdate = sub.LastSent == qs.lastSent + sortedPendingMsgs := makeSortedPendingMsgs(sub.acksPending) + for _, pm := range sortedPendingMsgs { + // Get one of the remaning queue subscribers. + qsub := qs.subs[idx] + qsub.Lock() + // Store in storage + if err := qsub.store.AddSeqPending(qsub.ID, pm.seq); err != nil { + ss.stan.log.Errorf("[Client:%s] Unable to transfer message to subid=%d, subject=%s, seq=%d, err=%v", + clientID, subid, subject, pm.seq, err) + qsub.Unlock() + continue + } + // We don't need to update if the sub's lastSent is transferred + // to another queue subscriber. + if storageUpdate && pm.seq == qs.lastSent { + storageUpdate = false + } + // Update LastSent if applicable + if pm.seq > qsub.LastSent { + qsub.LastSent = pm.seq + } + // As of now, members can have different AckWait values. + expirationTime := pm.expire + // If the member the message is transferred to has a higher AckWait, + // keep original expiration time, otherwise check that it is smaller + // than the new AckWait. + if sub.ackWait > qsub.ackWait && expirationTime-now > 0 { + expirationTime = now + int64(qsub.ackWait) + } + // Store in ackPending. + qsub.acksPending[pm.seq] = expirationTime + // Keep track of this qsub + if qsubs == nil { + qsubs = make(map[uint64]*subState) + } + if _, tracked := qsubs[qsub.ID]; !tracked { + qsubs[qsub.ID] = qsub + } + qsub.Unlock() + // Move to the next queue subscriber, going back to first if needed. + idx++ + if idx == numQSubs { + idx = 0 + } + } + sub.RUnlock() + // Even for durable queue subscribers, if this is not the last + // member, we need to delete from storage (we did that higher in + // that function for non durable case). Issue #215. + if isDurable { + if err := store.DeleteSub(subid); err != nil { + reportError(err) + } + } + } + if storageUpdate { + // If we have a shadow sub, use that one, othewise any queue subscriber + // will do, so use the first. + qsub := qs.shadow + if qsub == nil { + qsub = qs.subs[0] + } + qsub.Lock() + qsub.LastSent = qs.lastSent + qsub.store.UpdateSub(&qsub.SubState) + qsub.Unlock() + } + qs.Unlock() + } else { + + sub.Lock() + sub.clearAckTimer() + sub.Unlock() + + ss.psubs, _ = sub.deleteFromList(ss.psubs) + // When closing a durable subscription (calling sub.Close(), not sub.Unsubscribe()), + // we need to update the record on store to prevent the server from adding + // this durable to the list of active subscriptions. This is especially important + // if the client closing this durable is itself not closed when the server is + // restarted. The server would have no way to detect if the durable subscription + // is offline or not. + if isDurable && !unsubscribe { + sub.Lock() + // ClientID is required on store because this is used on recovery to + // "compute" the durable key (clientID+subject+durable name). + sub.ClientID = clientID + sub.IsClosed = true + store.UpdateSub(&sub.SubState) + // After storage, clear the ClientID. + sub.ClientID = "" + sub.Unlock() + } + } + ss.Unlock() + + if !ss.stan.isClustered || ss.stan.isLeader() { + // Calling this will sort current pending messages and ensure + // that the ackTimer is properly set. It does not necessarily + // mean that messages are going to be redelivered on the spot. + for _, qsub := range qsubs { + ss.stan.performAckExpirationRedelivery(qsub, false) + } + } + + if log != nil { + traceCtx := subStateTraceCtx{clientID: clientID, isRemove: true, isUnsubscribe: unsubscribe, isGroupEmpty: queueGroupIsEmpty} + traceSubState(log, sub, &traceCtx) + } +} + +// Lookup by durable name. +func (ss *subStore) LookupByDurable(durableName string) *subState { + ss.RLock() + sub := ss.durables[durableName] + ss.RUnlock() + return sub +} + +// Lookup by ackInbox name. +func (ss *subStore) LookupByAckInbox(ackInbox string) *subState { + ss.RLock() + sub := ss.acks[ackInbox] + ss.RUnlock() + return sub +} + +// Options for NATS Streaming Server +type Options struct { + ID string + DiscoverPrefix string + StoreType string + FilestoreDir string + FileStoreOpts stores.FileStoreOptions + SQLStoreOpts stores.SQLStoreOptions + stores.StoreLimits // Store limits (MaxChannels, etc..) + EnableLogging bool // Enables logging + CustomLogger logger.Logger // Server will start with the provided logger + Trace bool // Verbose trace + Debug bool // Debug trace + HandleSignals bool // Should the server setup a signal handler (for Ctrl+C, etc...) + Secure bool // Create a TLS enabled connection w/o server verification + ClientCert string // Client Certificate for TLS + ClientKey string // Client Key for TLS + ClientCA string // Client CAs for TLS + IOBatchSize int // Maximum number of messages collected from clients before starting their processing. + IOSleepTime int64 // Duration (in micro-seconds) the server waits for more message to fill up a batch. + NATSServerURL string // URL for external NATS Server to connect to. If empty, NATS Server is embedded. + ClientHBInterval time.Duration // Interval at which server sends heartbeat to a client. + ClientHBTimeout time.Duration // How long server waits for a heartbeat response. + ClientHBFailCount int // Number of failed heartbeats before server closes client connection. + FTGroupName string // Name of the FT Group. A group can be 2 or more servers with a single active server and all sharing the same datastore. + Partitioning bool // Specify if server only accepts messages/subscriptions on channels defined in StoreLimits. + SyslogName string // Optional name for the syslog (usueful on Windows when running several servers as a service) + Clustering ClusteringOptions +} + +// Clone returns a deep copy of the Options object. +func (o *Options) Clone() *Options { + // A simple copy covers pretty much everything + clone := *o + // But we have the problem of the PerChannel map that needs + // to be copied. + clone.PerChannel = (&o.StoreLimits).ClonePerChannelMap() + // Make a copy of the clustering peers + if len(o.Clustering.Peers) > 0 { + clone.Clustering.Peers = make([]string, 0, len(o.Clustering.Peers)) + clone.Clustering.Peers = append(clone.Clustering.Peers, o.Clustering.Peers...) + } + return &clone +} + +// DefaultOptions are default options for the NATS Streaming Server +var defaultOptions = Options{ + ID: DefaultClusterID, + DiscoverPrefix: DefaultDiscoverPrefix, + StoreType: DefaultStoreType, + FileStoreOpts: stores.DefaultFileStoreOptions, + IOBatchSize: DefaultIOBatchSize, + IOSleepTime: DefaultIOSleepTime, + NATSServerURL: "", + ClientHBInterval: DefaultHeartBeatInterval, + ClientHBTimeout: DefaultClientHBTimeout, + ClientHBFailCount: DefaultMaxFailedHeartBeats, +} + +// GetDefaultOptions returns default options for the NATS Streaming Server +func GetDefaultOptions() (o *Options) { + opts := defaultOptions + opts.StoreLimits = stores.DefaultStoreLimits + return &opts +} + +// DefaultNatsServerOptions are default options for the NATS server +var DefaultNatsServerOptions = server.Options{ + Host: "localhost", + Port: 4222, + NoLog: true, + NoSigs: true, +} + +func (s *StanServer) stanDisconnectedHandler(nc *nats.Conn) { + if nc.LastError() != nil { + s.log.Errorf("connection %q has been disconnected: %v", + nc.Opts.Name, nc.LastError()) + } +} + +func (s *StanServer) stanReconnectedHandler(nc *nats.Conn) { + s.log.Noticef("connection %q reconnected to NATS Server at %q", + nc.Opts.Name, nc.ConnectedUrl()) +} + +func (s *StanServer) stanClosedHandler(nc *nats.Conn) { + s.log.Debugf("connection %q has been closed", nc.Opts.Name) +} + +func (s *StanServer) stanErrorHandler(nc *nats.Conn, sub *nats.Subscription, err error) { + s.log.Errorf("Asynchronous error on connection %s, subject %s: %s", + nc.Opts.Name, sub.Subject, err) +} + +func (s *StanServer) buildServerURLs() ([]string, error) { + var hostport string + natsURL := s.opts.NATSServerURL + opts := s.natsOpts + // If the URL to an external NATS is provided... + if natsURL != "" { + // If it has user/pwd info or is a list of urls... + if strings.Contains(natsURL, "@") || strings.Contains(natsURL, ",") { + // Return the array + urls := strings.Split(natsURL, ",") + for i, s := range urls { + urls[i] = strings.Trim(s, " ") + } + return urls, nil + } + // Otherwise, prepare the host and port and continue to see + // if user/pass needs to be added. + + // First trim the protocol. + parts := strings.Split(natsURL, "://") + if len(parts) != 2 { + return nil, fmt.Errorf("malformed url: %v", natsURL) + } + natsURL = parts[1] + host, port, err := net.SplitHostPort(natsURL) + if err != nil { + return nil, err + } + // Use net.Join to support IPV6 addresses. + hostport = net.JoinHostPort(host, port) + } else { + // We embed the server, so it is local. If host is "any", + // use 127.0.0.1 or ::1 for host address (important for + // Windows since connect with 0.0.0.0 or :: fails). + sport := strconv.Itoa(opts.Port) + if opts.Host == "0.0.0.0" { + hostport = net.JoinHostPort("127.0.0.1", sport) + } else if opts.Host == "::" || opts.Host == "[::]" { + hostport = net.JoinHostPort("::1", sport) + } else { + hostport = net.JoinHostPort(opts.Host, sport) + } + } + var userpart string + if opts.Authorization != "" { + userpart = opts.Authorization + } else if opts.Username != "" { + userpart = fmt.Sprintf("%s:%s", opts.Username, opts.Password) + } + if userpart != "" { + return []string{fmt.Sprintf("nats://%s@%s", userpart, hostport)}, nil + } + return []string{fmt.Sprintf("nats://%s", hostport)}, nil +} + +// createNatsClientConn creates a connection to the NATS server, using +// TLS if configured. Pass in the NATS server options to derive a +// connection url, and for other future items (e.g. auth) +func (s *StanServer) createNatsClientConn(name string) (*nats.Conn, error) { + var err error + ncOpts := nats.DefaultOptions + + ncOpts.Servers, err = s.buildServerURLs() + if err != nil { + return nil, err + } + ncOpts.Name = fmt.Sprintf("_NSS-%s-%s", s.opts.ID, name) + + if err = nats.ErrorHandler(s.stanErrorHandler)(&ncOpts); err != nil { + return nil, err + } + if err = nats.ReconnectHandler(s.stanReconnectedHandler)(&ncOpts); err != nil { + return nil, err + } + if err = nats.ClosedHandler(s.stanClosedHandler)(&ncOpts); err != nil { + return nil, err + } + if err = nats.DisconnectHandler(s.stanDisconnectedHandler)(&ncOpts); err != nil { + return nil, err + } + if s.opts.Secure { + if err = nats.Secure()(&ncOpts); err != nil { + return nil, err + } + } + if s.opts.ClientCA != "" { + if err = nats.RootCAs(s.opts.ClientCA)(&ncOpts); err != nil { + return nil, err + } + } + if s.opts.ClientCert != "" { + if err = nats.ClientCert(s.opts.ClientCert, s.opts.ClientKey)(&ncOpts); err != nil { + return nil, err + } + } + // Shorten the time we wait to try to reconnect. + // Don't make it too often because it may exhaust the number of FDs. + ncOpts.ReconnectWait = 250 * time.Millisecond + // Make it try to reconnect for ever. + ncOpts.MaxReconnect = -1 + // To avoid possible duplicate redeliveries, etc.., set the reconnect + // buffer to -1 to avoid any buffering in the nats library and flush + // on reconnect. + ncOpts.ReconnectBufSize = -1 + + s.log.Tracef(" NATS conn opts: %v", ncOpts) + + var nc *nats.Conn + if nc, err = ncOpts.Connect(); err != nil { + return nil, err + } + return nc, err +} + +func (s *StanServer) createNatsConnections() error { + var err error + s.ncs, err = s.createNatsClientConn("send") + if err == nil { + s.nc, err = s.createNatsClientConn("general") + } + if err == nil { + s.nca, err = s.createNatsClientConn("acks") + } + if err == nil && s.opts.FTGroupName != "" { + s.ftnc, err = s.createNatsClientConn("ft") + } + if err == nil && s.isClustered { + s.ncr, err = s.createNatsClientConn("raft") + if err == nil { + s.ncsr, err = s.createNatsClientConn("raft_snap") + } + } + return err +} + +// NewNATSOptions returns a new instance of (NATS) Options. +// This is needed if one wants to configure specific NATS options +// before starting a NATS Streaming Server (with RunServerWithOpts()). +func NewNATSOptions() *server.Options { + opts := server.Options{} + return &opts +} + +// RunServer will startup an embedded NATS Streaming Server and a nats-server to support it. +func RunServer(ID string) (*StanServer, error) { + sOpts := GetDefaultOptions() + sOpts.ID = ID + nOpts := DefaultNatsServerOptions + return RunServerWithOpts(sOpts, &nOpts) +} + +// RunServerWithOpts allows you to run a NATS Streaming Server with full control +// on the Streaming and NATS Server configuration. +func RunServerWithOpts(stanOpts *Options, natsOpts *server.Options) (newServer *StanServer, returnedError error) { + var sOpts *Options + var nOpts *server.Options + // Make a copy of the options so we own them. + if stanOpts == nil { + sOpts = GetDefaultOptions() + } else { + sOpts = stanOpts.Clone() + } + if natsOpts == nil { + no := DefaultNatsServerOptions + nOpts = &no + } else { + nOpts = natsOpts.Clone() + } + // For now, no support for partitioning and clustering at the same time + if sOpts.Partitioning && sOpts.Clustering.Clustered { + return nil, fmt.Errorf("stan: channels partitioning in clustering mode is not supported") + } + + if sOpts.Clustering.Clustered { + if sOpts.StoreType == stores.TypeMemory { + return nil, fmt.Errorf("stan: clustering mode not supported with %s store type", stores.TypeMemory) + } + // Override store sync configuration with cluster sync. + sOpts.FileStoreOpts.DoSync = sOpts.Clustering.Sync + + // Remove cluster's node ID (if present) from the list of peers. + if len(sOpts.Clustering.Peers) > 0 && sOpts.Clustering.NodeID != "" { + nodeID := sOpts.Clustering.NodeID + peers := make([]string, 0, len(sOpts.Clustering.Peers)) + for _, p := range sOpts.Clustering.Peers { + if p != nodeID { + peers = append(peers, p) + } + } + if len(peers) != len(sOpts.Clustering.Peers) { + sOpts.Clustering.Peers = peers + } + } + } + + s := StanServer{ + serverID: nuid.Next(), + opts: sOpts, + natsOpts: nOpts, + dupCIDTimeout: defaultCheckDupCIDTimeout, + ioChannelQuit: make(chan struct{}), + trace: sOpts.Trace, + debug: sOpts.Debug, + subStartCh: make(chan *subStartInfo, defaultSubStartChanLen), + subStartQuit: make(chan struct{}, 1), + startTime: time.Now(), + log: logger.NewStanLogger(), + shutdownCh: make(chan struct{}), + isClustered: sOpts.Clustering.Clustered, + raftLogging: sOpts.Clustering.RaftLogging, + cliDipCIDsMap: make(map[string]struct{}), + } + + // If a custom logger is provided, use this one, otherwise, check + // if we should configure the logger or not. + if sOpts.CustomLogger != nil { + s.log.SetLogger(sOpts.CustomLogger, nOpts.Logtime, sOpts.Debug, sOpts.Trace, "") + } else if sOpts.EnableLogging { + s.configureLogger() + } + + s.log.Noticef("Starting nats-streaming-server[%s] version %s", sOpts.ID, VERSION) + + // ServerID is used to check that a brodcast protocol is not ours, + // for instance with FT. Some err/warn messages may be printed + // regarding other instance's ID, so print it on startup. + s.log.Noticef("ServerID: %v", s.serverID) + s.log.Noticef("Go version: %v", runtime.Version()) + + // Ensure that we shutdown the server if there is a panic/error during startup. + // This will ensure that stores are closed (which otherwise would cause + // issues during testing) and that the NATS Server (if started) is also + // properly shutdown. To do so, we recover from the panic in order to + // call Shutdown, then issue the original panic. + defer func() { + // We used to issue panic for common errors but now return error + // instead. Still we want to log the reason for the panic. + if r := recover(); r != nil { + s.Shutdown() + s.log.Noticef("Failed to start: %v", r) + panic(r) + } else if returnedError != nil { + s.Shutdown() + // Log it as a fatal error, process will exit (if + // running from executable or logger is configured). + s.log.Fatalf("Failed to start: %v", returnedError) + } + }() + + storeLimits := &s.opts.StoreLimits + + var ( + err error + store stores.Store + ) + + // Ensure store type option is in upper-case + sOpts.StoreType = strings.ToUpper(sOpts.StoreType) + + // Create the store. + switch sOpts.StoreType { + case stores.TypeFile: + store, err = stores.NewFileStore(s.log, sOpts.FilestoreDir, storeLimits, + stores.AllOptions(&sOpts.FileStoreOpts)) + case stores.TypeSQL: + store, err = stores.NewSQLStore(s.log, sOpts.SQLStoreOpts.Driver, sOpts.SQLStoreOpts.Source, + storeLimits, stores.SQLAllOptions(&sOpts.SQLStoreOpts)) + case stores.TypeMemory: + store, err = stores.NewMemoryStore(s.log, storeLimits) + default: + err = fmt.Errorf("unsupported store type: %v", sOpts.StoreType) + } + if err != nil { + return nil, err + } + // StanServer.store (s.store here) is of type stores.Store, which is an + // interface. If we assign s.store in the call of the constructor and there + // is an error, although the call returns "nil" for the store, we can no + // longer have a test such as "if s.store != nil" (as we do in shutdown). + // This is because the constructors return a store implementention. + // We would need to use reflection such as reflect.ValueOf(s.store).IsNil(). + // So to not do that, we simply delay the setting of s.store when we know + // that it was successful. + if s.isClustered { + // Wrap our store with a RaftStore instance that avoids persisting + // data that we don't need because they are handled by the actual + // raft logs. + store = stores.NewRaftStore(store) + } + s.store = store + + // Start the IO Loop before creating the channel store since the + // go routine watching for channel inactivity may schedule events + // to the IO loop. + s.startIOLoop() + + s.clients = newClientStore(s.store) + s.channels = newChannelStore(&s, s.store) + + // If no NATS server url is provided, it means that we embed the NATS Server + if sOpts.NATSServerURL == "" { + if err := s.startNATSServer(); err != nil { + return nil, err + } + } + // Check for monitoring + if nOpts.HTTPPort != 0 || nOpts.HTTPSPort != 0 { + if err := s.startMonitoring(nOpts); err != nil { + return nil, err + } + } + // Create our connections + if err := s.createNatsConnections(); err != nil { + return nil, err + } + + // In FT mode, server cannot recover the store until it is elected leader. + if s.opts.FTGroupName != "" { + if err := s.ftSetup(); err != nil { + return nil, err + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + if err := s.ftStart(); err != nil { + s.setLastError(err) + } + }() + } else { + state := Standalone + if s.isClustered { + state = Clustered + } + if err := s.start(state); err != nil { + return nil, err + } + } + if s.opts.HandleSignals { + s.handleSignals() + } + return &s, nil +} + +// Logging in STAN +// +// The STAN logger is an instance of a NATS logger, (basically duplicated +// from the NATS server code), and is passed into the NATS server. +// +// A note on Debugf and Tracef: These will be enabled within the log if +// either STAN or the NATS server enables them. However, STAN will only +// trace/debug if the local STAN debug/trace flags are set. NATS will do +// the same with it's logger flags. This enables us to use the same logger, +// but differentiate between STAN and NATS debug/trace. +func (s *StanServer) configureLogger() { + var newLogger logger.Logger + + sOpts := s.opts + nOpts := s.natsOpts + + enableDebug := nOpts.Debug || sOpts.Debug + enableTrace := nOpts.Trace || sOpts.Trace + + syslog := nOpts.Syslog + // Enable syslog if no log file is specified and we're running as a + // Windows service so that logs are written to the Windows event log. + if isWindowsService() && nOpts.LogFile == "" { + syslog = true + } + // If we have a syslog name specified, make sure we will use this name. + // This is for syslog and remote syslogs running on Windows. + if sOpts.SyslogName != "" { + natsdLogger.SetSyslogName(sOpts.SyslogName) + } + + if nOpts.LogFile != "" { + newLogger = natsdLogger.NewFileLogger(nOpts.LogFile, nOpts.Logtime, enableDebug, enableTrace, true) + } else if nOpts.RemoteSyslog != "" { + newLogger = natsdLogger.NewRemoteSysLogger(nOpts.RemoteSyslog, enableDebug, enableTrace) + } else if syslog { + newLogger = natsdLogger.NewSysLogger(enableDebug, enableTrace) + } else { + colors := true + // Check to see if stderr is being redirected and if so turn off color + // Also turn off colors if we're running on Windows where os.Stderr.Stat() returns an invalid handle-error + stat, err := os.Stderr.Stat() + if err != nil || (stat.Mode()&os.ModeCharDevice) == 0 { + colors = false + } + newLogger = natsdLogger.NewStdLogger(nOpts.Logtime, enableDebug, enableTrace, colors, true) + } + + s.log.SetLogger(newLogger, nOpts.Logtime, sOpts.Debug, sOpts.Trace, nOpts.LogFile) +} + +// This is either running inside RunServerWithOpts() and before any reference +// to the server is returned, so locking is not really an issue, or it is +// running from a go-routine when the server has been elected the FT active. +// Therefore, this function grabs the server lock for the duration of this +// call and so care must be taken to not invoke - directly or indirectly - +// code that would attempt to grab the server lock. +func (s *StanServer) start(runningState State) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.shutdown { + return nil + } + + // If using partitioning, send our list and start go routines handling + // channels list requests. + if s.opts.Partitioning { + if err := s.initPartitions(); err != nil { + return err + } + } + + s.state = runningState + + var ( + err error + recoveredState *stores.RecoveredState + recoveredSubs []*subState + callStoreInit bool + ) + + // Recover the state. + s.log.Noticef("Recovering the state...") + recoveredState, err = s.store.Recover() + if err != nil { + return err + } + if recoveredState != nil { + s.log.Noticef("Recovered %v channel(s)", len(recoveredState.Channels)) + } else { + s.log.Noticef("No recovered state") + } + subjID := s.opts.ID + // In FT or with static channels (aka partitioning), we use the cluster ID + // as part of the subjects prefix, not a NUID. + if runningState == Standalone && s.partitions == nil { + subjID = nuid.Next() + } + if recoveredState != nil { + // Copy content + s.info = *recoveredState.Info + // Check cluster IDs match + if s.opts.ID != s.info.ClusterID { + return fmt.Errorf("cluster ID %q does not match recovered value of %q", + s.opts.ID, s.info.ClusterID) + } + // Check to see if SubClose subject is present or not. + // If not, it means we recovered from an older server, so + // need to update. + if s.info.SubClose == "" { + s.info.SubClose = fmt.Sprintf("%s.%s", DefaultSubClosePrefix, subjID) + // Update the store with the server info + callStoreInit = true + } + + // If clustering was enabled but we are recovering a server that was + // previously not clustered, return an error. This is not allowed + // because there is preexisting state that is not represented in the + // Raft log. + if s.isClustered && s.info.NodeID == "" { + return ErrClusteredRestart + } + // Use recovered clustering node ID. + s.opts.Clustering.NodeID = s.info.NodeID + + // Restore clients state + s.processRecoveredClients(recoveredState.Clients) + + // Default Raft log path to .// if not set. This + // must be done here before recovering channels since that will + // initialize Raft groups if clustered. + if s.opts.Clustering.RaftLogPath == "" { + s.opts.Clustering.RaftLogPath = filepath.Join(s.opts.ID, s.opts.Clustering.NodeID) + } + + // Process recovered channels (if any). + recoveredSubs, err = s.processRecoveredChannels(recoveredState.Channels) + if err != nil { + return err + } + } else { + s.info.ClusterID = s.opts.ID + + // Generate Subjects + s.info.Discovery = fmt.Sprintf("%s.%s", s.opts.DiscoverPrefix, s.info.ClusterID) + s.info.Publish = fmt.Sprintf("%s.%s", DefaultPubPrefix, subjID) + s.info.Subscribe = fmt.Sprintf("%s.%s", DefaultSubPrefix, subjID) + s.info.SubClose = fmt.Sprintf("%s.%s", DefaultSubClosePrefix, subjID) + s.info.Unsubscribe = fmt.Sprintf("%s.%s", DefaultUnSubPrefix, subjID) + s.info.Close = fmt.Sprintf("%s.%s", DefaultClosePrefix, subjID) + s.info.AcksSubs = fmt.Sprintf("%s.%s", defaultAcksPrefix, subjID) + + if s.opts.Clustering.Clustered { + // If clustered, assign a random cluster node ID if not provided. + if s.opts.Clustering.NodeID == "" { + s.opts.Clustering.NodeID = nuid.Next() + } + s.info.NodeID = s.opts.Clustering.NodeID + } + + callStoreInit = true + } + if callStoreInit { + // Initialize the store with the server info + if err := s.store.Init(&s.info); err != nil { + return fmt.Errorf("unable to initialize the store: %v", err) + } + } + + // We don't do the check if we are running FT and/or if + // static channels (partitioning) is in play. + if runningState == Standalone && s.partitions == nil { + if err := s.ensureRunningStandAlone(); err != nil { + return err + } + } + + // If clustered, start Raft group. + if s.isClustered { + s.lazyRepl = &lazyReplication{subs: make(map[*subState]struct{})} + s.wg.Add(1) + go s.lazyReplicationOfSentAndAck() + // Default Raft log path to .// if not set. + if s.opts.Clustering.RaftLogPath == "" { + s.opts.Clustering.RaftLogPath = filepath.Join(s.opts.ID, s.opts.Clustering.NodeID) + } + s.log.Noticef("Cluster Node ID : %s", s.info.NodeID) + s.log.Noticef("Cluster Log Path: %s", s.opts.Clustering.RaftLogPath) + if err := s.startRaftNode(recoveredState != nil); err != nil { + return err + } + } + + // Start the go-routine responsible to start sending messages to newly + // started subscriptions. We do that before opening the gates in + // s.initSupscriptions() (which is where the internal subscriptions + // are created). + s.wg.Add(1) + go s.processSubscriptionsStart() + + if err := s.initSubscriptions(); err != nil { + return err + } + + if recoveredState != nil { + // Do some post recovery processing setup some timers, etc...) + s.postRecoveryProcessing(recoveredState.Clients, recoveredSubs) + } + + // Flush to make sure all subscriptions are processed before + // we return control to the user. + if err := s.nc.Flush(); err != nil { + return fmt.Errorf("could not flush the subscriptions, %v", err) + } + + s.log.Noticef("Message store is %s", s.store.Name()) + if s.opts.FilestoreDir != "" { + s.log.Noticef("Store location: %v", s.opts.FilestoreDir) + } + // The store has a copy of the limits and the inheritance + // was not applied to our limits. To have them displayed correctly, + // call Build() on them (we know that this is not going to fail, + // otherwise we would not have been able to create the store). + s.opts.StoreLimits.Build() + storeLimitsLines := (&s.opts.StoreLimits).Print() + for _, l := range storeLimitsLines { + s.log.Noticef(l) + } + + // Execute (in a go routine) redelivery of unacknowledged messages, + // and release newOnHold. We only do this if not clustered. If + // clustered, the leader will handle redelivery upon election. + if !s.isClustered { + s.wg.Add(1) + go func() { + s.performRedeliveryOnStartup(recoveredSubs) + s.wg.Done() + }() + } + return nil +} + +// startRaftNode creates and starts the Raft group. +// This should only be called if the server is running in clustered mode. +func (s *StanServer) startRaftNode(hasStreamingState bool) error { + if err := s.createServerRaftNode(hasStreamingState); err != nil { + return err + } + node := s.raft + + leaderWait := make(chan struct{}, 1) + leaderReady := func() { + select { + case leaderWait <- struct{}{}: + default: + } + } + if node.State() != raft.Leader { + leaderReady() + } + + s.wg.Add(1) + go func() { + defer s.wg.Done() + for { + select { + case isLeader := <-node.notifyCh: + if isLeader { + err := s.leadershipAcquired() + leaderReady() + if err != nil { + s.log.Errorf("Error on leadership acquired: %v", err) + switch { + case err == raft.ErrRaftShutdown: + // Node shutdown, just return. + return + case err == raft.ErrLeadershipLost: + // Node lost leadership, continue loop. + continue + default: + // TODO: probably step down as leader? + panic(err) + } + } + } else { + s.leadershipLost() + } + case <-s.shutdownCh: + // Signal channel here to handle edge case where we might + // otherwise block forever on the channel when shutdown. + leaderReady() + return + } + } + }() + + <-leaderWait + return nil +} + +func (s *StanServer) sendSynchronziationRequest() (chan struct{}, chan struct{}) { + sc := make(chan struct{}, 1) + sdc := make(chan struct{}) + iopm := &ioPendingMsg{sc: sc, sdc: sdc} + s.ioChannel <- iopm + return sc, sdc +} + +// leadershipAcquired should be called when this node is elected leader. +// This should only be called when the server is running in clustered mode. +func (s *StanServer) leadershipAcquired() error { + s.log.Noticef("server became leader, performing leader promotion actions") + defer s.log.Noticef("finished leader promotion actions") + + // If we were not the leader, there should be nothing in the ioChannel + // (processing of client publishes). However, since a node could go + // from leader to follower to leader again, let's make sure that we + // synchronize with the ioLoop before we touch the channels' nextSequence. + sc, sdc := s.sendSynchronziationRequest() + + // Wait for the ioLoop to reach that special iopm and notifies us (or + // give up if server is shutting down). + select { + case <-sc: + case <-s.ioChannelQuit: + close(sdc) + return nil + } + // Then, we will notify it back to unlock it when were are done here. + defer close(sdc) + + // Use a barrier to ensure all preceding operations are applied to the FSM + if err := s.raft.Barrier(0).Error(); err != nil { + return err + } + + channels := s.channels.getAll() + for _, c := range channels { + // Update next sequence to assign. + lastSequence, err := c.store.Msgs.LastSequence() + if err != nil { + return err + } + c.nextSequence = lastSequence + 1 + } + + // Setup client heartbeats and subscribe to acks for each sub. + for _, client := range s.clients.getClients() { + client.RLock() + cID := client.info.ID + for _, sub := range client.subs { + if err := sub.startAckSub(s.nca, s.processAckMsg); err != nil { + client.RUnlock() + return err + } + } + client.RUnlock() + s.clients.setClientHB(cID, s.opts.ClientHBInterval, func() { + s.checkClientHealth(cID) + }) + } + + // Start the internal subscriptions so we receive protocols from clients. + if err := s.initInternalSubs(true); err != nil { + return err + } + + var allSubs []*subState + for _, c := range channels { + // Subscribe to channel snapshot restore subject + if err := c.subToSnapshotRestoreRequests(); err != nil { + return err + } + subs := c.ss.getAllSubs() + if len(subs) > 0 { + allSubs = append(allSubs, subs...) + } + if c.activity != nil { + s.channels.maybeStartChannelDeleteTimer(c.name, c) + } + } + if len(allSubs) > 0 { + s.startGoRoutine(func() { + s.performRedeliveryOnStartup(allSubs) + s.wg.Done() + }) + } + + if err := s.nc.Flush(); err != nil { + return err + } + if err := s.nca.Flush(); err != nil { + return err + } + + atomic.StoreInt64(&s.raft.leader, 1) + return nil +} + +// leadershipLost should be called when this node loses leadership. +// This should only be called when the server is running in clustered mode. +func (s *StanServer) leadershipLost() { + s.log.Noticef("server lost leadership, performing leader stepdown actions") + defer s.log.Noticef("finished leader stepdown actions") + + // Cancel outstanding client heartbeats. We aren't concerned about races + // where new clients might be connecting because at this point, the server + // will no longer accept new client connections, but even if it did, the + // heartbeat would be automatically removed when it fires. + for _, client := range s.clients.getClients() { + s.clients.removeClientHB(client) + // Ensure subs ackTimer is stopped + subs := client.getSubsCopy() + for _, sub := range subs { + sub.Lock() + sub.stopAckSub() + sub.clearAckTimer() + sub.Unlock() + } + } + + // Unsubscribe to the snapshot request per channel since we are no longer + // leader. + for _, c := range s.channels.getAll() { + if c.snapshotSub != nil { + c.snapshotSub.Unsubscribe() + c.snapshotSub = nil + } + if c.activity != nil { + s.channels.lockDelete() + c.stopDeleteTimer() + s.channels.unlockDelete() + } + } + + // Only the leader will receive protocols from clients + s.unsubscribeInternalSubs() + + atomic.StoreInt64(&s.raft.leader, 0) +} + +// TODO: Explore parameter passing in gnatsd. Keep separate for now. +func (s *StanServer) configureClusterOpts() error { + opts := s.natsOpts + // If we don't have cluster defined in the configuration + // file and no cluster listen string override, but we do + // have a routes override, we need to report misconfiguration. + if opts.Cluster.ListenStr == "" && opts.Cluster.Host == "" && + opts.Cluster.Port == 0 { + if opts.RoutesStr != "" { + err := fmt.Errorf("solicited routes require cluster capabilities, e.g. --cluster") + s.log.Fatalf(err.Error()) + // Also return error in case server is started from application + // and no logger has been set. + return err + } + return nil + } + + // If cluster flag override, process it + if opts.Cluster.ListenStr != "" { + clusterURL, err := url.Parse(opts.Cluster.ListenStr) + if err != nil { + return err + } + h, p, err := net.SplitHostPort(clusterURL.Host) + if err != nil { + return err + } + opts.Cluster.Host = h + _, err = fmt.Sscan(p, &opts.Cluster.Port) + if err != nil { + return err + } + + if clusterURL.User != nil { + pass, hasPassword := clusterURL.User.Password() + if !hasPassword { + return fmt.Errorf("expected cluster password to be set") + } + opts.Cluster.Password = pass + + user := clusterURL.User.Username() + opts.Cluster.Username = user + } else { + // Since we override from flag and there is no user/pwd, make + // sure we clear what we may have gotten from config file. + opts.Cluster.Username = "" + opts.Cluster.Password = "" + } + } + + // If we have routes but no config file, fill in here. + if opts.RoutesStr != "" && opts.Routes == nil { + opts.Routes = server.RoutesFromStr(opts.RoutesStr) + } + + return nil +} + +// configureNATSServerTLS sets up TLS for the NATS Server. +// Additional TLS parameters (e.g. cipher suites) will need to be placed +// in a configuration file specified through the -config parameter. +func (s *StanServer) configureNATSServerTLS() error { + opts := s.natsOpts + tlsSet := false + tc := server.TLSConfigOpts{} + if opts.TLSCert != "" { + tc.CertFile = opts.TLSCert + tlsSet = true + } + if opts.TLSKey != "" { + tc.KeyFile = opts.TLSKey + tlsSet = true + } + if opts.TLSCaCert != "" { + tc.CaFile = opts.TLSCaCert + tlsSet = true + } + + if opts.TLSVerify { + tc.Verify = true + tlsSet = true + } + + var err error + if tlsSet { + if opts.TLSConfig, err = server.GenTLSConfig(&tc); err != nil { + // The connection will fail later if the problem is severe enough. + return fmt.Errorf("unable to setup NATS Server TLS: %v", err) + } + } + return nil +} + +// startNATSServer starts the embedded NATS server, possibly updating +// the NATS Server's clustering and/or TLS options. +func (s *StanServer) startNATSServer() error { + if err := s.configureClusterOpts(); err != nil { + return err + } + if err := s.configureNATSServerTLS(); err != nil { + return err + } + opts := s.natsOpts + s.natsServer = server.New(opts) + if s.natsServer == nil { + return fmt.Errorf("no NATS Server object returned") + } + if stanLogger := s.log.GetLogger(); stanLogger != nil { + s.natsServer.SetLogger(stanLogger, opts.Debug, opts.Trace) + } + // Run server in Go routine. + go s.natsServer.Start() + // Wait for accept loop(s) to be started + if !s.natsServer.ReadyForConnections(10 * time.Second) { + return fmt.Errorf("unable to start a NATS Server on %s:%d", opts.Host, opts.Port) + } + return nil +} + +// ensureRunningStandAlone prevents this streaming server from starting +// if another is found using the same cluster ID - a possibility when +// routing is enabled. +// This runs under sever's lock so nothing should grab the server lock here. +func (s *StanServer) ensureRunningStandAlone() error { + clusterID := s.info.ClusterID + hbInbox := nats.NewInbox() + timeout := time.Millisecond * 250 + + // We cannot use the client's API here as it will create a dependency + // cycle in the streaming client, so build our request and see if we + // get a response. + req := &pb.ConnectRequest{ClientID: clusterID, HeartbeatInbox: hbInbox} + b, _ := req.Marshal() + reply, err := s.nc.Request(s.info.Discovery, b, timeout) + if err == nats.ErrTimeout { + s.log.Debugf("Did not detect another server instance") + return nil + } + if err != nil { + return fmt.Errorf("request error detecting another server instance: %v", err) + } + // See if the response is valid and can be unmarshalled. + cr := &pb.ConnectResponse{} + err = cr.Unmarshal(reply.Data) + if err != nil { + // Something other than a compatible streaming server responded. + // This may cause other problems in the long run, so better fail + // the startup early. + return fmt.Errorf("unmarshall error while detecting another server instance: %v", err) + } + // Another streaming server was found, cleanup then return error. + clreq := &pb.CloseRequest{ClientID: clusterID} + b, _ = clreq.Marshal() + s.nc.Request(cr.CloseRequests, b, timeout) + return fmt.Errorf("discovered another streaming server with cluster ID %q", clusterID) +} + +// Binds server's view of a client with stored Client objects. +func (s *StanServer) processRecoveredClients(clients []*stores.Client) { + if !s.isClustered { + s.clients.recoverClients(clients) + } +} + +// Reconstruct the subscription state on restart. +func (s *StanServer) processRecoveredChannels(channels map[string]*stores.RecoveredChannel) ([]*subState, error) { + allSubs := make([]*subState, 0, 16) + + for channelName, recoveredChannel := range channels { + channel, err := s.channels.create(s, channelName, recoveredChannel.Channel) + if err != nil { + return nil, err + } + if !s.isClustered { + // Get the recovered subscriptions for this channel. + for _, recSub := range recoveredChannel.Subscriptions { + sub := s.recoverOneSub(channel, recSub.Sub, recSub.Pending, nil) + if sub != nil { + // Subscribe to subscription ACKs + if err := sub.startAckSub(s.nca, s.processAckMsg); err != nil { + return nil, err + } + allSubs = append(allSubs, sub) + } + } + // Now that we have recovered possible subscriptions for this channel, + // check if we should start the delete timer. + if channel.activity != nil { + s.channels.maybeStartChannelDeleteTimer(channelName, channel) + } + } + } + return allSubs, nil +} + +func (s *StanServer) recoverOneSub(c *channel, recSub *spb.SubState, pendingAcksAsMap map[uint64]struct{}, + pendingAcksAsArray []uint64) *subState { + + // map, but nowhere else. + processOfflineSub := func(c *channel, sub *subState) { + c.ss.durables[sub.durableKey()] = sub + // Now that the key is computed, clear ClientID otherwise + // durable would not be able to be restarted. + sub.savedClientID = sub.ClientID + sub.ClientID = "" + } + + // Create a subState + sub := &subState{ + SubState: *recSub, + subject: c.name, + ackWait: computeAckWait(recSub.AckWaitInSecs), + store: c.store.Subs, + } + // Depending from where this function is called, we are given + // a map[uint64]struct{} or a []uint64. + if len(pendingAcksAsMap) != 0 { + sub.acksPending = make(map[uint64]int64, len(pendingAcksAsMap)) + for seq := range pendingAcksAsMap { + sub.acksPending[seq] = 0 + } + } else { + sub.acksPending = make(map[uint64]int64, len(pendingAcksAsArray)) + for _, seq := range pendingAcksAsArray { + sub.acksPending[seq] = 0 + } + } + if len(sub.acksPending) > 0 { + // Prevent delivery of new messages until resent of old ones + sub.newOnHold = true + // We may not need to set this because this would be set + // during the initial redelivery attempt, but does not hurt. + if int32(len(sub.acksPending)) >= sub.MaxInFlight { + sub.stalled = true + } + } + // When recovering older stores, IsDurable may not exist for + // durable subscribers. Set it now. + durableSub := sub.isDurableSubscriber() // not a durable queue sub! + if durableSub { + sub.IsDurable = true + // Special handling if this is an offline durable subscriber. + // Note that durable subscribers have always ClientID on store. + // This is because we use ClientID+subject+durableName to construct + // the durable key used in the subStore's durables map. + // Note that even if the client connection is recovered, we should + // not attempt to add the offline durable back to the clients and + // regular state. We need to wait for the durable to be restarted. + if sub.IsClosed { + processOfflineSub(c, sub) + return nil + } + } + // Add the subscription to the corresponding client + added := s.clients.addSub(sub.ClientID, sub) + if added || sub.IsDurable { + // Repair for issue https://github.com/nats-io/nats-streaming-server/issues/215 + // Do not recover a queue durable subscriber that still + // has ClientID but for which connection was closed (=>!added) + if !added && sub.isQueueDurableSubscriber() && !sub.isShadowQueueDurable() { + s.log.Noticef("WARN: Not recovering ghost durable queue subscriber: [%s]:[%s] subject=%s inbox=%s", sub.ClientID, sub.QGroup, sub.subject, sub.Inbox) + c.store.Subs.DeleteSub(sub.ID) + return nil + } + // Fix for older offline durable subscribers. Newer offline durable + // subscribers have IsClosed set to true and therefore are handled aboved. + if durableSub && !added { + processOfflineSub(c, sub) + } else { + // Add this subscription to subStore. + c.ss.updateState(sub) + // Add to the array, unless this is the shadow durable queue sub that + // was left in the store in order to maintain the group's state. + if !sub.isShadowQueueDurable() { + s.monMu.Lock() + s.numSubs++ + s.monMu.Unlock() + return sub + } + } + } + return nil +} + +// Do some final setup. Be minded of locking here since the server +// has started communication with NATS server/clients. +func (s *StanServer) postRecoveryProcessing(recoveredClients []*stores.Client, recoveredSubs []*subState) { + for _, sub := range recoveredSubs { + sub.Lock() + // Consider this subscription initialized. Note that it may + // still have newOnHold == true, which would prevent incoming + // messages to be delivered before we attempt to redeliver + // unacknowledged messages in performRedeliveryOnStartup. + sub.initialized = true + sub.Unlock() + } + // Go through the list of clients and ensure their Hb timer is set. Only do + // this for standalone mode. If clustered, timers will be setup on leader + // election. + if !s.isClustered { + for _, sc := range recoveredClients { + // Because of the loop, we need to make copy for the closure + cID := sc.ID + s.clients.setClientHB(cID, s.opts.ClientHBInterval, func() { + s.checkClientHealth(cID) + }) + } + } +} + +// Redelivers unacknowledged messages, releases the hold for new messages delivery, +// and kicks delivery of available messages. +func (s *StanServer) performRedeliveryOnStartup(recoveredSubs []*subState) { + queues := make(map[*queueState]*channel) + + for _, sub := range recoveredSubs { + // Ignore subs that did not have any ack pendings on startup. + sub.Lock() + // Consider this subscription ready to receive messages + sub.initialized = true + // If this is a durable and it is offline, then skip the rest. + if sub.isOfflineDurableSubscriber() { + sub.newOnHold = false + sub.Unlock() + continue + } + // Unlock in order to call function below + sub.Unlock() + // Send old messages (lock is acquired in that function) + s.performAckExpirationRedelivery(sub, true) + // Regrab lock + sub.Lock() + // Allow new messages to be delivered + sub.newOnHold = false + subject := sub.subject + qs := sub.qstate + sub.Unlock() + c := s.channels.get(subject) + if c == nil { + continue + } + // Kick delivery of (possible) new messages + if qs != nil { + queues[qs] = c + } else { + s.sendAvailableMessages(c, sub) + } + } + // Kick delivery for queues that had members with newOnHold + for qs, c := range queues { + qs.Lock() + qs.newOnHold = false + qs.Unlock() + s.sendAvailableMessagesToQueue(c, qs) + } +} + +// initSubscriptions will setup initial subscriptions for discovery etc. +func (s *StanServer) initSubscriptions() error { + + // Do not create internal subscriptions in clustered mode, + // the leader will when it gets elected. + if !s.isClustered { + createSubOnClientPublish := true + + if s.partitions != nil { + // Receive published messages from clients, but only on the list + // of static channels. + if err := s.partitions.initSubscriptions(); err != nil { + return err + } + // Since we create a subscription per channel, do not create + // the internal subscription on the > wildcard + createSubOnClientPublish = false + } + + if err := s.initInternalSubs(createSubOnClientPublish); err != nil { + return err + } + } + + s.log.Debugf("Discover subject: %s", s.info.Discovery) + // For partitions, we actually print the list of channels + // in the startup banner, so we don't need to repeat them here. + if s.partitions != nil { + s.log.Debugf("Publish subjects root: %s", s.info.Publish) + } else { + s.log.Debugf("Publish subject: %s.>", s.info.Publish) + } + s.log.Debugf("Subscribe subject: %s", s.info.Subscribe) + s.log.Debugf("Subscription Close subject: %s", s.info.SubClose) + s.log.Debugf("Unsubscribe subject: %s", s.info.Unsubscribe) + s.log.Debugf("Close subject: %s", s.info.Close) + return nil +} + +func (s *StanServer) initInternalSubs(createPub bool) error { + var err error + // Listen for connection requests. + s.connectSub, err = s.createSub(s.info.Discovery, s.connectCB, "discover") + if err != nil { + return err + } + if createPub { + // Receive published messages from clients. + pubSubject := fmt.Sprintf("%s.>", s.info.Publish) + s.pubSub, err = s.createSub(pubSubject, s.processClientPublish, "publish") + if err != nil { + return err + } + s.pubSub.SetPendingLimits(-1, -1) + } + // Receive subscription requests from clients. + s.subSub, err = s.createSub(s.info.Subscribe, s.processSubscriptionRequest, "subscribe request") + if err != nil { + return err + } + // Receive unsubscribe requests from clients. + s.subUnsubSub, err = s.createSub(s.info.Unsubscribe, s.processUnsubscribeRequest, "subscription unsubscribe") + if err != nil { + return err + } + // Receive subscription close requests from clients. + s.subCloseSub, err = s.createSub(s.info.SubClose, s.processSubCloseRequest, "subscription close request") + if err != nil { + return err + } + // Receive close requests from clients. + s.closeSub, err = s.createSub(s.info.Close, s.processCloseRequest, "close request") + if err != nil { + return err + } + // Receive PINGs from clients. + s.cliPingSub, err = s.createSub(s.info.Discovery+".pings", s.processClientPings, "client pings") + return err +} + +func (s *StanServer) unsubscribeInternalSubs() { + if s.connectSub != nil { + s.connectSub.Unsubscribe() + s.connectSub = nil + } + if s.closeSub != nil { + s.closeSub.Unsubscribe() + s.closeSub = nil + } + if s.subSub != nil { + s.subSub.Unsubscribe() + s.subSub = nil + } + if s.subCloseSub != nil { + s.subCloseSub.Unsubscribe() + s.subCloseSub = nil + } + if s.subUnsubSub != nil { + s.subUnsubSub.Unsubscribe() + s.subUnsubSub = nil + } + if s.pubSub != nil { + s.pubSub.Unsubscribe() + s.pubSub = nil + } + if s.cliPingSub != nil { + s.cliPingSub.Unsubscribe() + s.cliPingSub = nil + } +} + +func (s *StanServer) createSub(subj string, f nats.MsgHandler, errTxt string) (*nats.Subscription, error) { + sub, err := s.nc.Subscribe(subj, f) + if err != nil { + return nil, fmt.Errorf("could not subscribe to %s subject: %v", errTxt, err) + } + return sub, nil +} + +// Process a client connect request +func (s *StanServer) connectCB(m *nats.Msg) { + req := &pb.ConnectRequest{} + err := req.Unmarshal(m.Data) + if err != nil || req.HeartbeatInbox == "" { + s.log.Errorf("[Client:?] Invalid conn request: ClientID=%s, Inbox=%s, err=%v", + req.ClientID, req.HeartbeatInbox, err) + s.sendConnectErr(m.Reply, ErrInvalidConnReq.Error()) + return + } + if !clientIDRegEx.MatchString(req.ClientID) { + s.log.Errorf("[Client:%s] Invalid ClientID, only alphanumeric and `-` or `_` characters allowed", req.ClientID) + s.sendConnectErr(m.Reply, ErrInvalidClientID.Error()) + return + } + + // If the client ID is already registered, check to see if it's the case + // that the client refreshed (e.g. it crashed and came back) or if the + // connection is a duplicate. If it refreshed, we will close the old + // client and open a new one. + client := s.clients.lookup(req.ClientID) + if client != nil { + // When detecting a duplicate, the processing of the connect request + // is going to be processed in a go-routine. We need however to keep + // track and fail another request on the same client ID until the + // current one has finished. + s.cliDupCIDsMu.Lock() + if _, exists := s.cliDipCIDsMap[req.ClientID]; exists { + s.cliDupCIDsMu.Unlock() + s.log.Debugf("[Client:%s] Connect failed; already connected", req.ClientID) + s.sendConnectErr(m.Reply, ErrInvalidClient.Error()) + return + } + s.cliDipCIDsMap[req.ClientID] = struct{}{} + s.cliDupCIDsMu.Unlock() + + s.startGoRoutine(func() { + defer s.wg.Done() + isDup := false + if s.isDuplicateConnect(client) { + s.log.Debugf("[Client:%s] Connect failed; already connected", req.ClientID) + s.sendConnectErr(m.Reply, ErrInvalidClient.Error()) + isDup = true + } + s.cliDupCIDsMu.Lock() + if !isDup { + s.handleConnect(req, m, true) + } + delete(s.cliDipCIDsMap, req.ClientID) + s.cliDupCIDsMu.Unlock() + }) + return + } + s.cliDupCIDsMu.Lock() + s.handleConnect(req, m, false) + s.cliDupCIDsMu.Unlock() +} + +func (s *StanServer) handleConnect(req *pb.ConnectRequest, m *nats.Msg, replaceOld bool) { + var err error + + // If clustered, thread operations through Raft. + if s.isClustered { + err = s.replicateConnect(req, replaceOld) + } else { + err = s.processConnect(req, replaceOld) + } + + if err != nil { + // Error has already been logged. + s.sendConnectErr(m.Reply, err.Error()) + return + } + + // Send connect response to client and start heartbeat timer. + s.finishConnectRequest(req, m.Reply) +} + +// isDuplicateConnect determines if the given client ID is a duplicate +// connection by pinging the old client's heartbeat inbox and checking if it +// responds. If it does, it's a duplicate connection. +func (s *StanServer) isDuplicateConnect(client *client) bool { + client.RLock() + hbInbox := client.info.HbInbox + client.RUnlock() + + // This is the HbInbox from the "old" client. See if it is up and + // running by sending a ping to that inbox. + _, err := s.nc.Request(hbInbox, nil, s.dupCIDTimeout) + + // If err is nil, the currently registered client responded, so this is a + // duplicate. + return err == nil +} + +// When calling ApplyFuture.Error(), if we get an error it means +// that raft failed to commit to its log. +// But if we also want the result of FSM.Apply(), which in this +// case is StanServer.Apply(), we need to check the Response(). +// So we first check error from future.Error(). If nil, then we +// check the Response. +func waitForReplicationErrResponse(f raft.ApplyFuture) error { + err := f.Error() + if err == nil { + resp := f.Response() + // We call this function when we know that FutureApply's + // Response is an error object. + if resp != nil { + err = f.Response().(error) + } + } + return err +} + +// Leader invokes this to replicate the command to delete a channel. +func (s *StanServer) replicateDeleteChannel(channel string) { + op := &spb.RaftOperation{ + OpType: spb.RaftOperation_DeleteChannel, + Channel: channel, + } + data, err := op.Marshal() + if err != nil { + panic(err) + } + // Wait on result of replication. + s.raft.Apply(data, 0).Error() +} + +// Check if the channel can be deleted. If so, do it in place. +// This is called from the ioLoop by the leader or a standlone server. +func (s *StanServer) handleChannelDelete(c *channel) { + delete := false + cs := s.channels + cs.lockDelete() + cs.Lock() + a := c.activity + if a.deleteInProgress || c.ss.hasActiveSubs() { + if s.debug { + s.log.Debugf("Channel %q cannot be deleted: inProgress=%v hasActiveSubs=%v", + c.name, a.deleteInProgress, c.ss.hasActiveSubs()) + } + c.stopDeleteTimer() + } else { + elapsed := time.Since(a.last) + if elapsed >= a.maxInactivity { + if s.debug { + s.log.Debugf("Channel %q is being deleted", c.name) + } + c.stopDeleteTimer() + // Leave in map for now, but mark as deleted. If we removed before + // completion of the removal, a new lookup could re-create while + // in the process of deleting it. + a.deleteInProgress = true + delete = true + } else { + var next time.Duration + if elapsed < 0 { + next = a.maxInactivity + } else { + // elapsed < a.maxInactivity + next = a.maxInactivity - elapsed + } + if s.debug { + s.log.Debugf("Channel %q cannot be deleted now, reset timer to fire in %v", + c.name, next) + } + c.resetDeleteTimer(next) + } + } + cs.Unlock() + cs.unlockDelete() + if delete { + if testDeleteChannel { + time.Sleep(time.Second) + } + if s.isClustered { + s.replicateDeleteChannel(c.name) + } else { + s.processDeleteChannel(c.name) + } + } +} + +// Actual deletetion of the channel. +func (s *StanServer) processDeleteChannel(channel string) { + cs := s.channels + cs.lockDelete() + defer cs.unlockDelete() + cs.Lock() + defer cs.Unlock() + // Delete from store + if err := cs.store.DeleteChannel(channel); err != nil { + s.log.Errorf("Error deleting channel %q: %v", channel, err) + c := cs.channels[channel] + if c != nil && c.activity != nil { + c.activity.deleteInProgress = false + c.startDeleteTimer() + } + return + } + // If no error, remove channel + delete(s.channels.channels, channel) + s.log.Noticef("Channel %q has been deleted", channel) +} + +func (s *StanServer) replicateConnect(req *pb.ConnectRequest, refresh bool) error { + op := &spb.RaftOperation{ + OpType: spb.RaftOperation_Connect, + ClientConnect: &spb.AddClient{ + Request: req, + Refresh: refresh, + }, + } + data, err := op.Marshal() + if err != nil { + panic(err) + } + // Wait on result of replication. + return waitForReplicationErrResponse(s.raft.Apply(data, 0)) +} + +func (s *StanServer) processConnect(req *pb.ConnectRequest, replaceOld bool) error { + // If the client restarted, close the old one first. + if replaceOld { + s.closeClient(req.ClientID) + + // Because connections are processed under a common lock, it is not + // possible for a connection request for the same client ID to come in at + // the same time after unregistering the old client. + } + + // Try to register + info := &spb.ClientInfo{ + ID: req.ClientID, + HbInbox: req.HeartbeatInbox, + ConnID: req.ConnID, + Protocol: req.Protocol, + PingInterval: req.PingInterval, + PingMaxOut: req.PingMaxOut, + } + _, err := s.clients.register(info) + if err != nil { + s.log.Errorf("[Client:%s] Error registering client: %v", req.ClientID, err) + return err + } + + if replaceOld { + s.log.Debugf("[Client:%s] Replaced old client (Inbox=%v)", req.ClientID, req.HeartbeatInbox) + } else { + s.log.Debugf("[Client:%s] Connected (Inbox=%v)", req.ClientID, req.HeartbeatInbox) + } + return nil +} + +func (s *StanServer) finishConnectRequest(req *pb.ConnectRequest, replyInbox string) { + clientID := req.ClientID + // Heartbeat timer. + s.clients.setClientHB(clientID, s.opts.ClientHBInterval, func() { s.checkClientHealth(clientID) }) + + cr := &pb.ConnectResponse{ + PubPrefix: s.info.Publish, + SubRequests: s.info.Subscribe, + UnsubRequests: s.info.Unsubscribe, + SubCloseRequests: s.info.SubClose, + CloseRequests: s.info.Close, + Protocol: protocolOne, + } + // We could set those unconditionally since even with + // older clients, the protobuf Unmarshal would simply not + // decode them. + if req.Protocol >= protocolOne { + cr.PingRequests = s.info.Discovery + ".pings" + // In the future, we may want to return different values + // than the one the client sent in the connect request. + // For now, return the values from the request. + // Note: Server and clients have possibly different HBs values. + cr.PingInterval = req.PingInterval + cr.PingMaxOut = req.PingMaxOut + } + b, _ := cr.Marshal() + s.nc.Publish(replyInbox, b) +} + +func (s *StanServer) sendConnectErr(replyInbox, err string) { + cr := &pb.ConnectResponse{Error: err} + b, _ := cr.Marshal() + s.nc.Publish(replyInbox, b) +} + +// Send a heartbeat call to the client. +func (s *StanServer) checkClientHealth(clientID string) { + client := s.clients.lookup(clientID) + if client == nil { + return + } + + // If clustered and we lost leadership, we should stop + // heartbeating as the new leader will take over. + if s.isClustered && !s.isLeader() { + s.clients.removeClientHB(client) + return + } + + client.RLock() + hbInbox := client.info.HbInbox + client.RUnlock() + + // Sends the HB request. This call blocks for ClientHBTimeout, + // do not hold the lock for that long! + _, err := s.nc.Request(hbInbox, nil, s.opts.ClientHBTimeout) + // Grab the lock now. + client.Lock() + // Client could have been unregistered, in which case + // client.hbt will be nil. + if client.hbt == nil { + client.Unlock() + return + } + hadFailed := client.fhb > 0 + // If we did not get the reply, increase the number of + // failed heartbeats. + if err != nil { + client.fhb++ + // If we have reached the max number of failures + if client.fhb > s.opts.ClientHBFailCount { + s.log.Debugf("[Client:%s] Timed out on heartbeats", clientID) + // close the client (connection). This locks the + // client object internally so unlock here. + client.Unlock() + // If clustered, thread operations through Raft. + if s.isClustered { + if err := s.replicateConnClose(&pb.CloseRequest{ClientID: clientID}); err != nil { + s.log.Errorf("[Client:%s] Failed to replicate disconnect on heartbeat expiration: %v", + clientID, err) + } + } else { + s.closeClient(clientID) + } + return + } + } else { + // We got the reply, reset the number of failed heartbeats. + client.fhb = 0 + } + // Reset the timer to fire again. + client.hbt.Reset(s.opts.ClientHBInterval) + var ( + subs []*subState + hasFailedHB = client.fhb > 0 + ) + if (hadFailed && !hasFailedHB) || (!hadFailed && hasFailedHB) { + // Get a copy of subscribers and client.fhb while under lock + subs = client.getSubsCopy() + } + client.Unlock() + if len(subs) > 0 { + // Push the info about presence of failed heartbeats down to + // subscribers, so they have easier access to that info in + // the redelivery attempt code. + for _, sub := range subs { + sub.Lock() + sub.hasFailedHB = hasFailedHB + sub.Unlock() + } + } +} + +// Close a client +func (s *StanServer) closeClient(clientID string) error { + s.closeMu.Lock() + defer s.closeMu.Unlock() + // Remove from our clientStore. + client, err := s.clients.unregister(clientID) + // The above call may return an error (due to storage) but still return + // the client that is being unregistered. So log error an proceed. + if err != nil { + s.log.Errorf("Error unregistering client %q: %v", clientID, err) + } + // This would mean that the client was already unregistered or was never + // registered. + if client == nil { + s.log.Errorf("Unknown client %q in close request", clientID) + return ErrUnknownClient + } + + // Remove all non-durable subscribers. + s.removeAllNonDurableSubscribers(client) + + if s.debug { + client.RLock() + hbInbox := client.info.HbInbox + client.RUnlock() + s.log.Debugf("[Client:%s] Closed (Inbox=%v)", clientID, hbInbox) + } + return nil +} + +// processCloseRequest will process connection close requests from clients. +func (s *StanServer) processCloseRequest(m *nats.Msg) { + req := &pb.CloseRequest{} + err := req.Unmarshal(m.Data) + if err != nil { + s.log.Errorf("Received invalid close request, subject=%s", m.Subject) + s.sendCloseResponse(m.Reply, ErrInvalidCloseReq) + return + } + + s.barrier(func() { + var err error + // If clustered, thread operations through Raft. + if s.isClustered { + err = s.replicateConnClose(req) + } else { + err = s.closeClient(req.ClientID) + } + // If there was an error, it has been already logged. + + // Send response, if err is nil, will be a success response. + s.sendCloseResponse(m.Reply, err) + }) +} + +func (s *StanServer) replicateConnClose(req *pb.CloseRequest) error { + // Go through the list of subscriptions and possibly + // flush the pending replication of sent/ack. + subs := s.clients.getSubs(req.ClientID) + for _, sub := range subs { + s.flushReplicatedSentAndAckSeqs(sub, true) + } + + op := &spb.RaftOperation{ + OpType: spb.RaftOperation_Disconnect, + ClientDisconnect: req, + } + data, err := op.Marshal() + if err != nil { + panic(err) + } + // Wait on result of replication. + return waitForReplicationErrResponse(s.raft.Apply(data, 0)) +} + +func (s *StanServer) sendCloseResponse(subj string, closeErr error) { + resp := &pb.CloseResponse{} + if closeErr != nil { + resp.Error = closeErr.Error() + } + if b, err := resp.Marshal(); err == nil { + s.nc.Publish(subj, b) + } +} + +// processClientPublish process inbound messages from clients. +func (s *StanServer) processClientPublish(m *nats.Msg) { + iopm := &ioPendingMsg{m: m} + pm := &iopm.pm + if pm.Unmarshal(m.Data) != nil { + if s.processCtrlMsg(m) { + return + } + // else we will report an error below... + } + + // Make sure we have a guid and valid channel name. + if pm.Guid == "" || !util.IsChannelNameValid(pm.Subject, false) { + s.log.Errorf("Received invalid client publish message %v", pm) + s.sendPublishErr(m.Reply, pm.Guid, ErrInvalidPubReq) + return + } + + // Check if the client is valid. We do this after the clustered check so + // that only the leader performs this check. + valid := false + if s.partitions != nil || s.isClustered { + // In partitioning or clustering it is possible that we get there + // before the connect request is processed. If so, make sure we wait + // for conn request to be processed first. Check clientCheckTimeout + // doc for details. + valid = s.clients.isValidWithTimeout(pm.ClientID, pm.ConnID, clientCheckTimeout) + } else { + valid = s.clients.isValid(pm.ClientID, pm.ConnID) + } + if !valid { + s.log.Errorf("Received invalid client publish message %v", pm) + s.sendPublishErr(m.Reply, pm.Guid, ErrInvalidPubReq) + return + } + + s.ioChannel <- iopm +} + +// processClientPings receives a PING from a client. The payload is the client's UID. +// If the client is present, a response with nil payload is sent back to indicate +// success, otherwise the payload contains an error message. +func (s *StanServer) processClientPings(m *nats.Msg) { + if len(m.Data) == 0 { + return + } + ping := &pb.Ping{} + if err := ping.Unmarshal(m.Data); err != nil { + return + } + var reply []byte + client := s.clients.lookupByConnID(ping.ConnID) + if client != nil { + // If the client has failed heartbeats and since the + // server just received a PING from the client, reset + // the server-to-client HB timer so that a PING is + // sent soon and the client's subscriptions failedHB + // is cleared. + client.RLock() + hasFailedHBs := client.fhb > 0 + client.RUnlock() + if hasFailedHBs { + client.Lock() + client.hbt.Reset(time.Millisecond) + client.Unlock() + } + if s.pingResponseOKBytes == nil { + s.pingResponseOKBytes, _ = (&pb.PingResponse{}).Marshal() + } + reply = s.pingResponseOKBytes + } else { + if s.pingResponseInvalidClientBytes == nil { + pingError := &pb.PingResponse{ + Error: "client has been replaced or is no longer registered", + } + s.pingResponseInvalidClientBytes, _ = pingError.Marshal() + } + reply = s.pingResponseInvalidClientBytes + } + s.ncs.Publish(m.Reply, reply) +} + +// CtrlMsg are no longer used to solve connection and subscription close/unsub +// ordering issues. However, a (newer) server may still receive those from +// older servers in the same NATS cluster. +// Since original behavior was to ignore control messages sent from a server +// other than itself, and since new server do not send those (in this context +// at least), this function simply make sure that if it is a properly formed +// CtrlMsg, we just ignore. +func (s *StanServer) processCtrlMsg(m *nats.Msg) bool { + cm := &spb.CtrlMsg{} + // Since we don't use CtrlMsg for connection/subscription close/unsub, + // simply return true if CtrlMsg is valid so that this message is ignored. + return cm.Unmarshal(m.Data) == nil +} + +func (s *StanServer) sendPublishErr(subj, guid string, err error) { + badMsgAck := &pb.PubAck{Guid: guid, Error: err.Error()} + if b, err := badMsgAck.Marshal(); err == nil { + s.ncs.Publish(subj, b) + } +} + +// FIXME(dlc) - place holder to pick sub that has least outstanding, should just sort, +// or use insertion sort, etc. +func findBestQueueSub(sl []*subState) *subState { + var ( + leastOutstanding = int(^uint(0) >> 1) + rsub *subState + ) + for _, sub := range sl { + + sub.RLock() + sOut := len(sub.acksPending) + sStalled := sub.stalled + sHasFailedHB := sub.hasFailedHB + sub.RUnlock() + + // Favor non stalled subscribers and clients that do not have failed heartbeats + if !sStalled && !sHasFailedHB { + if sOut < leastOutstanding { + leastOutstanding = sOut + rsub = sub + } + } + } + + len := len(sl) + if rsub == nil && len > 0 { + rsub = sl[0] + } + if len > 1 && rsub == sl[0] { + copy(sl, sl[1:len]) + sl[len-1] = rsub + } + + return rsub +} + +// Send a message to the queue group +// Assumes qs lock held for write +func (s *StanServer) sendMsgToQueueGroup(qs *queueState, m *pb.MsgProto, force bool) (*subState, bool, bool) { + sub := findBestQueueSub(qs.subs) + if sub == nil { + return nil, false, false + } + sub.Lock() + wasStalled := sub.stalled + didSend, sendMore := s.sendMsgToSub(sub, m, force) + // If this is not a redelivery and the sub was not stalled, but now is, + // bump the number of stalled members. + if !force && !wasStalled && sub.stalled { + qs.stalledSubCount++ + } + if didSend && sub.LastSent > qs.lastSent { + qs.lastSent = sub.LastSent + } + sub.Unlock() + return sub, didSend, sendMore +} + +// processMsg will process a message, and possibly send to clients, etc. +func (s *StanServer) processMsg(c *channel) { + ss := c.ss + + // Since we iterate through them all. + ss.RLock() + // Walk the plain subscribers and deliver to each one + for _, sub := range ss.psubs { + s.sendAvailableMessages(c, sub) + } + + // Check the queue subscribers + for _, qs := range ss.qsubs { + s.sendAvailableMessagesToQueue(c, qs) + } + ss.RUnlock() +} + +// Used for sorting by sequence +type bySeq []uint64 + +func (a bySeq) Len() int { return (len(a)) } +func (a bySeq) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a bySeq) Less(i, j int) bool { return a[i] < a[j] } + +// Returns an array of message sequence numbers ordered by sequence. +func makeSortedSequences(sequences map[uint64]int64) []uint64 { + results := make([]uint64, 0, len(sequences)) + for seq := range sequences { + results = append(results, seq) + } + sort.Sort(bySeq(results)) + return results +} + +// Used for sorting by expiration time +type byExpire []*pendingMsg + +func (a byExpire) Len() int { return (len(a)) } +func (a byExpire) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a byExpire) Less(i, j int) bool { + // If expire is 0, it means the server was restarted + // and we don't have the expiration time, which will + // be set later. Order by sequence instead. + if a[i].expire == 0 || a[j].expire == 0 { + return a[i].seq < a[j].seq + } + return a[i].expire < a[j].expire +} + +// Returns an array of pendingMsg ordered by expiration date, unless +// the expiration date in the pendingMsgs map is not set (0), which +// happens after a server restart. In this case, the array is ordered +// by message sequence numbers. +func makeSortedPendingMsgs(pendingMsgs map[uint64]int64) []*pendingMsg { + results := make([]*pendingMsg, 0, len(pendingMsgs)) + for seq, expire := range pendingMsgs { + results = append(results, &pendingMsg{seq: seq, expire: expire}) + } + sort.Sort(byExpire(results)) + return results +} + +// Redeliver all outstanding messages to a durable subscriber, used on resubscribe. +func (s *StanServer) performDurableRedelivery(c *channel, sub *subState) { + // Sort our messages outstanding from acksPending, grab some state and unlock. + sub.RLock() + sortedSeqs := makeSortedSequences(sub.acksPending) + clientID := sub.ClientID + newOnHold := sub.newOnHold + subID := sub.ID + sub.RUnlock() + + if s.debug && len(sortedSeqs) > 0 { + sub.RLock() + durName := sub.DurableName + if durName == "" { + durName = sub.QGroup + } + sub.RUnlock() + s.log.Debugf("[Client:%s] Redelivering to subid=%d, durable=%s", clientID, subID, durName) + } + + // If we don't find the client, we are done. + if s.clients.lookup(clientID) != nil { + // Go through all messages + for _, seq := range sortedSeqs { + m := s.getMsgForRedelivery(c, sub, seq) + if m == nil { + continue + } + + if s.trace { + s.log.Tracef("[Client:%s] Redelivering to subid=%d, seq=%d", clientID, subID, m.Sequence) + } + + // Flag as redelivered. + m.Redelivered = true + + sub.Lock() + // Force delivery + s.sendMsgToSub(sub, m, forceDelivery) + sub.Unlock() + } + } + // Release newOnHold if needed. + if newOnHold { + sub.Lock() + sub.newOnHold = false + sub.Unlock() + } +} + +// Redeliver all outstanding messages that have expired. +func (s *StanServer) performAckExpirationRedelivery(sub *subState, isStartup bool) { + // Sort our messages outstanding from acksPending, grab some state and unlock. + sub.Lock() + sortedPendingMsgs := makeSortedPendingMsgs(sub.acksPending) + if len(sortedPendingMsgs) == 0 { + sub.clearAckTimer() + sub.Unlock() + return + } + expTime := int64(sub.ackWait) + subject := sub.subject + qs := sub.qstate + clientID := sub.ClientID + subID := sub.ID + if sub.ackTimer == nil { + s.setupAckTimer(sub, sub.ackWait) + } + if qs == nil { + // If the client has some failed heartbeats, ignore this request. + if sub.hasFailedHB { + // Reset the timer + sub.ackTimer.Reset(sub.ackWait) + sub.Unlock() + if s.debug { + s.log.Debugf("[Client:%s] Skipping redelivery to subid=%d due to missed client heartbeat", clientID, subID) + } + return + } + } + sub.Unlock() + + c := s.channels.get(subject) + if c == nil { + s.log.Errorf("[Client:%s] Aborting redelivery to subid=%d for non existing channel %s", clientID, subID, subject) + sub.Lock() + sub.clearAckTimer() + sub.Unlock() + return + } + + // In cluster mode we will always redeliver to the same queue member. + // This is to avoid to have to replicated sent/ack when a message would + // be redelivered (removed from one member to be sent to another member) + isClustered := s.isClustered + + now := time.Now().UnixNano() + // limit is now plus a buffer of 15ms to avoid repeated timer callbacks. + limit := now + int64(15*time.Millisecond) + + var ( + pick *subState + sent bool + tracePrinted bool + foundWithZero bool + nextExpirationTime int64 + ) + + // We will move through acksPending(sorted) and see what needs redelivery. + for _, pm := range sortedPendingMsgs { + m := s.getMsgForRedelivery(c, sub, pm.seq) + if m == nil { + continue + } + // If we found any pm.expire with 0 in the array (due to a server restart), + // ensure that all have now an expiration set, then reschedule right away. + if foundWithZero || pm.expire == 0 { + foundWithZero = true + if pm.expire == 0 { + sub.Lock() + // Is message still pending? + if _, present := sub.acksPending[pm.seq]; present { + sub.acksPending[pm.seq] = m.Timestamp + expTime + } + sub.Unlock() + } + continue + } + + // If this message has not yet expired, reset timer for next callback + if pm.expire > limit { + nextExpirationTime = pm.expire + if !tracePrinted && s.trace { + tracePrinted = true + s.log.Tracef("[Client:%s] Redelivery for subid=%d, skipping seq=%d", clientID, subID, m.Sequence) + } + break + } + + // Flag as redelivered. + m.Redelivered = true + + // Handle QueueSubscribers differently, since we will choose best subscriber + // to redeliver to, not necessarily the same one. + // However, on startup, resends only to member that had previously this message + // otherwise this could cause a message to be redelivered to multiple members. + if !isClustered && qs != nil && !isStartup { + qs.Lock() + pick, sent, _ = s.sendMsgToQueueGroup(qs, m, forceDelivery) + qs.Unlock() + if pick == nil { + s.log.Errorf("[Client:%s] Unable to find queue subscriber for subid=%d", clientID, subID) + break + } + // If the message is redelivered to a different queue subscriber, + // we need to process an implicit ack for the original subscriber. + // We do this only after confirmation that it was successfully added + // as pending on the other queue subscriber. + if pick != sub && sent { + s.processAck(c, sub, m.Sequence) + } + } else { + sub.Lock() + s.sendMsgToSub(sub, m, forceDelivery) + sub.Unlock() + } + } + if foundWithZero { + // Restart expiration now that ackPending map's expire values are properly + // set. Note that messages may have been added/removed in the meantime. + s.performAckExpirationRedelivery(sub, isStartup) + return + } + + // Adjust the timer + sub.adjustAckTimer(nextExpirationTime) +} + +// getMsgForRedelivery looks up the message from storage. If not found - +// because it has been removed due to limit - processes an ACK for this +// sub/sequence number and returns nil, otherwise return a copy of the +// message (since it is going to be modified: m.Redelivered = true) +func (s *StanServer) getMsgForRedelivery(c *channel, sub *subState, seq uint64) *pb.MsgProto { + m, err := c.store.Msgs.Lookup(seq) + if m == nil || err != nil { + if err != nil { + s.log.Errorf("Error getting message for redelivery subid=%d, seq=%d, err=%v", + sub.ID, seq, err) + } + // Ack it so that it does not reincarnate on restart + s.processAck(c, sub, seq) + return nil + } + // The store implementation does not return a copy, we need one + mcopy := *m + return &mcopy +} + +func createSubSentAndAck() *subSentAndAck { + return &subSentAndAck{ + sent: make([]uint64, 0), + ack: make([]uint64, 0), + } +} + +// Lazyly replicates the fact that the server sent a message to a given subscription, +// or the subsription ack'ed a message. +// Caller holds the sub's Lock. +func (s *StanServer) replicateSentOrAck(sub *subState, sent bool, sequence uint64) { + if sub.replicate == nil { + sub.replicate = createSubSentAndAck() + } + repl := sub.replicate + if sent { + repl.sent = append(repl.sent, sequence) + } else { + repl.ack = append(repl.ack, sequence) + } + if len(repl.sent)+len(repl.ack) >= 100 { + s.replicateSentAndAckSeqs(sub) + } else if !repl.inFlusher { + s.lazyRepl.Lock() + s.lazyRepl.subs[sub] = struct{}{} + repl.inFlusher = true + s.lazyRepl.Unlock() + } +} + +// Replicates through raft +// Caller holds the sub's Lock. +func (s *StanServer) replicateSentAndAckSeqs(sub *subState) { + op := &spb.RaftOperation{ + OpType: spb.RaftOperation_SendAndAck, + SubSentAck: &spb.SubSentAndAck{ + Channel: sub.subject, + AckInbox: sub.AckInbox, + Sent: sub.replicate.sent, + Ack: sub.replicate.ack, + }, + } + data, err := op.Marshal() + if err != nil { + panic(err) + } + s.raft.Apply(data, 0) + sub.replicate.reset() +} + +// Possibly issue a raft replication for the given subscription +// if there are pending sent/ack operations. +// This is called by a go-routine that does lazy replication, or +// when a subscription/connection is closed. In such case, we may +// skip the replication if the subscription is not durable for instance, +// because the subscription is going away anyway. +func (s *StanServer) flushReplicatedSentAndAckSeqs(sub *subState, onClose bool) { + sub.Lock() + if sub.replicate != nil { + if len(sub.replicate.sent) > 0 || len(sub.replicate.ack) > 0 { + if !onClose || (sub.IsDurable || sub.qstate != nil) { + s.replicateSentAndAckSeqs(sub) + } + } + // When called from the lazy replication go-routine, onClose is false + // and the sub has already been removed from the map. + if onClose && sub.replicate.inFlusher { + s.lazyRepl.Lock() + delete(s.lazyRepl.subs, sub) + s.lazyRepl.Unlock() + } + sub.replicate.inFlusher = false + } + sub.Unlock() +} + +// long-lived go-routine that periodically replicate +// subscriptions' pending Sent and/or Ack operations. +func (s *StanServer) lazyReplicationOfSentAndAck() { + defer s.wg.Done() + + s.mu.Lock() + lr := s.lazyRepl + s.mu.Unlock() + + ticker := time.NewTicker(lazyReplicationInterval) + + flush := func() { + lr.Lock() + for sub := range lr.subs { + delete(lr.subs, sub) + lr.Unlock() + s.flushReplicatedSentAndAckSeqs(sub, false) + lr.Lock() + } + lr.Unlock() + } + + for { + select { + case <-s.shutdownCh: + // Try to flush outstanding before returning + flush() + return + case <-ticker.C: + flush() + } + } +} + +// This is invoked from raft thread on a follower. It persists given +// sequence number to subscription of given AckInbox. It updates the +// sub (and queue state) LastSent value. It adds the sequence to the +// map of acksPending. +func (s *StanServer) processReplicatedSendAndAck(ssa *spb.SubSentAndAck) { + c, err := s.lookupOrCreateChannel(ssa.Channel) + if err != nil { + return + } + sub := c.ss.LookupByAckInbox(ssa.AckInbox) + if sub == nil { + return + } + sub.Lock() + defer sub.Unlock() + + // This is not optimized. The leader sent all accumulated sent and ack + // sequences. For queue members, there is no much that can be done + // because by nature seq will not be contiguous, but for non queue + // subs, this could be optimized. + for _, sequence := range ssa.Sent { + // Update LastSent if applicable + if sequence > sub.LastSent { + sub.LastSent = sequence + } + // In case this is a queue member, update queue state's LastSent. + if sub.qstate != nil && sequence > sub.qstate.lastSent { + sub.qstate.lastSent = sequence + } + // Set 0 for expiration time. This will be computed + // when the follower becomes leader and attempts to + // redeliver messages. + sub.acksPending[sequence] = 0 + } + // Now remove the acks pending that we potentially just added ;-) + for _, sequence := range ssa.Ack { + delete(sub.acksPending, sequence) + } + // Don't set the sub.stalled here. Let that be done if the server + // becomes leader and attempt the first deliveries. +} + +// Sends the message to the subscriber +// Unless `force` is true, in which case message is always sent, if the number +// of acksPending is greater or equal to the sub's MaxInFlight limit, messages +// are not sent and subscriber is marked as stalled. +// Sub lock should be held before calling. +func (s *StanServer) sendMsgToSub(sub *subState, m *pb.MsgProto, force bool) (bool, bool) { + if sub == nil || m == nil || !sub.initialized || (sub.newOnHold && !m.Redelivered) { + return false, false + } + + // Don't send if we have too many outstanding already, unless forced to send. + ap := int32(len(sub.acksPending)) + if !force && (ap >= sub.MaxInFlight) { + sub.stalled = true + return false, false + } + + if s.trace { + var action string + if m.Redelivered { + action = "Redelivering" + } else { + action = "Delivering" + } + s.log.Tracef("[Client:%s] %s msg to subid=%d, subject=%s, seq=%d", + sub.ClientID, action, sub.ID, m.Subject, m.Sequence) + } + + // Marshal of a pb.MsgProto cannot fail + b, _ := m.Marshal() + // but protect against a store implementation that may incorrectly + // return an empty message. + if len(b) == 0 { + panic("store implementation returned an empty message") + } + if err := s.ncs.Publish(sub.Inbox, b); err != nil { + s.log.Errorf("[Client:%s] Failed sending to subid=%d, subject=%s, seq=%d, err=%v", + sub.ClientID, sub.ID, m.Subject, m.Sequence, err) + return false, false + } + + // Setup the ackTimer as needed now. I don't want to use defer in this + // function, and want to make sure that if we exit before the end, the + // timer is set. It will be adjusted/stopped as needed. + if sub.ackTimer == nil { + s.setupAckTimer(sub, sub.ackWait) + } + + // If this message is already pending, do not add it again to the store. + if expTime, present := sub.acksPending[m.Sequence]; present { + // However, update the next expiration time. + if expTime == 0 { + // That can happen after a server restart, so need to use + // the current time. + expTime = time.Now().UnixNano() + } + // bump the next expiration time with the sub's ackWait. + expTime += int64(sub.ackWait) + sub.acksPending[m.Sequence] = expTime + return true, true + } + + // If in cluster mode, trigger replication (but leader does + // not wait on quorum result). + if s.isClustered { + s.replicateSentOrAck(sub, replicateSent, m.Sequence) + } + + // Store in storage + if err := sub.store.AddSeqPending(sub.ID, m.Sequence); err != nil { + s.log.Errorf("[Client:%s] Unable to add pending message to subid=%d, subject=%s, seq=%d, err=%v", + sub.ClientID, sub.ID, sub.subject, m.Sequence, err) + return false, false + } + + // Update LastSent if applicable + if m.Sequence > sub.LastSent { + sub.LastSent = m.Sequence + } + + // Store in ackPending. + // Use current time to compute expiration time instead of m.Timestamp. + // A message can be persisted in the log and send much later to a + // new subscriber. Basing expiration time on m.Timestamp would + // likely set the expiration time in the past! + sub.acksPending[m.Sequence] = time.Now().UnixNano() + int64(sub.ackWait) + + // Now that we have added to acksPending, check again if we + // have reached the max and tell the caller that it should not + // be sending more at this time. + if !force && (ap+1 == sub.MaxInFlight) { + sub.stalled = true + return true, false + } + + return true, true +} + +// Sets up the ackTimer to fire at the given duration. +// sub's lock held on entry. +func (s *StanServer) setupAckTimer(sub *subState, d time.Duration) { + sub.ackTimer = time.AfterFunc(d, func() { + s.performAckExpirationRedelivery(sub, false) + }) +} + +func (s *StanServer) startIOLoop() { + s.ioChannelWG.Add(1) + s.ioChannel = make(chan *ioPendingMsg, ioChannelSize) + // Use wait group to ensure that the loop is as ready as + // possible before we setup the subscriptions and open the door + // to incoming NATS messages. + ready := &sync.WaitGroup{} + ready.Add(1) + go s.ioLoop(ready) + ready.Wait() +} + +func (s *StanServer) ioLoop(ready *sync.WaitGroup) { + defer s.ioChannelWG.Done() + + storesToFlush := make(map[*channel]struct{}, 64) + + var ( + _pendingMsgs [ioChannelSize]*ioPendingMsg + pendingMsgs = _pendingMsgs[:0] + ) + + storeIOPendingMsgs := func(iopms []*ioPendingMsg) { + var ( + futuresMap map[*channel]raft.Future + err error + ) + if s.isClustered { + futuresMap, err = s.replicate(iopms) + // If replicate() returns an error, it means that no future + // was applied, so we can fail all published messages. + if err != nil { + for _, iopm := range iopms { + s.logErrAndSendPublishErr(iopm, err) + } + } else { + for c, f := range futuresMap { + // Wait for the replication result. + // We know that we panic in StanServer.Apply() if storing + // of messages fail. So the only reason f.Error() would + // return an error (we are not using timeout in Apply()) + // is if raft fails to store its log, but it would have + // then switched follower state. On leadership acquisition + // we do reset nextSequence based on lastSequence on store. + // Regardless, do reset here in case of error. + + // Note that each future contains a batch of messages for + // a given channel. All futures in the map are for different + // channels. + if err := f.Error(); err != nil { + lastSeq, lerr := c.store.Msgs.LastSequence() + if lerr != nil { + panic(fmt.Errorf("Error during message replication (%v), unable to get store last sequence: %v", err, lerr)) + } + c.nextSequence = lastSeq + 1 + } else { + storesToFlush[c] = struct{}{} + } + } + // We have 1 future per channel. However, the array of iopms + // may be from different channels. For each iopm we look + // up its corresponding future and if there was an error + // (same for all iopms of the same channel) we fail the + // corresponding publishers. + for _, iopm := range iopms { + // We can call Error() again, this is not a problem. + if err := futuresMap[iopm.c].Error(); err != nil { + s.logErrAndSendPublishErr(iopm, err) + } else { + pendingMsgs = append(pendingMsgs, iopm) + } + } + } + } else { + for _, iopm := range iopms { + pm := &iopm.pm + c, err := s.lookupOrCreateChannel(pm.Subject) + if err == nil { + msg := c.pubMsgToMsgProto(pm, c.nextSequence) + _, err = c.store.Msgs.Store(msg) + } + if err != nil { + s.logErrAndSendPublishErr(iopm, err) + } else { + c.nextSequence++ + pendingMsgs = append(pendingMsgs, iopm) + storesToFlush[c] = struct{}{} + } + } + } + } + + var ( + batchSize = s.opts.IOBatchSize + sleepTime = s.opts.IOSleepTime + sleepDur = time.Duration(sleepTime) * time.Microsecond + max = 0 + batch = make([]*ioPendingMsg, 0, batchSize) + dciopm *ioPendingMsg + ) + + synchronizationRequest := func(iopm *ioPendingMsg) { + iopm.sc <- struct{}{} + <-iopm.sdc + } + + ready.Done() + for { + batch = batch[:0] + select { + case iopm := <-s.ioChannel: + // Is this a request to delete a channel? + if iopm.dc { + s.handleChannelDelete(iopm.c) + continue + } else if iopm.sc != nil { + synchronizationRequest(iopm) + continue + } + batch = append(batch, iopm) + + remaining := batchSize - 1 + FILL_BATCH_LOOP: + // fill the message batch slice with at most our batch size, + // unless the channel is empty. + for remaining > 0 { + ioChanLen := len(s.ioChannel) + + // if we are empty, wait, check again, and break if nothing. + // While this adds some latency, it optimizes batching. + if ioChanLen == 0 { + if sleepTime > 0 { + time.Sleep(sleepDur) + ioChanLen = len(s.ioChannel) + if ioChanLen == 0 { + break + } + } else { + break + } + } + + // stick to our buffer size + if ioChanLen > remaining { + ioChanLen = remaining + } + + for i := 0; i < ioChanLen; i++ { + iopm = <-s.ioChannel + if iopm.dc { + dciopm = iopm + break FILL_BATCH_LOOP + } else if iopm.sc != nil { + synchronizationRequest(iopm) + } else { + batch = append(batch, iopm) + } + } + // Keep track of max number of messages in a batch + if ioChanLen > max { + max = ioChanLen + atomic.StoreInt64(&(s.ioChannelStatsMaxBatchSize), int64(max)) + } + remaining -= ioChanLen + } + + // If clustered, wait on the result of replication. + storeIOPendingMsgs(batch) + + // flush all the stores with messages written to them... + for c := range storesToFlush { + if err := c.store.Msgs.Flush(); err != nil { + // TODO: Attempt recovery, notify publishers of error. + panic(fmt.Errorf("Unable to flush msg store: %v", err)) + } + // Call this here, so messages are sent to subscribers, + // which means that msg seq is added to subscription file + s.processMsg(c) + if err := c.store.Subs.Flush(); err != nil { + panic(fmt.Errorf("Unable to flush sub store: %v", err)) + } + // Remove entry from map (this is safe in Go) + delete(storesToFlush, c) + // When relevant, update the last activity + if c.activity != nil { + c.activity.last = time.Unix(0, c.lTimestamp) + } + } + + // Ack our messages back to the publisher + for i := range pendingMsgs { + iopm := pendingMsgs[i] + s.ackPublisher(iopm) + pendingMsgs[i] = nil + } + + // clear out pending messages + pendingMsgs = pendingMsgs[:0] + + // If there was a request to delete a channel, try now + if dciopm != nil { + s.handleChannelDelete(dciopm.c) + dciopm = nil + } + + case <-s.ioChannelQuit: + return + } + } +} + +func (s *StanServer) logErrAndSendPublishErr(iopm *ioPendingMsg, err error) { + s.log.Errorf("[Client:%s] Error processing message for subject %q: %v", + iopm.pm.ClientID, iopm.m.Subject, err) + s.sendPublishErr(iopm.m.Reply, iopm.pm.Guid, err) +} + +// Sends a special ioPendingMsg to indicate that we should attempt +// to delete the given channel. +func (s *StanServer) sendDeleteChannelRequest(c *channel) { + iopm := &ioPendingMsg{c: c, dc: true} + s.ioChannel <- iopm +} + +// replicate will replicate the batch of messages to followers and return +// futures (one for each channel messages were replicated for) which, when +// waited upon, will indicate if the replication was successful or not. This +// should only be called if running in clustered mode. +func (s *StanServer) replicate(iopms []*ioPendingMsg) (map[*channel]raft.Future, error) { + var ( + futures = make(map[*channel]raft.Future) + batches = make(map[*channel]*spb.Batch) + ) + for _, iopm := range iopms { + pm := &iopm.pm + c, err := s.lookupOrCreateChannel(pm.Subject) + if err != nil { + return nil, err + } + msg := c.pubMsgToMsgProto(pm, c.nextSequence) + batch := batches[c] + if batch == nil { + batch = &spb.Batch{} + batches[c] = batch + } + batch.Messages = append(batch.Messages, msg) + iopm.c = c + c.nextSequence++ + } + for c, batch := range batches { + op := &spb.RaftOperation{ + OpType: spb.RaftOperation_Publish, + PublishBatch: batch, + } + data, err := op.Marshal() + if err != nil { + panic(err) + } + futures[c] = s.raft.Apply(data, 0) + } + return futures, nil +} + +// ackPublisher sends the ack for a message. +func (s *StanServer) ackPublisher(iopm *ioPendingMsg) { + msgAck := &iopm.pa + msgAck.Guid = iopm.pm.Guid + needed := msgAck.Size() + s.tmpBuf = util.EnsureBufBigEnough(s.tmpBuf, needed) + n, _ := msgAck.MarshalTo(s.tmpBuf) + if s.trace { + pm := &iopm.pm + s.log.Tracef("[Client:%s] Acking Publisher subj=%s guid=%s", pm.ClientID, pm.Subject, pm.Guid) + } + s.ncs.Publish(iopm.m.Reply, s.tmpBuf[:n]) +} + +// Delete a sub from a given list. +func (sub *subState) deleteFromList(sl []*subState) ([]*subState, bool) { + for i := 0; i < len(sl); i++ { + if sl[i] == sub { + sl[i] = sl[len(sl)-1] + sl[len(sl)-1] = nil + sl = sl[:len(sl)-1] + return shrinkSubListIfNeeded(sl), true + } + } + return sl, false +} + +// Checks if we need to do a resize. This is for very large growth then +// subsequent return to a more normal size. +func shrinkSubListIfNeeded(sl []*subState) []*subState { + lsl := len(sl) + csl := cap(sl) + // Don't bother if list not too big + if csl <= 8 { + return sl + } + pFree := float32(csl-lsl) / float32(csl) + if pFree > 0.50 { + return append([]*subState(nil), sl...) + } + return sl +} + +// removeAllNonDurableSubscribers will remove all non-durable subscribers for the client. +func (s *StanServer) removeAllNonDurableSubscribers(client *client) { + // client has been unregistered and no other routine can add/remove + // subscriptions, so it is safe to use the original. + client.RLock() + subs := client.subs + client.RUnlock() + for _, sub := range subs { + sub.RLock() + subject := sub.subject + sub.RUnlock() + // Get the channel + c := s.channels.get(subject) + if c == nil { + continue + } + // Don't remove durables + c.ss.Remove(c, sub, false) + } +} + +// processUnsubscribeRequest will process a unsubscribe request. +func (s *StanServer) processUnsubscribeRequest(m *nats.Msg) { + req := &pb.UnsubscribeRequest{} + err := req.Unmarshal(m.Data) + if err != nil { + s.log.Errorf("Invalid unsub request from %s", m.Subject) + s.sendSubscriptionResponseErr(m.Reply, ErrInvalidUnsubReq) + return + } + s.performmUnsubOrCloseSubscription(m, req, false) +} + +// processSubCloseRequest will process a subscription close request. +func (s *StanServer) processSubCloseRequest(m *nats.Msg) { + req := &pb.UnsubscribeRequest{} + err := req.Unmarshal(m.Data) + if err != nil { + s.log.Errorf("Invalid sub close request from %s", m.Subject) + s.sendSubscriptionResponseErr(m.Reply, ErrInvalidUnsubReq) + return + } + s.performmUnsubOrCloseSubscription(m, req, true) +} + +// Used when processing protocol messages to guarantee ordering. +// Since protocols handlers use different subscriptions, a client +// may send a message then close the connection, but those protocols +// are processed by different internal subscriptions in the server. +// Using nats's Conn.Barrier() we ensure that messages have been +// processed in their respective callbacks before invoking `f`. +// Since we also use a separate connection to handle acks, we +// also need to flush the connection used to process ack's and +// chained Barrier calls between s.nc and s.nca. +func (s *StanServer) barrier(f func()) { + s.nc.Barrier(func() { + // Ensure all pending acks are received by the connection + s.nca.Flush() + // Then ensure that all acks have been processed in processAckMsg callbacks + // before executing the closing function. + s.nca.Barrier(f) + }) +} + +// performmUnsubOrCloseSubscription processes the unsub or close subscription +// request. +func (s *StanServer) performmUnsubOrCloseSubscription(m *nats.Msg, req *pb.UnsubscribeRequest, isSubClose bool) { + // With partitioning, first verify that this server is handling this + // channel. If not, do not return an error, since another server will + // handle it. If no other server is, the client will get a timeout. + if s.partitions != nil { + if r := s.partitions.sl.Match(req.Subject); len(r) == 0 { + return + } + } + + s.barrier(func() { + var err error + if s.isClustered { + if isSubClose { + err = s.replicateCloseSubscription(req) + } else { + err = s.replicateRemoveSubscription(req) + } + } else { + s.closeMu.Lock() + err = s.unsubscribe(req, isSubClose) + s.closeMu.Unlock() + } + // If there was an error, it has been already logged. + + if err == nil { + // This will check if the channel has MaxInactivity defined, + // if so and there is no active subscription, it will start the + // delete timer. + s.channels.maybeStartChannelDeleteTimer(req.Subject, nil) + } + + // If err is nil, it will be a non-error response + s.sendSubscriptionResponseErr(m.Reply, err) + }) +} + +func (s *StanServer) unsubscribe(req *pb.UnsubscribeRequest, isSubClose bool) error { + action := "unsub" + if isSubClose { + action = "sub close" + } + c := s.channels.get(req.Subject) + if c == nil { + s.log.Errorf("[Client:%s] %s request missing subject %s", + req.ClientID, action, req.Subject) + return ErrInvalidSub + } + sub := c.ss.LookupByAckInbox(req.Inbox) + if sub == nil { + s.log.Errorf("[Client:%s] %s request for missing inbox %s", + req.ClientID, action, req.Inbox) + return ErrInvalidSub + } + return s.unsubscribeSub(c, req.ClientID, action, sub, isSubClose) +} + +func (s *StanServer) unsubscribeSub(c *channel, clientID, action string, sub *subState, isSubClose bool) error { + // Remove from Client + if !s.clients.removeSub(clientID, sub) { + s.log.Errorf("[Client:%s] %s request for missing client", clientID, action) + return ErrUnknownClient + } + // Remove the subscription + unsubscribe := !isSubClose + c.ss.Remove(c, sub, unsubscribe) + s.monMu.Lock() + s.numSubs-- + s.monMu.Unlock() + return nil +} + +func (s *StanServer) replicateRemoveSubscription(req *pb.UnsubscribeRequest) error { + return s.replicateUnsubscribe(req, spb.RaftOperation_RemoveSubscription) +} + +func (s *StanServer) replicateCloseSubscription(req *pb.UnsubscribeRequest) error { + // When closing a subscription, we need to possibly "flush" the + // pending sent/ack that need to be replicated + c := s.channels.get(req.Subject) + if c != nil { + sub := c.ss.LookupByAckInbox(req.Inbox) + if sub != nil { + s.flushReplicatedSentAndAckSeqs(sub, true) + } + } + return s.replicateUnsubscribe(req, spb.RaftOperation_CloseSubscription) +} + +func (s *StanServer) replicateUnsubscribe(req *pb.UnsubscribeRequest, opType spb.RaftOperation_Type) error { + op := &spb.RaftOperation{ + OpType: opType, + Unsub: req, + } + data, err := op.Marshal() + if err != nil { + panic(err) + } + // Wait on result of replication. + return waitForReplicationErrResponse(s.raft.Apply(data, 0)) +} + +func (s *StanServer) sendSubscriptionResponseErr(reply string, err error) { + resp := &pb.SubscriptionResponse{} + if err != nil { + resp.Error = err.Error() + } + b, _ := resp.Marshal() + s.ncs.Publish(reply, b) +} + +// Clear the ackTimer. +// sub Lock held in entry. +func (sub *subState) clearAckTimer() { + if sub.ackTimer != nil { + sub.ackTimer.Stop() + sub.ackTimer = nil + } +} + +// adjustAckTimer adjusts the timer based on a given next +// expiration time. +// The timer will be stopped if there is no more pending ack. +// If there are pending acks, the timer will be reset to the +// default sub.ackWait value if the given expiration time is +// 0 or in the past. Otherwise, it is set to the remaining time +// between the given expiration time and now. +func (sub *subState) adjustAckTimer(nextExpirationTime int64) { + sub.Lock() + defer sub.Unlock() + + // Possible that the subscriber has been destroyed, and timer cleared + if sub.ackTimer == nil { + return + } + + // Check if there are still pending acks + if len(sub.acksPending) > 0 { + // Capture time + now := time.Now().UnixNano() + + // If the next expiration time is 0 or less than now, + // use the default ackWait + if nextExpirationTime <= now { + sub.ackTimer.Reset(sub.ackWait) + } else { + // Compute the time the ackTimer should fire, based + // on the given next expiration time and now. + fireIn := (nextExpirationTime - now) + sub.ackTimer.Reset(time.Duration(fireIn)) + } + } else { + // No more pending acks, clear the timer. + sub.clearAckTimer() + } +} + +// Subscribes to the AckInbox subject in order to process subscription's acks +// if not already done. +// This function grabs and releases the sub's lock. +func (sub *subState) startAckSub(nc *nats.Conn, cb nats.MsgHandler) error { + ackSub, err := nc.Subscribe(sub.AckInbox, cb) + if err != nil { + return err + } + sub.Lock() + // Should not occur, but if it was already set, + // unsubscribe old and replace. + sub.stopAckSub() + sub.ackSub = ackSub + sub.ackSub.SetPendingLimits(-1, -1) + sub.Unlock() + return nil +} + +// Stops subscribing to AckInbox. +// Lock assumed held on entry. +func (sub *subState) stopAckSub() { + if sub.ackSub != nil { + sub.ackSub.Unsubscribe() + sub.ackSub = nil + } +} + +// Used to generate durable key. This should not be called on non-durables. +func (sub *subState) durableKey() string { + if sub.DurableName == "" { + return "" + } + return fmt.Sprintf("%s-%s-%s", sub.ClientID, sub.subject, sub.DurableName) +} + +// Returns true if this sub is a queue subscriber (durable or not) +func (sub *subState) isQueueSubscriber() bool { + return sub.QGroup != "" +} + +// Returns true if this sub is a durable queue subscriber +func (sub *subState) isQueueDurableSubscriber() bool { + return sub.QGroup != "" && sub.IsDurable +} + +// Returns true if this is a "shadow" durable queue subscriber +func (sub *subState) isShadowQueueDurable() bool { + return sub.IsDurable && sub.QGroup != "" && sub.ClientID == "" +} + +// Returns true if this sub is a durable subscriber (not a durable queue sub) +func (sub *subState) isDurableSubscriber() bool { + return sub.DurableName != "" +} + +// Returns true if this is an offline durable subscriber. +func (sub *subState) isOfflineDurableSubscriber() bool { + return sub.DurableName != "" && sub.ClientID == "" +} + +// Used to generate durable key. This should not be called on non-durables. +func durableKey(sr *pb.SubscriptionRequest) string { + if sr.DurableName == "" { + return "" + } + return fmt.Sprintf("%s-%s-%s", sr.ClientID, sr.Subject, sr.DurableName) +} + +// replicateSub replicates the SubscriptionRequest to nodes in the cluster via +// Raft. +func (s *StanServer) replicateSub(sr *pb.SubscriptionRequest, ackInbox string) (*subState, error) { + op := &spb.RaftOperation{ + OpType: spb.RaftOperation_Subscribe, + Sub: &spb.AddSubscription{ + Request: sr, + AckInbox: ackInbox, + }, + } + data, err := op.Marshal() + if err != nil { + panic(err) + } + // Replicate operation and wait on result. + future := s.raft.Apply(data, 0) + if err := future.Error(); err != nil { + return nil, err + } + rs := future.Response().(*replicatedSub) + return rs.sub, rs.err +} + +// addSubscription adds `sub` to the client and store. +func (s *StanServer) addSubscription(ss *subStore, sub *subState) error { + // Store in client + if !s.clients.addSub(sub.ClientID, sub) { + return fmt.Errorf("can't find clientID: %v", sub.ClientID) + } + // Store this subscription in subStore + if err := ss.Store(sub); err != nil { + s.clients.removeSub(sub.ClientID, sub) + return err + } + return nil +} + +// updateDurable adds back `sub` to the client and updates the store. +// No lock is needed for `sub` since it has just been created. +func (s *StanServer) updateDurable(ss *subStore, sub *subState) error { + // Reset the hasFailedHB boolean since it may have been set + // if the client previously crashed and server set this + // flag to its subs. + sub.hasFailedHB = false + // Store in the client + if !s.clients.addSub(sub.ClientID, sub) { + return fmt.Errorf("can't find clientID: %v", sub.ClientID) + } + // Update this subscription in the store + if err := sub.store.UpdateSub(&sub.SubState); err != nil { + return err + } + ss.Lock() + // Do this only for durable subscribers (not durable queue subscribers). + if sub.isDurableSubscriber() { + // Add back into plain subscribers + ss.psubs = append(ss.psubs, sub) + } + // And in ackInbox lookup map. + ss.acks[sub.AckInbox] = sub + ss.Unlock() + + return nil +} + +// processSub adds the subscription to the server. +func (s *StanServer) processSub(c *channel, sr *pb.SubscriptionRequest, ackInbox string) (*subState, error) { + // If channel not provided, we have to look it up + var err error + if c == nil { + c, err = s.lookupOrCreateChannel(sr.Subject) + if err != nil { + s.log.Errorf("Unable to create channel for subscription on %q", sr.Subject) + return nil, err + } + } + var ( + sub *subState + ss = c.ss + ) + // Will be true for durable queue subscribers and durable subscribers alike. + isDurable := false + // Will be set to false for en existing durable subscriber or existing + // queue group (durable or not). + setStartPos := true + // Check for durable queue subscribers + if sr.QGroup != "" { + if sr.DurableName != "" { + // For queue subscribers, we prevent DurableName to contain + // the ':' character, since we use it for the compound name. + if strings.Contains(sr.DurableName, ":") { + s.log.Errorf("[Client:%s] Invalid DurableName (%q) for queue subscriber from %s", + sr.ClientID, sr.DurableName, sr.Subject) + return nil, ErrInvalidDurName + } + isDurable = true + // Make the queue group a compound name between durable name and q group. + sr.QGroup = fmt.Sprintf("%s:%s", sr.DurableName, sr.QGroup) + // Clear DurableName from this subscriber. + sr.DurableName = "" + } + // Lookup for an existing group. Only interested in situation where + // the group exist, but is empty and had a shadow subscriber. + ss.RLock() + qs := ss.qsubs[sr.QGroup] + if qs != nil { + qs.Lock() + if qs.shadow != nil { + sub = qs.shadow + qs.shadow = nil + qs.subs = append(qs.subs, sub) + } + qs.Unlock() + setStartPos = false + } + ss.RUnlock() + } else if sr.DurableName != "" { + // Check for DurableSubscriber status + if sub = ss.LookupByDurable(durableKey(sr)); sub != nil { + sub.RLock() + clientID := sub.ClientID + sub.RUnlock() + if clientID != "" { + s.log.Errorf("[Client:%s] Duplicate durable subscription registration", sr.ClientID) + return nil, ErrDupDurable + } + setStartPos = false + } + isDurable = true + } + var ( + subStartTrace string + subIsNew bool + ) + if sub != nil { + // ok we have a remembered subscription + sub.Lock() + // Set ClientID and new AckInbox but leave LastSent to the + // remembered value. + sub.AckInbox = ackInbox + sub.ClientID = sr.ClientID + sub.Inbox = sr.Inbox + sub.IsDurable = true + // Use some of the new options, but ignore the ones regarding start position + sub.MaxInFlight = sr.MaxInFlight + sub.AckWaitInSecs = sr.AckWaitInSecs + sub.ackWait = computeAckWait(sr.AckWaitInSecs) + sub.stalled = false + if len(sub.acksPending) > 0 { + // We have a durable with pending messages, set newOnHold + // until we have performed the initial redelivery. + sub.newOnHold = true + if !s.isClustered || s.isLeader() { + s.setupAckTimer(sub, sub.ackWait) + } + } + // Clear the IsClosed flags that were set during a Close() + sub.IsClosed = false + sub.Unlock() + + // Case of restarted durable subscriber, or first durable queue + // subscriber re-joining a group that was left with pending messages. + err = s.updateDurable(ss, sub) + } else { + subIsNew = true + // Create sub here (can be plain, durable or queue subscriber) + sub = &subState{ + SubState: spb.SubState{ + ClientID: sr.ClientID, + QGroup: sr.QGroup, + Inbox: sr.Inbox, + AckInbox: ackInbox, + MaxInFlight: sr.MaxInFlight, + AckWaitInSecs: sr.AckWaitInSecs, + DurableName: sr.DurableName, + IsDurable: isDurable, + }, + subject: sr.Subject, + ackWait: computeAckWait(sr.AckWaitInSecs), + acksPending: make(map[uint64]int64), + store: c.store.Subs, + } + + if setStartPos { + // set the start sequence of the subscriber. + var lastSent uint64 + subStartTrace, lastSent, err = s.setSubStartSequence(c, sr) + if err == nil { + sub.LastSent = lastSent + } + } + + if err == nil { + // add the subscription to stan + err = s.addSubscription(ss, sub) + } + } + if err == nil && (!s.isClustered || s.isLeader()) { + err = sub.startAckSub(s.nca, s.processAckMsg) + if err == nil { + // Need tp make sure that this subscription is processed by + // NATS Server before sending response (since we use different + // connection to send the response) + s.nca.Flush() + } + } + if err != nil { + // Try to undo what has been done. + s.closeMu.Lock() + ss.Remove(c, sub, false) + s.closeMu.Unlock() + s.log.Errorf("Unable to add subscription for %s: %v", sr.Subject, err) + return nil, err + } + if s.debug { + traceCtx := subStateTraceCtx{clientID: sr.ClientID, isNew: subIsNew, startTrace: subStartTrace} + traceSubState(s.log, sub, &traceCtx) + } + + s.monMu.Lock() + s.numSubs++ + s.monMu.Unlock() + + return sub, nil +} + +// processSubscriptionRequest will process a subscription request. +func (s *StanServer) processSubscriptionRequest(m *nats.Msg) { + sr := &pb.SubscriptionRequest{} + err := sr.Unmarshal(m.Data) + if err != nil { + s.log.Errorf("Invalid Subscription request from %s: %v", m.Subject, err) + s.sendSubscriptionResponseErr(m.Reply, ErrInvalidSubReq) + return + } + + // ClientID must not be empty. + if sr.ClientID == "" { + s.log.Errorf("Missing ClientID in subscription request from %s", m.Subject) + s.sendSubscriptionResponseErr(m.Reply, ErrMissingClient) + return + } + + // AckWait must be >= 1s (except in test mode where negative value means that + // duration should be interpreted as Milliseconds) + if !testAckWaitIsInMillisecond && sr.AckWaitInSecs <= 0 { + s.log.Errorf("[Client:%s] Invalid AckWait (%v) in subscription request from %s", + sr.ClientID, sr.AckWaitInSecs, m.Subject) + s.sendSubscriptionResponseErr(m.Reply, ErrInvalidAckWait) + return + } + + // MaxInflight must be >= 1 + if sr.MaxInFlight <= 0 { + s.log.Errorf("[Client:%s] Invalid MaxInflight (%v) in subscription request from %s", + sr.ClientID, sr.MaxInFlight, m.Subject) + s.sendSubscriptionResponseErr(m.Reply, ErrInvalidMaxInflight) + return + } + + // StartPosition between StartPosition_NewOnly and StartPosition_First + if sr.StartPosition < pb.StartPosition_NewOnly || sr.StartPosition > pb.StartPosition_First { + s.log.Errorf("[Client:%s] Invalid StartPosition (%v) in subscription request from %s", + sr.ClientID, int(sr.StartPosition), m.Subject) + s.sendSubscriptionResponseErr(m.Reply, ErrInvalidStart) + return + } + + // Make sure subject is valid + if !util.IsChannelNameValid(sr.Subject, false) { + s.log.Errorf("[Client:%s] Invalid channel %q in subscription request from %s", + sr.ClientID, sr.Subject, m.Subject) + s.sendSubscriptionResponseErr(m.Reply, ErrInvalidSubject) + return + } + + // In partitioning mode, do not fail the subscription request + // if this server does not have the channel. It could be that there + // is another server out there that will accept the subscription. + // If not, the client will get a subscription request timeout. + if s.partitions != nil { + if r := s.partitions.sl.Match(sr.Subject); len(r) == 0 { + return + } + // Also check that the connection request has already + // been processed. Check clientCheckTimeout doc for details. + if !s.clients.isValidWithTimeout(sr.ClientID, nil, clientCheckTimeout) { + s.log.Errorf("[Client:%s] Rejecting subscription on %q: connection not created yet", + sr.ClientID, sr.Subject) + s.sendSubscriptionResponseErr(m.Reply, ErrInvalidSubReq) + return + } + } + + var ( + sub *subState + ackInbox = nats.NewInbox() + ) + + c, err := s.lookupOrCreateChannel(sr.Subject) + if err == nil { + // Keep the channel delete mutex to ensure that channel cannot be + // deleted while we are about to add a subscription. + s.channels.lockDelete() + defer s.channels.unlockDelete() + + // If clustered, thread operations through Raft. + if s.isClustered { + // For start requests other than SequenceStart, we MUST convert the request + // to a SequenceStart, otherwise, during the replay on server restart, the + // subscription would be created with whatever is the seq at that time. + // For instance, a request with new-only could originally be created with + // the current max seq of 100, but when the cluster is restarted and sub + // request is replayed, the channel's current max may be 200, which + // would cause the subscription to be created at start 200, which could cause + // subscription to miss all messages in between. + if sr.StartPosition != pb.StartPosition_SequenceStart { + // Figure out what the sequence should be based on orinal StartPosition + // request. + var seq uint64 + _, seq, err = s.setSubStartSequence(c, sr) + if err == nil { + // Convert to a SequenceStart start position with the proper sequence + // number. + sr.StartPosition = pb.StartPosition_SequenceStart + sr.StartSequence = seq + } + } + sub, err = s.replicateSub(sr, ackInbox) + } else { + sub, err = s.processSub(c, sr, ackInbox) + } + } + if err != nil { + s.sendSubscriptionResponseErr(m.Reply, err) + return + } + // If the channel has a MaxInactivity limit, stop the timer since we know that there + // is at least one active subscription. + if c.activity != nil { + // We are under the channelStore delete mutex + c.stopDeleteTimer() + } + + // In case this is a durable, sub already exists so we need to protect access + sub.Lock() + + // Create a non-error response + resp := &pb.SubscriptionResponse{AckInbox: sub.AckInbox} + b, _ := resp.Marshal() + s.ncs.Publish(m.Reply, b) + + // Capture under lock here + qs := sub.qstate + // Now that we have sent the response, we set the subscription to initialized, + // which allows messages to be sent to it - but not sooner (which could happen + // without this since the subscription is added to the system earlier and + // incoming messages to the channel would trigger delivery). + sub.initialized = true + sub.Unlock() + + s.subStartCh <- &subStartInfo{c: c, sub: sub, qs: qs, isDurable: sub.IsDurable} +} + +type subStateTraceCtx struct { + clientID string + isRemove bool + isNew bool + isUnsubscribe bool + isGroupEmpty bool + startTrace string +} + +func traceSubState(log logger.Logger, sub *subState, ctx *subStateTraceCtx) { + sub.RLock() + defer sub.RUnlock() + var ( + action string + specific string + durable string + sending string + queue string + prefix string + ) + if sub.IsDurable { + durable = "durable " + } + if sub.QGroup != "" { + queue = "queue " + } + if ctx.isRemove { + if (ctx.isUnsubscribe || !sub.IsDurable) && (sub.QGroup == "" || ctx.isGroupEmpty) { + prefix = "Removed" + } else if sub.QGroup != "" && !ctx.isGroupEmpty { + prefix = "Removed member from" + } else { + prefix = "Suspended" + } + } else { + if ctx.startTrace != "" { + prefix = "Started new" + } else if sub.QGroup != "" && ctx.isNew { + prefix = "Added member to" + } else if sub.IsDurable { + prefix = "Resumed" + } + } + action = fmt.Sprintf("%s %s%s", prefix, durable, queue) + if sub.QGroup != "" { + specific = fmt.Sprintf(" queue=%s,", sub.QGroup) + } else if sub.IsDurable { + specific = fmt.Sprintf(" durable=%s,", sub.DurableName) + } + if !ctx.isRemove && ctx.startTrace != "" { + sending = ", sending " + ctx.startTrace + } + log.Debugf("[Client:%s] %ssubscription, subject=%s, inbox=%s,%s subid=%d%s", + ctx.clientID, action, sub.subject, sub.Inbox, specific, sub.ID, sending) +} + +func (s *StanServer) processSubscriptionsStart() { + defer s.wg.Done() + for { + select { + case subStart := <-s.subStartCh: + c := subStart.c + sub := subStart.sub + qs := subStart.qs + isDurable := subStart.isDurable + if isDurable { + // Redeliver any outstanding. + s.performDurableRedelivery(c, sub) + } + // publish messages to this subscriber + if qs != nil { + s.sendAvailableMessagesToQueue(c, qs) + } else { + s.sendAvailableMessages(c, sub) + } + case <-s.subStartQuit: + return + } + } +} + +// processAckMsg processes inbound acks from clients for delivered messages. +func (s *StanServer) processAckMsg(m *nats.Msg) { + ack := &pb.Ack{} + if ack.Unmarshal(m.Data) != nil { + if s.processCtrlMsg(m) { + return + } + } + c := s.channels.get(ack.Subject) + if c == nil { + s.log.Errorf("Unable to process ack seq=%d, channel %s not found", ack.Sequence, ack.Subject) + return + } + sub := c.ss.LookupByAckInbox(m.Subject) + if sub == nil { + return + } + s.processAck(c, sub, ack.Sequence) +} + +// processAck processes an ack and if needed sends more messages. +func (s *StanServer) processAck(c *channel, sub *subState, sequence uint64) { + var stalled bool + + // This is immutable, so can grab outside of sub's lock. + // If we have a queue group, we want to grab queue's lock before + // sub's lock. + qs := sub.qstate + if qs != nil { + qs.Lock() + } + + sub.Lock() + + // If in cluster mode, replicate the ack but leader + // does not wait on quorum result. + if s.isClustered { + s.replicateSentOrAck(sub, replicateAck, sequence) + } + + if s.trace { + s.log.Tracef("[Client:%s] Processing ack for subid=%d, subject=%s, seq=%d", + sub.ClientID, sub.ID, sub.subject, sequence) + } + + if err := sub.store.AckSeqPending(sub.ID, sequence); err != nil { + s.log.Errorf("[Client:%s] Unable to persist ack for subid=%d, subject=%s, seq=%d, err=%v", + sub.ClientID, sub.ID, sub.subject, sequence, err) + sub.Unlock() + if qs != nil { + qs.Unlock() + } + return + } + + delete(sub.acksPending, sequence) + if sub.stalled && int32(len(sub.acksPending)) < sub.MaxInFlight { + // For queue, we must not check the queue stalled count here. The queue + // as a whole may not be stalled, yet, if this sub was stalled, it is + // not now since the pending acks is below MaxInflight. The server should + // try to send available messages. + // It works also if the queue *was* stalled (all members were stalled), + // then this member is no longer stalled, which release the queue. + + // Trigger send of available messages by setting this to true. + stalled = true + + // Clear the stalled flag from this sub + sub.stalled = false + // .. and update the queue's stalled members count if this is a queue sub. + if qs != nil && qs.stalledSubCount > 0 { + qs.stalledSubCount-- + } + } + sub.Unlock() + if qs != nil { + qs.Unlock() + } + + // Leave the reset/cancel of the ackTimer to the redelivery cb. + + if !stalled { + return + } + + if sub.qstate != nil { + s.sendAvailableMessagesToQueue(c, sub.qstate) + } else { + s.sendAvailableMessages(c, sub) + } +} + +// Send any messages that are ready to be sent that have been queued to the group. +func (s *StanServer) sendAvailableMessagesToQueue(c *channel, qs *queueState) { + if c == nil || qs == nil { + return + } + + qs.Lock() + // Short circuit if no active members + if len(qs.subs) == 0 { + qs.Unlock() + return + } + // If redelivery at startup in progress, don't attempt to deliver new messages + if qs.newOnHold { + qs.Unlock() + return + } + for nextSeq := qs.lastSent + 1; qs.stalledSubCount < len(qs.subs); nextSeq++ { + nextMsg := s.getNextMsg(c, &nextSeq, &qs.lastSent) + if nextMsg == nil { + break + } + if _, sent, sendMore := s.sendMsgToQueueGroup(qs, nextMsg, honorMaxInFlight); !sent || !sendMore { + break + } + } + qs.Unlock() +} + +// Send any messages that are ready to be sent that have been queued. +func (s *StanServer) sendAvailableMessages(c *channel, sub *subState) { + sub.Lock() + for nextSeq := sub.LastSent + 1; !sub.stalled; nextSeq++ { + nextMsg := s.getNextMsg(c, &nextSeq, &sub.LastSent) + if nextMsg == nil { + break + } + if sent, sendMore := s.sendMsgToSub(sub, nextMsg, honorMaxInFlight); !sent || !sendMore { + break + } + } + sub.Unlock() +} + +func (s *StanServer) getNextMsg(c *channel, nextSeq, lastSent *uint64) *pb.MsgProto { + for { + nextMsg, err := c.store.Msgs.Lookup(*nextSeq) + if err != nil { + s.log.Errorf("Error looking up message %v:%v (%v)", c.name, *nextSeq, err) + // TODO: This will stop delivery. Will revisit later to see if we + // should move to the next message (if avail) or not. + return nil + } + if nextMsg != nil { + return nextMsg + } + first, last, _ := c.store.Msgs.FirstAndLastSequence() + if *nextSeq < first { + *nextSeq = first + *lastSent = first - 1 + } else if *nextSeq >= last { + return nil + } else { + *nextSeq++ + *lastSent++ + } + + // Note that the next lookup could still fail because + // the first avail message may have been dropped in the + // meantime. + } +} + +// Setup the start position for the subscriber. +func (s *StanServer) setSubStartSequence(c *channel, sr *pb.SubscriptionRequest) (string, uint64, error) { + lastSent := uint64(0) + debugTrace := "" + + // In all start position cases, if there is no message, ensure + // lastSent stays at 0. + + switch sr.StartPosition { + case pb.StartPosition_NewOnly: + var err error + lastSent, err = c.store.Msgs.LastSequence() + if err != nil { + return "", 0, err + } + if s.debug { + debugTrace = fmt.Sprintf("new-only, seq=%d", lastSent+1) + } + case pb.StartPosition_LastReceived: + lastSeq, err := c.store.Msgs.LastSequence() + if err != nil { + return "", 0, err + } + if lastSeq > 0 { + lastSent = lastSeq - 1 + } + if s.debug { + debugTrace = fmt.Sprintf("last message, seq=%d", lastSent+1) + } + case pb.StartPosition_TimeDeltaStart: + startTime := time.Now().UnixNano() - sr.StartTimeDelta + // If there is no message, seq will be 0. + seq, err := c.store.Msgs.GetSequenceFromTimestamp(startTime) + if err != nil { + return "", 0, err + } + if seq > 0 { + // If the time delta is in the future relative to the last + // message in the log, 'seq' will be equal to last sequence + 1, + // so this would translate to "new only" semantic. + lastSent = seq - 1 + } + if s.debug { + debugTrace = fmt.Sprintf("from time time='%v' seq=%d", time.Unix(0, startTime), lastSent+1) + } + case pb.StartPosition_SequenceStart: + // If there is no message, firstSeq and lastSeq will be equal to 0. + firstSeq, lastSeq, err := c.store.Msgs.FirstAndLastSequence() + if err != nil { + return "", 0, err + } + // StartSequence is an uint64, so can't be lower than 0. + if sr.StartSequence < firstSeq { + // That translates to sending the first message available. + lastSent = firstSeq - 1 + } else if sr.StartSequence > lastSeq { + // That translates to "new only" + lastSent = lastSeq + } else if sr.StartSequence > 0 { + // That translates to sending the message with StartSequence + // sequence number. + lastSent = sr.StartSequence - 1 + } + if s.debug { + debugTrace = fmt.Sprintf("from sequence, asked_seq=%d actual_seq=%d", sr.StartSequence, lastSent+1) + } + case pb.StartPosition_First: + firstSeq, err := c.store.Msgs.FirstSequence() + if err != nil { + return "", 0, err + } + if firstSeq > 0 { + lastSent = firstSeq - 1 + } + if s.debug { + debugTrace = fmt.Sprintf("from beginning, seq=%d", lastSent+1) + } + } + return debugTrace, lastSent, nil +} + +// startGoRoutine starts the given function as a go routine if and only if +// the server was not shutdown at that time. This is required because +// we cannot increment the wait group after the shutdown process has started. +func (s *StanServer) startGoRoutine(f func()) { + s.mu.Lock() + if !s.shutdown { + s.wg.Add(1) + go f() + } + s.mu.Unlock() +} + +// ClusterID returns the NATS Streaming Server's ID. +func (s *StanServer) ClusterID() string { + s.mu.RLock() + defer s.mu.RUnlock() + return s.info.ClusterID +} + +// State returns the state of this server. +func (s *StanServer) State() State { + s.mu.RLock() + defer s.mu.RUnlock() + return s.state +} + +// setLastError sets the last fatal error that occurred. This is +// used in case of an async error that cannot directly be reported +// to the user. +func (s *StanServer) setLastError(err error) { + s.mu.Lock() + s.lastError = err + s.state = Failed + s.mu.Unlock() + s.log.Fatalf("%v", err) +} + +// LastError returns the last fatal error the server experienced. +func (s *StanServer) LastError() error { + s.mu.RLock() + defer s.mu.RUnlock() + return s.lastError +} + +// Shutdown will close our NATS connection and shutdown any embedded NATS server. +func (s *StanServer) Shutdown() { + s.log.Noticef("Shutting down.") + + s.mu.Lock() + if s.shutdown { + s.mu.Unlock() + return + } + + close(s.shutdownCh) + + // Allows Shutdown() to be idempotent + s.shutdown = true + // Change the state too + s.state = Shutdown + + // We need to make sure that the storeIOLoop returns before + // closing the Store + waitForIOStoreLoop := true + + // Capture under lock + store := s.store + ns := s.natsServer + // Do not close and nil the connections here, they are used in many places + // without locking. Once closed, s.nc.xxx() calls will simply fail, but + // we won't panic. + ncs := s.ncs + ncr := s.ncr + ncsr := s.ncsr + nc := s.nc + ftnc := s.ftnc + nca := s.nca + + // Stop processing subscriptions start requests + s.subStartQuit <- struct{}{} + + if s.ioChannel != nil { + // Notify the IO channel that we are shutting down + close(s.ioChannelQuit) + } else { + waitForIOStoreLoop = false + } + // In case we are running in FT mode. + if s.ftQuit != nil { + s.ftQuit <- struct{}{} + } + // In case we are running in Partitioning mode + if s.partitions != nil { + s.partitions.shutdown() + } + s.mu.Unlock() + + // Make sure the StoreIOLoop returns before closing the Store + if waitForIOStoreLoop { + s.ioChannelWG.Wait() + } + + // Close Raft group before closing store. + if s.raft != nil { + if err := s.raft.shutdown(); err != nil { + s.log.Errorf("Failed to stop Raft node: %v", err) + } + } + + // Close/Shutdown resources. Note that unless one instantiates StanServer + // directly (instead of calling RunServer() and the like), these should + // not be nil. + if store != nil { + store.Close() + } + if ncs != nil { + ncs.Close() + } + if ncr != nil { + ncr.Close() + } + if ncsr != nil { + ncsr.Close() + } + if nc != nil { + nc.Close() + } + if ftnc != nil { + ftnc.Close() + } + if nca != nil { + nca.Close() + } + if ns != nil { + ns.Shutdown() + } + + // Wait for go-routines to return + s.wg.Wait() +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/service.go b/vendor/github.com/nats-io/nats-streaming-server/server/service.go new file mode 100644 index 00000000000..ba7eaeb6d80 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/service.go @@ -0,0 +1,31 @@ +// Copyright 2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !windows + +package server + +import ( + natsd "github.com/nats-io/gnatsd/server" +) + +// Run starts the NATS Streaming server. This wrapper function allows Windows to add a +// hook for running NATS Streaming as a service. +func Run(sOpts *Options, nOpts *natsd.Options) (*StanServer, error) { + return RunServerWithOpts(sOpts, nOpts) +} + +// isWindowsService indicates if NATS Streaming is running as a Windows service. +func isWindowsService() bool { + return false +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/service_windows.go b/vendor/github.com/nats-io/nats-streaming-server/server/service_windows.go new file mode 100644 index 00000000000..6cd1d5f6983 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/service_windows.go @@ -0,0 +1,185 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "os" + "strings" + "sync" + + natsdLogger "github.com/nats-io/gnatsd/logger" + natsd "github.com/nats-io/gnatsd/server" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/debug" + "golang.org/x/sys/windows/svc/eventlog" +) + +const ( + serviceName = "nats-streaming-server" + reopenLogCode = 128 + reopenLogCmd = svc.Cmd(reopenLogCode) + acceptReopenLog = svc.Accepted(reopenLogCode) +) + +// winServiceWrapper implements the svc.Handler interface for implementing +// nats-streaming-server as a Windows service. +type winServiceWrapper struct { + sOpts *Options + nOpts *natsd.Options + srvCh chan *StanServer + errCh chan error +} + +var ( + dockerized = false + sysLogInitLock sync.Mutex + sysLog *eventlog.Log + sysLogName = "NATS-Streaming-Server" +) + +func init() { + if v, exists := os.LookupEnv("NATS_DOCKERIZED"); exists && v == "1" { + dockerized = true + } + // Set the default event source name. This may be changed when the + // server will configure the logger if a SyslogName option is specified. + natsdLogger.SetSyslogName(sysLogName) + // This is so the gnatsd's signal code works for streaming service + natsd.SetServiceName(serviceName) +} + +// Execute will be called by the package code at the start of +// the service, and the service will exit once Execute completes. +// Inside Execute you must read service change requests from r and +// act accordingly. You must keep service control manager up to date +// about state of your service by writing into s as required. +// args contains service name followed by argument strings passed +// to the service. +// You can provide service exit code in exitCode return parameter, +// with 0 being "no error". You can also indicate if exit code, +// if any, is service specific or not by using svcSpecificEC +// parameter. +func (w *winServiceWrapper) Execute(args []string, changes <-chan svc.ChangeRequest, status chan<- svc.Status) (bool, uint32) { + + status <- svc.Status{State: svc.StartPending} + + if sysLog != nil { + sysLog.Info(1, "Starting NATS Streaming Server...") + } + // Override NoSigs since we are doing signal handling HERE + w.sOpts.HandleSignals = false + server, err := RunServerWithOpts(w.sOpts, w.nOpts) + if err != nil && sysLog != nil { + sysLog.Error(2, fmt.Sprintf("Starting server returned: %v", err)) + } + if err != nil { + w.errCh <- err + // Failed to start. + return true, 1 + } + status <- svc.Status{ + State: svc.Running, + Accepts: svc.AcceptStop | svc.AcceptShutdown | svc.AcceptParamChange | acceptReopenLog, + } + w.srvCh <- server + +loop: + for change := range changes { + switch change.Cmd { + case svc.Interrogate: + status <- change.CurrentStatus + case svc.Stop, svc.Shutdown: + status <- svc.Status{State: svc.StopPending} + server.Shutdown() + break loop + case reopenLogCmd: + // File log re-open for rotating file logs. + server.log.ReopenLogFile() + case svc.ParamChange: + // Ignore for now + default: + server.log.Debugf("Unexpected control request: %v", change.Cmd) + } + } + + status <- svc.Status{State: svc.Stopped} + return false, 0 +} + +// Run starts the NATS Streaming server. This wrapper function allows Windows to add a +// hook for running NATS Streaming as a service. +func Run(sOpts *Options, nOpts *natsd.Options) (*StanServer, error) { + if dockerized { + return RunServerWithOpts(sOpts, nOpts) + } + run := svc.Run + isInteractive, err := svc.IsAnInteractiveSession() + if err != nil { + return nil, err + } + if isInteractive { + run = debug.Run + } else { + sysLogInitLock.Lock() + // We create a syslog here because we want to capture possible startup + // failure message. + if sysLog == nil { + if sOpts.SyslogName != "" { + sysLogName = sOpts.SyslogName + } + err := eventlog.InstallAsEventCreate(sysLogName, eventlog.Info|eventlog.Error|eventlog.Warning) + if err != nil { + if !strings.Contains(err.Error(), "registry key already exists") { + panic(err) + } + } + sysLog, err = eventlog.Open(sysLogName) + if err != nil { + panic(fmt.Sprintf("could not open event log: %v", err)) + } + } + sysLogInitLock.Unlock() + } + wrapper := &winServiceWrapper{ + srvCh: make(chan *StanServer, 1), + errCh: make(chan error, 1), + sOpts: sOpts, + nOpts: nOpts, + } + go func() { + // If no error, we exit here, otherwise, we are getting the + // error down below. + if err := run(serviceName, wrapper); err == nil { + os.Exit(0) + } + }() + + var srv *StanServer + // Wait for server instance to be created + select { + case err = <-wrapper.errCh: + case srv = <-wrapper.srvCh: + } + return srv, err +} + +// isWindowsService indicates if NATS is running as a Windows service. +func isWindowsService() bool { + if dockerized { + return false + } + isInteractive, _ := svc.IsAnInteractiveSession() + return !isInteractive +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/signal.go b/vendor/github.com/nats-io/nats-streaming-server/server/signal.go new file mode 100644 index 00000000000..c5868f58842 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/signal.go @@ -0,0 +1,53 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !windows + +package server + +import ( + "os" + "os/signal" + "syscall" + + natsd "github.com/nats-io/gnatsd/server" +) + +func init() { + // Set the process name so signal code use this process name + // instead of gnatsd. + natsd.SetProcessName("nats-streaming-server") +} + +// Signal Handling +func (s *StanServer) handleSignals() { + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM, syscall.SIGUSR1, syscall.SIGHUP) + go func() { + for sig := range c { + // Notify will relay only the signals that we have + // registered, so we don't need a "default" in the + // switch statement. + switch sig { + case syscall.SIGINT, syscall.SIGTERM: + s.Shutdown() + os.Exit(0) + case syscall.SIGUSR1: + // File log re-open for rotating file logs. + s.log.ReopenLogFile() + case syscall.SIGHUP: + // Ignore for now + } + } + }() +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/signal_windows.go b/vendor/github.com/nats-io/nats-streaming-server/server/signal_windows.go new file mode 100644 index 00000000000..42768969633 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/signal_windows.go @@ -0,0 +1,33 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "os" + "os/signal" +) + +// Signal Handling +func (s *StanServer) handleSignals() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + go func() { + // We register only 1 signal (os.Interrupt) so we don't + // need to check which one we get, since Notify() relays + // only the ones that are registered. + <-c + s.Shutdown() + os.Exit(0) + }() +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/snapshot.go b/vendor/github.com/nats-io/nats-streaming-server/server/snapshot.go new file mode 100644 index 00000000000..1fd0fda9f45 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/snapshot.go @@ -0,0 +1,421 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "io" + "time" + + "github.com/hashicorp/raft" + "github.com/nats-io/go-nats" + "github.com/nats-io/go-nats-streaming/pb" + "github.com/nats-io/nats-streaming-server/spb" + "github.com/nats-io/nats-streaming-server/util" +) + +// serverSnapshot implements the raft.FSMSnapshot interface by snapshotting +// StanServer state. +type serverSnapshot struct { + *StanServer +} + +// Snapshot is used to support log compaction. This call should +// return an FSMSnapshot which can be used to save a point-in-time +// snapshot of the FSM. Apply and Snapshot are not called in multiple +// threads, but Apply will be called concurrently with Persist. This means +// the FSM should be implemented in a fashion that allows for concurrent +// updates while a snapshot is happening. +func (r *raftFSM) Snapshot() (raft.FSMSnapshot, error) { + return &serverSnapshot{r.server}, nil +} + +// Persist should dump all necessary state to the WriteCloser 'sink', +// and call sink.Close() when finished or call sink.Cancel() on error. +func (s *serverSnapshot) Persist(sink raft.SnapshotSink) (err error) { + defer func() { + if err != nil { + sink.Cancel() + } + }() + + snap := &spb.RaftSnapshot{} + + s.snapshotClients(snap, sink) + + if err := s.snapshotChannels(snap); err != nil { + return err + } + + var b []byte + for i := 0; i < 2; i++ { + b, err = snap.Marshal() + if err != nil { + return err + } + // Raft assumes that the follower will restore the snapshot from the leader using + // a timeout that is equal to: + // the transport timeout (we provide 2*time.Second) * (snapshot size / TimeoutScale). + // We can't provide an infinite timeout to the nats-log transport otherwise some + // Raft operations will block forever. + // However, we persist only the first/last sequence for a channel snapshot, however, + // the follower will request the leader to send all those messages as part of the + // restore. Since the snapshot size may be small compared to the amout of messages + // to recover, the timeout may be too small. + // To trick the system, we first set the transport's TimeoutScale to 1 (the default + // is 256KB). Then, if we want an overall timeout of 1 hour (3600 seconds), we need + // the size to be at least 1800 bytes. If it is less than that, then we add some + // padding to the snapshot. + if len(b) < 1800 { + snap.Padding = make([]byte, 1800-len(b)) + continue + } + break + } + + var sizeBuf [4]byte + util.ByteOrder.PutUint32(sizeBuf[:], uint32(len(b))) + if _, err := sink.Write(sizeBuf[:]); err != nil { + return err + } + if _, err := sink.Write(b); err != nil { + return err + } + + return sink.Close() +} + +func (s *serverSnapshot) snapshotClients(snap *spb.RaftSnapshot, sink raft.SnapshotSink) { + s.clients.RLock() + defer s.clients.RUnlock() + + numClients := len(s.clients.clients) + if numClients == 0 { + return + } + + snap.Clients = make([]*spb.ClientInfo, numClients) + i := 0 + for _, client := range s.clients.clients { + // Make a copy + info := client.info.ClientInfo + snap.Clients[i] = &info + i++ + } +} + +func (s *serverSnapshot) snapshotChannels(snap *spb.RaftSnapshot) error { + s.channels.RLock() + defer s.channels.RUnlock() + + numChannels := len(s.channels.channels) + if numChannels == 0 { + return nil + } + + snapshotASub := func(sub *subState) *spb.SubscriptionSnapshot { + // Make a copy + state := sub.SubState + snapSub := &spb.SubscriptionSnapshot{State: &state} + if len(sub.acksPending) > 0 { + snapSub.AcksPending = make([]uint64, len(sub.acksPending)) + i := 0 + for seq := range sub.acksPending { + snapSub.AcksPending[i] = seq + i++ + } + } + return snapSub + } + + snap.Channels = make([]*spb.ChannelSnapshot, numChannels) + numChannel := 0 + for _, c := range s.channels.channels { + first, last, err := c.store.Msgs.FirstAndLastSequence() + if err != nil { + return err + } + snapChannel := &spb.ChannelSnapshot{ + Channel: c.name, + First: first, + Last: last, + } + c.ss.RLock() + + // Start with count of all plain subs... + snapSubs := make([]*spb.SubscriptionSnapshot, len(c.ss.psubs)) + i := 0 + for _, sub := range c.ss.psubs { + sub.RLock() + snapSubs[i] = snapshotASub(sub) + sub.RUnlock() + i++ + } + + // Now need to close durables + for _, dur := range c.ss.durables { + dur.RLock() + if dur.IsClosed { + // We need to persist a SubState with a ClientID + // so that we can reconstruct the durable key + // on recovery. So set to the saved value here + // and then clear it after that. + dur.ClientID = dur.savedClientID + snapSubs = append(snapSubs, snapshotASub(dur)) + dur.ClientID = "" + } + dur.RUnlock() + } + + // Snapshot the queue subscriptions + for _, qsub := range c.ss.qsubs { + qsub.RLock() + for _, sub := range qsub.subs { + sub.RLock() + snapSubs = append(snapSubs, snapshotASub(sub)) + sub.RUnlock() + } + // If all members of a durable queue group left the group, + // we need to persist the "shadow" queue member. + if qsub.shadow != nil { + qsub.shadow.RLock() + snapSubs = append(snapSubs, snapshotASub(qsub.shadow)) + qsub.shadow.RUnlock() + } + qsub.RUnlock() + } + if len(snapSubs) > 0 { + snapChannel.Subscriptions = snapSubs + } + + c.ss.RUnlock() + snap.Channels[numChannel] = snapChannel + numChannel++ + } + + return nil +} + +// Release is a no-op. +func (s *serverSnapshot) Release() {} + +// Restore is used to restore an FSM from a snapshot. It is not called +// concurrently with any other command. The FSM must discard all previous +// state. +func (r *raftFSM) Restore(snapshot io.ReadCloser) (retErr error) { + defer snapshot.Close() + + r.Lock() + defer r.Unlock() + + // This function may be invoked directly from raft.NewRaft() when + // the node is initialized and if there were exisiting local snapshots, + // or later, when catching up with a leader. We behave differently + // depending on the situation. So we need to know if we are called + // from NewRaft(). + // + // To do so, we first look at the number of local snapshots before + // calling NewRaft(). If the number is > 0, it means that Raft will + // call us within NewRaft(). Raft will restore the latest snapshot + // first, and only in case of Restore() returning an error will move + // to the next (earliest) one. When there are none and Restore() still + // returns an error raft.NewRaft() will return an error. + // + // So on error we decrement the number of snapshots, on success we set + // it to 0. This means that next time Restore() is invoked, we know it + // is restoring from a leader, not from the local snapshots. + inNewRaftCall := r.snapshotsOnInit != 0 + if inNewRaftCall { + defer func() { + if retErr != nil { + r.snapshotsOnInit-- + } else { + r.snapshotsOnInit = 0 + } + }() + } + s := r.server + + // We need to drop current state. The server will recover from snapshot + // and all newer Raft entry logs (basically the entire state is being + // reconstructed from this point on). + for _, c := range s.channels.getAll() { + for _, sub := range c.ss.getAllSubs() { + sub.RLock() + clientID := sub.ClientID + sub.RUnlock() + if err := s.unsubscribeSub(c, clientID, "unsub", sub, false); err != nil { + return err + } + } + } + for clientID := range s.clients.getClients() { + if _, err := s.clients.unregister(clientID); err != nil { + return err + } + } + + sizeBuf := make([]byte, 4) + // Read the snapshot size. + if _, err := io.ReadFull(snapshot, sizeBuf); err != nil { + if err == io.EOF { + return nil + } + return err + } + // Read the snapshot. + size := util.ByteOrder.Uint32(sizeBuf) + buf := make([]byte, size) + if _, err := io.ReadFull(snapshot, buf); err != nil { + return err + } + + serverSnap := &spb.RaftSnapshot{} + if err := serverSnap.Unmarshal(buf); err != nil { + panic(err) + } + if err := r.restoreClientsFromSnapshot(serverSnap); err != nil { + return err + } + return r.restoreChannelsFromSnapshot(serverSnap, inNewRaftCall) +} + +func (r *raftFSM) restoreClientsFromSnapshot(serverSnap *spb.RaftSnapshot) error { + s := r.server + for _, sc := range serverSnap.Clients { + if _, err := s.clients.register(sc); err != nil { + return err + } + } + return nil +} + +func (r *raftFSM) restoreChannelsFromSnapshot(serverSnap *spb.RaftSnapshot, inNewRaftCall bool) error { + s := r.server + + var channelsBeforeRestore map[string]*channel + if !inNewRaftCall { + channelsBeforeRestore = s.channels.getAll() + } + for _, sc := range serverSnap.Channels { + c, err := s.lookupOrCreateChannel(sc.Channel) + if err != nil { + return err + } + // Do not restore messages from snapshot if the server + // just started and is recovering from its own snapshot. + if !inNewRaftCall { + if err := r.restoreMsgsFromSnapshot(c, sc.First, sc.Last); err != nil { + return err + } + delete(channelsBeforeRestore, sc.Channel) + } + for _, ss := range sc.Subscriptions { + s.recoverOneSub(c, ss.State, nil, ss.AcksPending) + } + } + if !inNewRaftCall { + // Now delete channels that we had before the restore. + // This is possible if channels have been deleted while + // this node was not running and snapshot occurred. The + // channels would not be in the snapshot, so we can remove + // them now. + for name := range channelsBeforeRestore { + s.processDeleteChannel(name) + } + } + return nil +} + +func (r *raftFSM) restoreMsgsFromSnapshot(c *channel, first, last uint64) error { + storeFirst, storeLast, err := c.store.Msgs.FirstAndLastSequence() + if err != nil { + return err + } + // If the leader's first sequence is more than our lastSequence+1, + // then we need to empty the store. We don't want to have gaps. + // Same if our first is strictly greater than the leader, or our + // last sequence is more than the leader + if first > storeLast+1 || storeFirst > first || storeLast > last { + if err := c.store.Msgs.Empty(); err != nil { + return err + } + } else if storeLast == last { + // We may have a message with lower sequence than the leader, + // but our last sequence is the same, so nothing to do. + return nil + } else if storeLast > 0 { + // first is less than what we already have, just started + // at our next sequence. + first = storeLast + 1 + } + inbox := nats.NewInbox() + sub, err := c.stan.ncsr.SubscribeSync(inbox) + if err != nil { + return err + } + sub.SetPendingLimits(-1, -1) + defer sub.Unsubscribe() + + subject := fmt.Sprintf("%s.%s.%s", defaultSnapshotPrefix, c.stan.info.ClusterID, c.name) + + var ( + reqBuf [16]byte + reqNext = first + reqStart = first + reqEnd uint64 + batch = uint64(100) + halfBatch = batch / 2 + ) + for seq := first; seq <= last; seq++ { + if seq == reqNext { + reqEnd = reqStart + batch + if reqEnd > last { + reqEnd = last + } + util.ByteOrder.PutUint64(reqBuf[:8], reqStart) + util.ByteOrder.PutUint64(reqBuf[8:16], reqEnd) + if err := c.stan.ncsr.PublishRequest(subject, inbox, reqBuf[:16]); err != nil { + return err + } + if reqEnd != last { + reqNext = reqStart + halfBatch + reqStart = reqEnd + 1 + } + } + resp, err := sub.NextMsg(2 * time.Second) + if err != nil { + return err + } + // It is possible that the leader does not have this message because of + // channel limits. If resp.Data is empty, we are in this situation and + // we are done recovering snapshot. + if len(resp.Data) == 0 { + break + } + msg := &pb.MsgProto{} + if err := msg.Unmarshal(resp.Data); err != nil { + panic(err) + } + if _, err := c.store.Msgs.Store(msg); err != nil { + return err + } + select { + case <-r.server.shutdownCh: + return fmt.Errorf("server shutting down") + default: + } + } + return c.store.Msgs.Flush() +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/server/timeout_reader.go b/vendor/github.com/nats-io/nats-streaming-server/server/timeout_reader.go new file mode 100644 index 00000000000..9efebf87a49 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/server/timeout_reader.go @@ -0,0 +1,82 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bufio" + "errors" + "io" + "runtime" + "time" +) + +const bufferSize = 4096 + +var ErrTimeout = errors.New("natslog: read timeout") + +type timeoutReader struct { + b *bufio.Reader + t time.Time + ch <-chan error + closeFunc func() error +} + +func newTimeoutReader(r io.ReadCloser) *timeoutReader { + return &timeoutReader{ + b: bufio.NewReaderSize(r, bufferSize), + closeFunc: func() error { return r.Close() }, + } +} + +// SetDeadline sets the deadline for all future Read calls. +func (r *timeoutReader) SetDeadline(t time.Time) { + r.t = t +} + +func (r *timeoutReader) Read(b []byte) (n int, err error) { + if r.ch == nil { + if r.t.IsZero() || r.b.Buffered() > 0 { + return r.b.Read(b) + } + ch := make(chan error, 1) + r.ch = ch + go func() { + _, err := r.b.Peek(1) + ch <- err + }() + runtime.Gosched() + } + if r.t.IsZero() { + err = <-r.ch // Block + } else { + select { + case err = <-r.ch: // Poll + default: + select { + case err = <-r.ch: // Timeout + case <-time.After(time.Until(r.t)): + return 0, ErrTimeout + } + } + } + r.ch = nil + if r.b.Buffered() > 0 { + n, _ = r.b.Read(b) + } + return +} + +func (r *timeoutReader) Close() error { + return r.closeFunc() +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/spb/protocol.pb.go b/vendor/github.com/nats-io/nats-streaming-server/spb/protocol.pb.go new file mode 100644 index 00000000000..f6998633555 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/spb/protocol.pb.go @@ -0,0 +1,4350 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: protocol.proto + +/* + Package spb is a generated protocol buffer package. + + It is generated from these files: + protocol.proto + + It has these top-level messages: + SubState + SubStateDelete + SubStateUpdate + ServerInfo + ClientInfo + ClientDelete + CtrlMsg + RaftJoinRequest + RaftJoinResponse + RaftOperation + Batch + AddSubscription + SubSentAndAck + AddClient + RaftSnapshot + ChannelSnapshot + SubscriptionSnapshot +*/ +package spb + +import proto "github.com/gogo/protobuf/proto" +import fmt "fmt" +import math "math" +import _ "github.com/gogo/protobuf/gogoproto" +import pb "github.com/nats-io/go-nats-streaming/pb" + +import io "io" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package + +type CtrlMsg_Type int32 + +const ( + CtrlMsg_SubUnsubscribe CtrlMsg_Type = 0 + CtrlMsg_SubClose CtrlMsg_Type = 1 + CtrlMsg_ConnClose CtrlMsg_Type = 2 + CtrlMsg_FTHeartbeat CtrlMsg_Type = 3 + CtrlMsg_Partitioning CtrlMsg_Type = 4 +) + +var CtrlMsg_Type_name = map[int32]string{ + 0: "SubUnsubscribe", + 1: "SubClose", + 2: "ConnClose", + 3: "FTHeartbeat", + 4: "Partitioning", +} +var CtrlMsg_Type_value = map[string]int32{ + "SubUnsubscribe": 0, + "SubClose": 1, + "ConnClose": 2, + "FTHeartbeat": 3, + "Partitioning": 4, +} + +func (x CtrlMsg_Type) String() string { + return proto.EnumName(CtrlMsg_Type_name, int32(x)) +} +func (CtrlMsg_Type) EnumDescriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{6, 0} } + +type RaftOperation_Type int32 + +const ( + RaftOperation_Publish RaftOperation_Type = 0 + RaftOperation_Subscribe RaftOperation_Type = 1 + RaftOperation_RemoveSubscription RaftOperation_Type = 2 + RaftOperation_CloseSubscription RaftOperation_Type = 3 + RaftOperation_SendAndAck RaftOperation_Type = 4 + RaftOperation_Connect RaftOperation_Type = 6 + RaftOperation_Disconnect RaftOperation_Type = 7 + RaftOperation_DeleteChannel RaftOperation_Type = 8 +) + +var RaftOperation_Type_name = map[int32]string{ + 0: "Publish", + 1: "Subscribe", + 2: "RemoveSubscription", + 3: "CloseSubscription", + 4: "SendAndAck", + 6: "Connect", + 7: "Disconnect", + 8: "DeleteChannel", +} +var RaftOperation_Type_value = map[string]int32{ + "Publish": 0, + "Subscribe": 1, + "RemoveSubscription": 2, + "CloseSubscription": 3, + "SendAndAck": 4, + "Connect": 6, + "Disconnect": 7, + "DeleteChannel": 8, +} + +func (x RaftOperation_Type) String() string { + return proto.EnumName(RaftOperation_Type_name, int32(x)) +} +func (RaftOperation_Type) EnumDescriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{9, 0} } + +// SubState represents the state of a Subscription +type SubState struct { + ID uint64 `protobuf:"varint,1,opt,name=ID,proto3" json:"ID,omitempty"` + ClientID string `protobuf:"bytes,2,opt,name=clientID,proto3" json:"clientID,omitempty"` + QGroup string `protobuf:"bytes,3,opt,name=qGroup,proto3" json:"qGroup,omitempty"` + Inbox string `protobuf:"bytes,4,opt,name=inbox,proto3" json:"inbox,omitempty"` + AckInbox string `protobuf:"bytes,5,opt,name=ackInbox,proto3" json:"ackInbox,omitempty"` + MaxInFlight int32 `protobuf:"varint,6,opt,name=maxInFlight,proto3" json:"maxInFlight,omitempty"` + AckWaitInSecs int32 `protobuf:"varint,7,opt,name=ackWaitInSecs,proto3" json:"ackWaitInSecs,omitempty"` + DurableName string `protobuf:"bytes,8,opt,name=durableName,proto3" json:"durableName,omitempty"` + LastSent uint64 `protobuf:"varint,9,opt,name=lastSent,proto3" json:"lastSent,omitempty"` + IsDurable bool `protobuf:"varint,10,opt,name=isDurable,proto3" json:"isDurable,omitempty"` + IsClosed bool `protobuf:"varint,11,opt,name=isClosed,proto3" json:"isClosed,omitempty"` +} + +func (m *SubState) Reset() { *m = SubState{} } +func (m *SubState) String() string { return proto.CompactTextString(m) } +func (*SubState) ProtoMessage() {} +func (*SubState) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{0} } + +// SubStateDelete marks a Subscription as deleted +type SubStateDelete struct { + ID uint64 `protobuf:"varint,1,opt,name=ID,proto3" json:"ID,omitempty"` +} + +func (m *SubStateDelete) Reset() { *m = SubStateDelete{} } +func (m *SubStateDelete) String() string { return proto.CompactTextString(m) } +func (*SubStateDelete) ProtoMessage() {} +func (*SubStateDelete) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{1} } + +// SubStateUpdate represents a subscription update (either Msg or Ack) +type SubStateUpdate struct { + ID uint64 `protobuf:"varint,1,opt,name=ID,proto3" json:"ID,omitempty"` + Seqno uint64 `protobuf:"varint,2,opt,name=seqno,proto3" json:"seqno,omitempty"` +} + +func (m *SubStateUpdate) Reset() { *m = SubStateUpdate{} } +func (m *SubStateUpdate) String() string { return proto.CompactTextString(m) } +func (*SubStateUpdate) ProtoMessage() {} +func (*SubStateUpdate) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{2} } + +// ServerInfo contains basic information regarding the Server +type ServerInfo struct { + ClusterID string `protobuf:"bytes,1,opt,name=ClusterID,proto3" json:"ClusterID,omitempty"` + Discovery string `protobuf:"bytes,2,opt,name=Discovery,proto3" json:"Discovery,omitempty"` + Publish string `protobuf:"bytes,3,opt,name=Publish,proto3" json:"Publish,omitempty"` + Subscribe string `protobuf:"bytes,4,opt,name=Subscribe,proto3" json:"Subscribe,omitempty"` + Unsubscribe string `protobuf:"bytes,5,opt,name=Unsubscribe,proto3" json:"Unsubscribe,omitempty"` + Close string `protobuf:"bytes,6,opt,name=Close,proto3" json:"Close,omitempty"` + SubClose string `protobuf:"bytes,7,opt,name=SubClose,proto3" json:"SubClose,omitempty"` + AcksSubs string `protobuf:"bytes,8,opt,name=AcksSubs,proto3" json:"AcksSubs,omitempty"` + NodeID string `protobuf:"bytes,9,opt,name=NodeID,proto3" json:"NodeID,omitempty"` +} + +func (m *ServerInfo) Reset() { *m = ServerInfo{} } +func (m *ServerInfo) String() string { return proto.CompactTextString(m) } +func (*ServerInfo) ProtoMessage() {} +func (*ServerInfo) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{3} } + +// ClientInfo contains information related to a Client +type ClientInfo struct { + ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` + HbInbox string `protobuf:"bytes,2,opt,name=HbInbox,proto3" json:"HbInbox,omitempty"` + ConnID []byte `protobuf:"bytes,3,opt,name=ConnID,proto3" json:"ConnID,omitempty"` + Protocol int32 `protobuf:"varint,4,opt,name=Protocol,proto3" json:"Protocol,omitempty"` + PingInterval int32 `protobuf:"varint,5,opt,name=PingInterval,proto3" json:"PingInterval,omitempty"` + PingMaxOut int32 `protobuf:"varint,6,opt,name=PingMaxOut,proto3" json:"PingMaxOut,omitempty"` +} + +func (m *ClientInfo) Reset() { *m = ClientInfo{} } +func (m *ClientInfo) String() string { return proto.CompactTextString(m) } +func (*ClientInfo) ProtoMessage() {} +func (*ClientInfo) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{4} } + +type ClientDelete struct { + ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` +} + +func (m *ClientDelete) Reset() { *m = ClientDelete{} } +func (m *ClientDelete) String() string { return proto.CompactTextString(m) } +func (*ClientDelete) ProtoMessage() {} +func (*ClientDelete) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{5} } + +type CtrlMsg struct { + MsgType CtrlMsg_Type `protobuf:"varint,1,opt,name=MsgType,proto3,enum=spb.CtrlMsg_Type" json:"MsgType,omitempty"` + ServerID string `protobuf:"bytes,2,opt,name=ServerID,proto3" json:"ServerID,omitempty"` + Data []byte `protobuf:"bytes,3,opt,name=Data,proto3" json:"Data,omitempty"` + RefID string `protobuf:"bytes,4,opt,name=RefID,proto3" json:"RefID,omitempty"` +} + +func (m *CtrlMsg) Reset() { *m = CtrlMsg{} } +func (m *CtrlMsg) String() string { return proto.CompactTextString(m) } +func (*CtrlMsg) ProtoMessage() {} +func (*CtrlMsg) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{6} } + +// RaftJoinRequest is a request to join a Raft group. +type RaftJoinRequest struct { + NodeID string `protobuf:"bytes,1,opt,name=NodeID,proto3" json:"NodeID,omitempty"` + NodeAddr string `protobuf:"bytes,2,opt,name=NodeAddr,proto3" json:"NodeAddr,omitempty"` +} + +func (m *RaftJoinRequest) Reset() { *m = RaftJoinRequest{} } +func (m *RaftJoinRequest) String() string { return proto.CompactTextString(m) } +func (*RaftJoinRequest) ProtoMessage() {} +func (*RaftJoinRequest) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{7} } + +// RaftJoinResponse is a response to a RaftJoinRequest. +type RaftJoinResponse struct { + Error string `protobuf:"bytes,1,opt,name=Error,proto3" json:"Error,omitempty"` +} + +func (m *RaftJoinResponse) Reset() { *m = RaftJoinResponse{} } +func (m *RaftJoinResponse) String() string { return proto.CompactTextString(m) } +func (*RaftJoinResponse) ProtoMessage() {} +func (*RaftJoinResponse) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{8} } + +// RaftOperation is a Raft log message. +type RaftOperation struct { + OpType RaftOperation_Type `protobuf:"varint,1,opt,name=OpType,proto3,enum=spb.RaftOperation_Type" json:"OpType,omitempty"` + PublishBatch *Batch `protobuf:"bytes,2,opt,name=PublishBatch" json:"PublishBatch,omitempty"` + Sub *AddSubscription `protobuf:"bytes,3,opt,name=Sub" json:"Sub,omitempty"` + Unsub *pb.UnsubscribeRequest `protobuf:"bytes,4,opt,name=Unsub" json:"Unsub,omitempty"` + SubSentAck *SubSentAndAck `protobuf:"bytes,5,opt,name=SubSentAck" json:"SubSentAck,omitempty"` + ClientConnect *AddClient `protobuf:"bytes,7,opt,name=ClientConnect" json:"ClientConnect,omitempty"` + ClientDisconnect *pb.CloseRequest `protobuf:"bytes,8,opt,name=ClientDisconnect" json:"ClientDisconnect,omitempty"` + Channel string `protobuf:"bytes,9,opt,name=Channel,proto3" json:"Channel,omitempty"` +} + +func (m *RaftOperation) Reset() { *m = RaftOperation{} } +func (m *RaftOperation) String() string { return proto.CompactTextString(m) } +func (*RaftOperation) ProtoMessage() {} +func (*RaftOperation) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{9} } + +// Batch is a batch of messages for replication. +type Batch struct { + Messages []*pb.MsgProto `protobuf:"bytes,1,rep,name=Messages" json:"Messages,omitempty"` +} + +func (m *Batch) Reset() { *m = Batch{} } +func (m *Batch) String() string { return proto.CompactTextString(m) } +func (*Batch) ProtoMessage() {} +func (*Batch) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{10} } + +// AddSubscription is used to replicate a new client subscription. +type AddSubscription struct { + Request *pb.SubscriptionRequest `protobuf:"bytes,1,opt,name=Request" json:"Request,omitempty"` + AckInbox string `protobuf:"bytes,2,opt,name=AckInbox,proto3" json:"AckInbox,omitempty"` +} + +func (m *AddSubscription) Reset() { *m = AddSubscription{} } +func (m *AddSubscription) String() string { return proto.CompactTextString(m) } +func (*AddSubscription) ProtoMessage() {} +func (*AddSubscription) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{11} } + +// SubSentAndAck is used to replicate a sent and/or ack messages. +type SubSentAndAck struct { + Channel string `protobuf:"bytes,1,opt,name=Channel,proto3" json:"Channel,omitempty"` + AckInbox string `protobuf:"bytes,2,opt,name=AckInbox,proto3" json:"AckInbox,omitempty"` + Sent []uint64 `protobuf:"varint,3,rep,packed,name=Sent" json:"Sent,omitempty"` + Ack []uint64 `protobuf:"varint,4,rep,packed,name=Ack" json:"Ack,omitempty"` +} + +func (m *SubSentAndAck) Reset() { *m = SubSentAndAck{} } +func (m *SubSentAndAck) String() string { return proto.CompactTextString(m) } +func (*SubSentAndAck) ProtoMessage() {} +func (*SubSentAndAck) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{12} } + +// AddClient is used to replicate a new client connection. +type AddClient struct { + Request *pb.ConnectRequest `protobuf:"bytes,1,opt,name=Request" json:"Request,omitempty"` + Refresh bool `protobuf:"varint,2,opt,name=Refresh,proto3" json:"Refresh,omitempty"` +} + +func (m *AddClient) Reset() { *m = AddClient{} } +func (m *AddClient) String() string { return proto.CompactTextString(m) } +func (*AddClient) ProtoMessage() {} +func (*AddClient) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{13} } + +// RaftSnapshot is a snapshot of the state of the server. +type RaftSnapshot struct { + Clients []*ClientInfo `protobuf:"bytes,1,rep,name=Clients" json:"Clients,omitempty"` + Channels []*ChannelSnapshot `protobuf:"bytes,2,rep,name=Channels" json:"Channels,omitempty"` + Padding []byte `protobuf:"bytes,3,opt,name=Padding,proto3" json:"Padding,omitempty"` +} + +func (m *RaftSnapshot) Reset() { *m = RaftSnapshot{} } +func (m *RaftSnapshot) String() string { return proto.CompactTextString(m) } +func (*RaftSnapshot) ProtoMessage() {} +func (*RaftSnapshot) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{14} } + +// ChannelSnapshot is a snapshot of a channel +type ChannelSnapshot struct { + Channel string `protobuf:"bytes,1,opt,name=Channel,proto3" json:"Channel,omitempty"` + First uint64 `protobuf:"varint,2,opt,name=First,proto3" json:"First,omitempty"` + Last uint64 `protobuf:"varint,3,opt,name=Last,proto3" json:"Last,omitempty"` + Subscriptions []*SubscriptionSnapshot `protobuf:"bytes,4,rep,name=Subscriptions" json:"Subscriptions,omitempty"` +} + +func (m *ChannelSnapshot) Reset() { *m = ChannelSnapshot{} } +func (m *ChannelSnapshot) String() string { return proto.CompactTextString(m) } +func (*ChannelSnapshot) ProtoMessage() {} +func (*ChannelSnapshot) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{15} } + +// SubscriptionSnaphot is the snapshot of a subscription +type SubscriptionSnapshot struct { + State *SubState `protobuf:"bytes,1,opt,name=State" json:"State,omitempty"` + AcksPending []uint64 `protobuf:"varint,2,rep,packed,name=AcksPending" json:"AcksPending,omitempty"` +} + +func (m *SubscriptionSnapshot) Reset() { *m = SubscriptionSnapshot{} } +func (m *SubscriptionSnapshot) String() string { return proto.CompactTextString(m) } +func (*SubscriptionSnapshot) ProtoMessage() {} +func (*SubscriptionSnapshot) Descriptor() ([]byte, []int) { return fileDescriptorProtocol, []int{16} } + +func init() { + proto.RegisterType((*SubState)(nil), "spb.SubState") + proto.RegisterType((*SubStateDelete)(nil), "spb.SubStateDelete") + proto.RegisterType((*SubStateUpdate)(nil), "spb.SubStateUpdate") + proto.RegisterType((*ServerInfo)(nil), "spb.ServerInfo") + proto.RegisterType((*ClientInfo)(nil), "spb.ClientInfo") + proto.RegisterType((*ClientDelete)(nil), "spb.ClientDelete") + proto.RegisterType((*CtrlMsg)(nil), "spb.CtrlMsg") + proto.RegisterType((*RaftJoinRequest)(nil), "spb.RaftJoinRequest") + proto.RegisterType((*RaftJoinResponse)(nil), "spb.RaftJoinResponse") + proto.RegisterType((*RaftOperation)(nil), "spb.RaftOperation") + proto.RegisterType((*Batch)(nil), "spb.Batch") + proto.RegisterType((*AddSubscription)(nil), "spb.AddSubscription") + proto.RegisterType((*SubSentAndAck)(nil), "spb.SubSentAndAck") + proto.RegisterType((*AddClient)(nil), "spb.AddClient") + proto.RegisterType((*RaftSnapshot)(nil), "spb.RaftSnapshot") + proto.RegisterType((*ChannelSnapshot)(nil), "spb.ChannelSnapshot") + proto.RegisterType((*SubscriptionSnapshot)(nil), "spb.SubscriptionSnapshot") + proto.RegisterEnum("spb.CtrlMsg_Type", CtrlMsg_Type_name, CtrlMsg_Type_value) + proto.RegisterEnum("spb.RaftOperation_Type", RaftOperation_Type_name, RaftOperation_Type_value) +} +func (m *SubState) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *SubState) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.ID != 0 { + dAtA[i] = 0x8 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.ID)) + } + if len(m.ClientID) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ClientID))) + i += copy(dAtA[i:], m.ClientID) + } + if len(m.QGroup) > 0 { + dAtA[i] = 0x1a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.QGroup))) + i += copy(dAtA[i:], m.QGroup) + } + if len(m.Inbox) > 0 { + dAtA[i] = 0x22 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Inbox))) + i += copy(dAtA[i:], m.Inbox) + } + if len(m.AckInbox) > 0 { + dAtA[i] = 0x2a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.AckInbox))) + i += copy(dAtA[i:], m.AckInbox) + } + if m.MaxInFlight != 0 { + dAtA[i] = 0x30 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.MaxInFlight)) + } + if m.AckWaitInSecs != 0 { + dAtA[i] = 0x38 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.AckWaitInSecs)) + } + if len(m.DurableName) > 0 { + dAtA[i] = 0x42 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.DurableName))) + i += copy(dAtA[i:], m.DurableName) + } + if m.LastSent != 0 { + dAtA[i] = 0x48 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.LastSent)) + } + if m.IsDurable { + dAtA[i] = 0x50 + i++ + if m.IsDurable { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i++ + } + if m.IsClosed { + dAtA[i] = 0x58 + i++ + if m.IsClosed { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i++ + } + return i, nil +} + +func (m *SubStateDelete) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *SubStateDelete) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.ID != 0 { + dAtA[i] = 0x8 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.ID)) + } + return i, nil +} + +func (m *SubStateUpdate) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *SubStateUpdate) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.ID != 0 { + dAtA[i] = 0x8 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.ID)) + } + if m.Seqno != 0 { + dAtA[i] = 0x10 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.Seqno)) + } + return i, nil +} + +func (m *ServerInfo) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ServerInfo) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.ClusterID) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ClusterID))) + i += copy(dAtA[i:], m.ClusterID) + } + if len(m.Discovery) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Discovery))) + i += copy(dAtA[i:], m.Discovery) + } + if len(m.Publish) > 0 { + dAtA[i] = 0x1a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Publish))) + i += copy(dAtA[i:], m.Publish) + } + if len(m.Subscribe) > 0 { + dAtA[i] = 0x22 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Subscribe))) + i += copy(dAtA[i:], m.Subscribe) + } + if len(m.Unsubscribe) > 0 { + dAtA[i] = 0x2a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Unsubscribe))) + i += copy(dAtA[i:], m.Unsubscribe) + } + if len(m.Close) > 0 { + dAtA[i] = 0x32 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Close))) + i += copy(dAtA[i:], m.Close) + } + if len(m.SubClose) > 0 { + dAtA[i] = 0x3a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.SubClose))) + i += copy(dAtA[i:], m.SubClose) + } + if len(m.AcksSubs) > 0 { + dAtA[i] = 0x42 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.AcksSubs))) + i += copy(dAtA[i:], m.AcksSubs) + } + if len(m.NodeID) > 0 { + dAtA[i] = 0x4a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.NodeID))) + i += copy(dAtA[i:], m.NodeID) + } + return i, nil +} + +func (m *ClientInfo) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ClientInfo) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.ID) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ID))) + i += copy(dAtA[i:], m.ID) + } + if len(m.HbInbox) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.HbInbox))) + i += copy(dAtA[i:], m.HbInbox) + } + if len(m.ConnID) > 0 { + dAtA[i] = 0x1a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ConnID))) + i += copy(dAtA[i:], m.ConnID) + } + if m.Protocol != 0 { + dAtA[i] = 0x20 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.Protocol)) + } + if m.PingInterval != 0 { + dAtA[i] = 0x28 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.PingInterval)) + } + if m.PingMaxOut != 0 { + dAtA[i] = 0x30 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.PingMaxOut)) + } + return i, nil +} + +func (m *ClientDelete) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ClientDelete) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.ID) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ID))) + i += copy(dAtA[i:], m.ID) + } + return i, nil +} + +func (m *CtrlMsg) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *CtrlMsg) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.MsgType != 0 { + dAtA[i] = 0x8 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.MsgType)) + } + if len(m.ServerID) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.ServerID))) + i += copy(dAtA[i:], m.ServerID) + } + if len(m.Data) > 0 { + dAtA[i] = 0x1a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Data))) + i += copy(dAtA[i:], m.Data) + } + if len(m.RefID) > 0 { + dAtA[i] = 0x22 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.RefID))) + i += copy(dAtA[i:], m.RefID) + } + return i, nil +} + +func (m *RaftJoinRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *RaftJoinRequest) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.NodeID) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.NodeID))) + i += copy(dAtA[i:], m.NodeID) + } + if len(m.NodeAddr) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.NodeAddr))) + i += copy(dAtA[i:], m.NodeAddr) + } + return i, nil +} + +func (m *RaftJoinResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *RaftJoinResponse) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Error) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Error))) + i += copy(dAtA[i:], m.Error) + } + return i, nil +} + +func (m *RaftOperation) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *RaftOperation) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.OpType != 0 { + dAtA[i] = 0x8 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.OpType)) + } + if m.PublishBatch != nil { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.PublishBatch.Size())) + n1, err := m.PublishBatch.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n1 + } + if m.Sub != nil { + dAtA[i] = 0x1a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.Sub.Size())) + n2, err := m.Sub.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n2 + } + if m.Unsub != nil { + dAtA[i] = 0x22 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.Unsub.Size())) + n3, err := m.Unsub.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n3 + } + if m.SubSentAck != nil { + dAtA[i] = 0x2a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.SubSentAck.Size())) + n4, err := m.SubSentAck.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n4 + } + if m.ClientConnect != nil { + dAtA[i] = 0x3a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.ClientConnect.Size())) + n5, err := m.ClientConnect.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n5 + } + if m.ClientDisconnect != nil { + dAtA[i] = 0x42 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.ClientDisconnect.Size())) + n6, err := m.ClientDisconnect.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n6 + } + if len(m.Channel) > 0 { + dAtA[i] = 0x4a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Channel))) + i += copy(dAtA[i:], m.Channel) + } + return i, nil +} + +func (m *Batch) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Batch) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Messages) > 0 { + for _, msg := range m.Messages { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(msg.Size())) + n, err := msg.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n + } + } + return i, nil +} + +func (m *AddSubscription) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *AddSubscription) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.Request != nil { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.Request.Size())) + n7, err := m.Request.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n7 + } + if len(m.AckInbox) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.AckInbox))) + i += copy(dAtA[i:], m.AckInbox) + } + return i, nil +} + +func (m *SubSentAndAck) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *SubSentAndAck) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Channel) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Channel))) + i += copy(dAtA[i:], m.Channel) + } + if len(m.AckInbox) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.AckInbox))) + i += copy(dAtA[i:], m.AckInbox) + } + if len(m.Sent) > 0 { + dAtA9 := make([]byte, len(m.Sent)*10) + var j8 int + for _, num := range m.Sent { + for num >= 1<<7 { + dAtA9[j8] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j8++ + } + dAtA9[j8] = uint8(num) + j8++ + } + dAtA[i] = 0x1a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(j8)) + i += copy(dAtA[i:], dAtA9[:j8]) + } + if len(m.Ack) > 0 { + dAtA11 := make([]byte, len(m.Ack)*10) + var j10 int + for _, num := range m.Ack { + for num >= 1<<7 { + dAtA11[j10] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j10++ + } + dAtA11[j10] = uint8(num) + j10++ + } + dAtA[i] = 0x22 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(j10)) + i += copy(dAtA[i:], dAtA11[:j10]) + } + return i, nil +} + +func (m *AddClient) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *AddClient) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.Request != nil { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.Request.Size())) + n12, err := m.Request.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n12 + } + if m.Refresh { + dAtA[i] = 0x10 + i++ + if m.Refresh { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i++ + } + return i, nil +} + +func (m *RaftSnapshot) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *RaftSnapshot) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Clients) > 0 { + for _, msg := range m.Clients { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(msg.Size())) + n, err := msg.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if len(m.Channels) > 0 { + for _, msg := range m.Channels { + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(msg.Size())) + n, err := msg.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n + } + } + if len(m.Padding) > 0 { + dAtA[i] = 0x1a + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Padding))) + i += copy(dAtA[i:], m.Padding) + } + return i, nil +} + +func (m *ChannelSnapshot) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ChannelSnapshot) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Channel) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(len(m.Channel))) + i += copy(dAtA[i:], m.Channel) + } + if m.First != 0 { + dAtA[i] = 0x10 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.First)) + } + if m.Last != 0 { + dAtA[i] = 0x18 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.Last)) + } + if len(m.Subscriptions) > 0 { + for _, msg := range m.Subscriptions { + dAtA[i] = 0x22 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(msg.Size())) + n, err := msg.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n + } + } + return i, nil +} + +func (m *SubscriptionSnapshot) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *SubscriptionSnapshot) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if m.State != nil { + dAtA[i] = 0xa + i++ + i = encodeVarintProtocol(dAtA, i, uint64(m.State.Size())) + n13, err := m.State.MarshalTo(dAtA[i:]) + if err != nil { + return 0, err + } + i += n13 + } + if len(m.AcksPending) > 0 { + dAtA15 := make([]byte, len(m.AcksPending)*10) + var j14 int + for _, num := range m.AcksPending { + for num >= 1<<7 { + dAtA15[j14] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j14++ + } + dAtA15[j14] = uint8(num) + j14++ + } + dAtA[i] = 0x12 + i++ + i = encodeVarintProtocol(dAtA, i, uint64(j14)) + i += copy(dAtA[i:], dAtA15[:j14]) + } + return i, nil +} + +func encodeVarintProtocol(dAtA []byte, offset int, v uint64) int { + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return offset + 1 +} +func (m *SubState) Size() (n int) { + var l int + _ = l + if m.ID != 0 { + n += 1 + sovProtocol(uint64(m.ID)) + } + l = len(m.ClientID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.QGroup) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Inbox) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.AckInbox) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + if m.MaxInFlight != 0 { + n += 1 + sovProtocol(uint64(m.MaxInFlight)) + } + if m.AckWaitInSecs != 0 { + n += 1 + sovProtocol(uint64(m.AckWaitInSecs)) + } + l = len(m.DurableName) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + if m.LastSent != 0 { + n += 1 + sovProtocol(uint64(m.LastSent)) + } + if m.IsDurable { + n += 2 + } + if m.IsClosed { + n += 2 + } + return n +} + +func (m *SubStateDelete) Size() (n int) { + var l int + _ = l + if m.ID != 0 { + n += 1 + sovProtocol(uint64(m.ID)) + } + return n +} + +func (m *SubStateUpdate) Size() (n int) { + var l int + _ = l + if m.ID != 0 { + n += 1 + sovProtocol(uint64(m.ID)) + } + if m.Seqno != 0 { + n += 1 + sovProtocol(uint64(m.Seqno)) + } + return n +} + +func (m *ServerInfo) Size() (n int) { + var l int + _ = l + l = len(m.ClusterID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Discovery) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Publish) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Subscribe) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Unsubscribe) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Close) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.SubClose) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.AcksSubs) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.NodeID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *ClientInfo) Size() (n int) { + var l int + _ = l + l = len(m.ID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.HbInbox) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.ConnID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + if m.Protocol != 0 { + n += 1 + sovProtocol(uint64(m.Protocol)) + } + if m.PingInterval != 0 { + n += 1 + sovProtocol(uint64(m.PingInterval)) + } + if m.PingMaxOut != 0 { + n += 1 + sovProtocol(uint64(m.PingMaxOut)) + } + return n +} + +func (m *ClientDelete) Size() (n int) { + var l int + _ = l + l = len(m.ID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *CtrlMsg) Size() (n int) { + var l int + _ = l + if m.MsgType != 0 { + n += 1 + sovProtocol(uint64(m.MsgType)) + } + l = len(m.ServerID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Data) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.RefID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *RaftJoinRequest) Size() (n int) { + var l int + _ = l + l = len(m.NodeID) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.NodeAddr) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *RaftJoinResponse) Size() (n int) { + var l int + _ = l + l = len(m.Error) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *RaftOperation) Size() (n int) { + var l int + _ = l + if m.OpType != 0 { + n += 1 + sovProtocol(uint64(m.OpType)) + } + if m.PublishBatch != nil { + l = m.PublishBatch.Size() + n += 1 + l + sovProtocol(uint64(l)) + } + if m.Sub != nil { + l = m.Sub.Size() + n += 1 + l + sovProtocol(uint64(l)) + } + if m.Unsub != nil { + l = m.Unsub.Size() + n += 1 + l + sovProtocol(uint64(l)) + } + if m.SubSentAck != nil { + l = m.SubSentAck.Size() + n += 1 + l + sovProtocol(uint64(l)) + } + if m.ClientConnect != nil { + l = m.ClientConnect.Size() + n += 1 + l + sovProtocol(uint64(l)) + } + if m.ClientDisconnect != nil { + l = m.ClientDisconnect.Size() + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.Channel) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *Batch) Size() (n int) { + var l int + _ = l + if len(m.Messages) > 0 { + for _, e := range m.Messages { + l = e.Size() + n += 1 + l + sovProtocol(uint64(l)) + } + } + return n +} + +func (m *AddSubscription) Size() (n int) { + var l int + _ = l + if m.Request != nil { + l = m.Request.Size() + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.AckInbox) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *SubSentAndAck) Size() (n int) { + var l int + _ = l + l = len(m.Channel) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + l = len(m.AckInbox) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + if len(m.Sent) > 0 { + l = 0 + for _, e := range m.Sent { + l += sovProtocol(uint64(e)) + } + n += 1 + sovProtocol(uint64(l)) + l + } + if len(m.Ack) > 0 { + l = 0 + for _, e := range m.Ack { + l += sovProtocol(uint64(e)) + } + n += 1 + sovProtocol(uint64(l)) + l + } + return n +} + +func (m *AddClient) Size() (n int) { + var l int + _ = l + if m.Request != nil { + l = m.Request.Size() + n += 1 + l + sovProtocol(uint64(l)) + } + if m.Refresh { + n += 2 + } + return n +} + +func (m *RaftSnapshot) Size() (n int) { + var l int + _ = l + if len(m.Clients) > 0 { + for _, e := range m.Clients { + l = e.Size() + n += 1 + l + sovProtocol(uint64(l)) + } + } + if len(m.Channels) > 0 { + for _, e := range m.Channels { + l = e.Size() + n += 1 + l + sovProtocol(uint64(l)) + } + } + l = len(m.Padding) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + return n +} + +func (m *ChannelSnapshot) Size() (n int) { + var l int + _ = l + l = len(m.Channel) + if l > 0 { + n += 1 + l + sovProtocol(uint64(l)) + } + if m.First != 0 { + n += 1 + sovProtocol(uint64(m.First)) + } + if m.Last != 0 { + n += 1 + sovProtocol(uint64(m.Last)) + } + if len(m.Subscriptions) > 0 { + for _, e := range m.Subscriptions { + l = e.Size() + n += 1 + l + sovProtocol(uint64(l)) + } + } + return n +} + +func (m *SubscriptionSnapshot) Size() (n int) { + var l int + _ = l + if m.State != nil { + l = m.State.Size() + n += 1 + l + sovProtocol(uint64(l)) + } + if len(m.AcksPending) > 0 { + l = 0 + for _, e := range m.AcksPending { + l += sovProtocol(uint64(e)) + } + n += 1 + sovProtocol(uint64(l)) + l + } + return n +} + +func sovProtocol(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozProtocol(x uint64) (n int) { + return sovProtocol(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *SubState) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: SubState: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: SubState: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field ID", wireType) + } + m.ID = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.ID |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ClientID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ClientID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field QGroup", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.QGroup = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Inbox", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Inbox = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field AckInbox", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.AckInbox = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 6: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field MaxInFlight", wireType) + } + m.MaxInFlight = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.MaxInFlight |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 7: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field AckWaitInSecs", wireType) + } + m.AckWaitInSecs = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.AckWaitInSecs |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 8: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field DurableName", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.DurableName = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 9: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field LastSent", wireType) + } + m.LastSent = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.LastSent |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 10: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field IsDurable", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.IsDurable = bool(v != 0) + case 11: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field IsClosed", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.IsClosed = bool(v != 0) + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *SubStateDelete) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: SubStateDelete: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: SubStateDelete: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field ID", wireType) + } + m.ID = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.ID |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *SubStateUpdate) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: SubStateUpdate: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: SubStateUpdate: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field ID", wireType) + } + m.ID = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.ID |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Seqno", wireType) + } + m.Seqno = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Seqno |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *ServerInfo) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ServerInfo: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ServerInfo: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ClusterID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ClusterID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Discovery", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Discovery = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Publish", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Publish = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Subscribe", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Subscribe = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Unsubscribe", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Unsubscribe = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Close", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Close = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SubClose", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.SubClose = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 8: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field AcksSubs", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.AcksSubs = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 9: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field NodeID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.NodeID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *ClientInfo) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ClientInfo: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ClientInfo: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field HbInbox", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.HbInbox = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ConnID", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ConnID = append(m.ConnID[:0], dAtA[iNdEx:postIndex]...) + if m.ConnID == nil { + m.ConnID = []byte{} + } + iNdEx = postIndex + case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Protocol", wireType) + } + m.Protocol = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Protocol |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 5: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field PingInterval", wireType) + } + m.PingInterval = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.PingInterval |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 6: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field PingMaxOut", wireType) + } + m.PingMaxOut = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.PingMaxOut |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *ClientDelete) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ClientDelete: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ClientDelete: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *CtrlMsg) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: CtrlMsg: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: CtrlMsg: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field MsgType", wireType) + } + m.MsgType = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.MsgType |= (CtrlMsg_Type(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ServerID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ServerID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...) + if m.Data == nil { + m.Data = []byte{} + } + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RefID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.RefID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *RaftJoinRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: RaftJoinRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: RaftJoinRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field NodeID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.NodeID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field NodeAddr", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.NodeAddr = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *RaftJoinResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: RaftJoinResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: RaftJoinResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Error", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Error = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *RaftOperation) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: RaftOperation: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: RaftOperation: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field OpType", wireType) + } + m.OpType = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.OpType |= (RaftOperation_Type(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field PublishBatch", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.PublishBatch == nil { + m.PublishBatch = &Batch{} + } + if err := m.PublishBatch.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Sub", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Sub == nil { + m.Sub = &AddSubscription{} + } + if err := m.Sub.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Unsub", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Unsub == nil { + m.Unsub = &pb.UnsubscribeRequest{} + } + if err := m.Unsub.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SubSentAck", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.SubSentAck == nil { + m.SubSentAck = &SubSentAndAck{} + } + if err := m.SubSentAck.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ClientConnect", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.ClientConnect == nil { + m.ClientConnect = &AddClient{} + } + if err := m.ClientConnect.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 8: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ClientDisconnect", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.ClientDisconnect == nil { + m.ClientDisconnect = &pb.CloseRequest{} + } + if err := m.ClientDisconnect.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 9: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Channel", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Channel = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Batch) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Batch: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Batch: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Messages", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Messages = append(m.Messages, &pb.MsgProto{}) + if err := m.Messages[len(m.Messages)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *AddSubscription) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: AddSubscription: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: AddSubscription: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Request", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Request == nil { + m.Request = &pb.SubscriptionRequest{} + } + if err := m.Request.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field AckInbox", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.AckInbox = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *SubSentAndAck) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: SubSentAndAck: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: SubSentAndAck: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Channel", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Channel = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field AckInbox", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.AckInbox = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType == 0 { + var v uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Sent = append(m.Sent, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Sent = append(m.Sent, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Sent", wireType) + } + case 4: + if wireType == 0 { + var v uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Ack = append(m.Ack, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Ack = append(m.Ack, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Ack", wireType) + } + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *AddClient) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: AddClient: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: AddClient: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Request", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Request == nil { + m.Request = &pb.ConnectRequest{} + } + if err := m.Request.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Refresh", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Refresh = bool(v != 0) + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *RaftSnapshot) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: RaftSnapshot: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: RaftSnapshot: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Clients", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Clients = append(m.Clients, &ClientInfo{}) + if err := m.Clients[len(m.Clients)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Channels", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Channels = append(m.Channels, &ChannelSnapshot{}) + if err := m.Channels[len(m.Channels)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Padding", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Padding = append(m.Padding[:0], dAtA[iNdEx:postIndex]...) + if m.Padding == nil { + m.Padding = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *ChannelSnapshot) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ChannelSnapshot: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ChannelSnapshot: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Channel", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Channel = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field First", wireType) + } + m.First = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.First |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Last", wireType) + } + m.Last = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Last |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Subscriptions", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Subscriptions = append(m.Subscriptions, &SubscriptionSnapshot{}) + if err := m.Subscriptions[len(m.Subscriptions)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *SubscriptionSnapshot) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: SubscriptionSnapshot: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: SubscriptionSnapshot: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field State", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.State == nil { + m.State = &SubState{} + } + if err := m.State.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType == 0 { + var v uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.AcksPending = append(m.AcksPending, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthProtocol + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProtocol + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.AcksPending = append(m.AcksPending, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field AcksPending", wireType) + } + default: + iNdEx = preIndex + skippy, err := skipProtocol(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthProtocol + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipProtocol(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowProtocol + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowProtocol + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowProtocol + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthProtocol + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowProtocol + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipProtocol(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthProtocol = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowProtocol = fmt.Errorf("proto: integer overflow") +) + +func init() { proto.RegisterFile("protocol.proto", fileDescriptorProtocol) } + +var fileDescriptorProtocol = []byte{ + // 1223 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x56, 0x5f, 0x8f, 0xdb, 0x44, + 0x10, 0x3f, 0xe7, 0xcf, 0x25, 0x99, 0x24, 0x77, 0xee, 0xea, 0x68, 0x4d, 0x85, 0xa2, 0xc8, 0x20, + 0x14, 0x44, 0x9b, 0xa3, 0x01, 0xf5, 0x09, 0x09, 0x5d, 0x2f, 0x2d, 0x0d, 0xe2, 0xda, 0xd3, 0xa6, + 0x15, 0x12, 0x12, 0x12, 0x6b, 0x67, 0xcf, 0x67, 0xc5, 0xb7, 0x76, 0xbd, 0xeb, 0x53, 0xfb, 0x01, + 0x90, 0x78, 0x84, 0x17, 0x9e, 0xf8, 0x0c, 0x7c, 0x06, 0x1e, 0xfb, 0xc8, 0x23, 0x8f, 0x50, 0xbe, + 0x08, 0xda, 0xd9, 0xb5, 0xe3, 0xdc, 0x55, 0x7d, 0xdb, 0xdf, 0xcc, 0xec, 0x64, 0xe6, 0xf7, 0x9b, + 0x59, 0x07, 0xf6, 0xb2, 0x3c, 0x55, 0x69, 0x98, 0x26, 0x53, 0x3c, 0x90, 0xa6, 0xcc, 0x82, 0xdb, + 0x77, 0xa3, 0x58, 0x9d, 0x17, 0xc1, 0x34, 0x4c, 0x2f, 0x0e, 0xa3, 0x34, 0x4a, 0x0f, 0xd1, 0x17, + 0x14, 0x67, 0x88, 0x10, 0xe0, 0xc9, 0xdc, 0xb9, 0x7d, 0xbf, 0x16, 0x2e, 0x98, 0x92, 0x77, 0x63, + 0xed, 0xbe, 0x8b, 0x47, 0xa9, 0x72, 0xce, 0x2e, 0x62, 0x11, 0x1d, 0x66, 0xc1, 0xe1, 0xf6, 0x6f, + 0xf9, 0x7f, 0x36, 0xa0, 0xbb, 0x2c, 0x82, 0xa5, 0x62, 0x8a, 0x93, 0x3d, 0x68, 0x2c, 0xe6, 0x9e, + 0x33, 0x76, 0x26, 0x2d, 0xda, 0x58, 0xcc, 0xc9, 0x6d, 0xe8, 0x86, 0x49, 0xcc, 0x85, 0x5a, 0xcc, + 0xbd, 0xc6, 0xd8, 0x99, 0xf4, 0x68, 0x85, 0xc9, 0x4d, 0xd8, 0x7d, 0xf1, 0x75, 0x9e, 0x16, 0x99, + 0xd7, 0x44, 0x8f, 0x45, 0xe4, 0x00, 0xda, 0xb1, 0x08, 0xd2, 0x97, 0x5e, 0x0b, 0xcd, 0x06, 0xe8, + 0x4c, 0x2c, 0x5c, 0x2f, 0xd0, 0xd1, 0x36, 0x99, 0x4a, 0x4c, 0xc6, 0xd0, 0xbf, 0x60, 0x2f, 0x17, + 0xe2, 0x51, 0x12, 0x47, 0xe7, 0xca, 0xdb, 0x1d, 0x3b, 0x93, 0x36, 0xad, 0x9b, 0xc8, 0x47, 0x30, + 0x64, 0xe1, 0xfa, 0x3b, 0x16, 0xab, 0x85, 0x58, 0xf2, 0x50, 0x7a, 0x1d, 0x8c, 0xd9, 0x36, 0xea, + 0x3c, 0xab, 0x22, 0x67, 0x41, 0xc2, 0x9f, 0xb0, 0x0b, 0xee, 0x75, 0xf1, 0x67, 0xea, 0x26, 0x5d, + 0x45, 0xc2, 0xa4, 0x5a, 0x72, 0xa1, 0xbc, 0x1e, 0x76, 0x59, 0x61, 0xf2, 0x01, 0xf4, 0x62, 0x39, + 0x37, 0xc1, 0x1e, 0x8c, 0x9d, 0x49, 0x97, 0x6e, 0x0c, 0xfa, 0x66, 0x2c, 0x8f, 0x93, 0x54, 0xf2, + 0x95, 0xd7, 0x47, 0x67, 0x85, 0xfd, 0x31, 0xec, 0x95, 0x0c, 0xce, 0x79, 0xc2, 0xaf, 0xf3, 0xe8, + 0xdf, 0xdf, 0x44, 0x3c, 0xcf, 0x56, 0x6f, 0x63, 0xfa, 0x00, 0xda, 0x92, 0xbf, 0x10, 0x29, 0xd2, + 0xdc, 0xa2, 0x06, 0xf8, 0x3f, 0x37, 0x00, 0x96, 0x3c, 0xbf, 0xe4, 0xf9, 0x42, 0x9c, 0xa5, 0xba, + 0xc4, 0xe3, 0xa4, 0x90, 0x8a, 0xe7, 0xf6, 0x6e, 0x8f, 0x6e, 0x0c, 0xda, 0x3b, 0x8f, 0x65, 0x98, + 0x5e, 0xf2, 0xfc, 0x95, 0x55, 0x6b, 0x63, 0x20, 0x1e, 0x74, 0x4e, 0x8b, 0x20, 0x89, 0xe5, 0xb9, + 0xd5, 0xab, 0x84, 0xfa, 0xde, 0xb2, 0x08, 0x64, 0x98, 0xc7, 0x01, 0xb7, 0xa2, 0x6d, 0x0c, 0x9a, + 0xd4, 0xe7, 0x42, 0x56, 0x7e, 0xa3, 0x5d, 0xdd, 0xa4, 0x4b, 0x47, 0x22, 0x50, 0xb8, 0x1e, 0x35, + 0x40, 0x13, 0xb6, 0x2c, 0x02, 0xe3, 0xe8, 0x18, 0xc1, 0x4b, 0xac, 0x7d, 0x47, 0xe1, 0x5a, 0xea, + 0x1f, 0xb1, 0x2a, 0x55, 0x58, 0x8f, 0xd5, 0x93, 0x74, 0xc5, 0x17, 0x73, 0x14, 0xa8, 0x47, 0x2d, + 0xf2, 0xff, 0x70, 0x00, 0x8e, 0xcd, 0xec, 0x69, 0x2a, 0x36, 0xfc, 0xf5, 0x90, 0x3f, 0x0f, 0x3a, + 0x8f, 0x03, 0x33, 0x5e, 0xa6, 0xf5, 0x12, 0xea, 0x84, 0xc7, 0xa9, 0x10, 0x8b, 0x39, 0xf6, 0x3d, + 0xa0, 0x16, 0xe9, 0x22, 0x4e, 0xed, 0x2a, 0x60, 0xd7, 0x6d, 0x5a, 0x61, 0xe2, 0xc3, 0xe0, 0x34, + 0x16, 0xd1, 0x42, 0x28, 0x9e, 0x5f, 0xb2, 0x04, 0xbb, 0x6e, 0xd3, 0x2d, 0x1b, 0x19, 0x01, 0x68, + 0x7c, 0xc2, 0x5e, 0x3e, 0x2d, 0xca, 0xa1, 0xad, 0x59, 0xfc, 0x11, 0x0c, 0x4c, 0xbd, 0xd7, 0x66, + 0x02, 0x2b, 0xf6, 0xff, 0x76, 0xa0, 0x73, 0xac, 0xf2, 0xe4, 0x44, 0x46, 0xe4, 0x53, 0xe8, 0x9c, + 0xc8, 0xe8, 0xd9, 0xab, 0x8c, 0x63, 0xc0, 0xde, 0xec, 0xc6, 0x54, 0x66, 0xc1, 0xd4, 0xba, 0xa7, + 0xda, 0x41, 0xcb, 0x08, 0x64, 0xd6, 0xcc, 0x44, 0xb5, 0x94, 0x25, 0x26, 0x04, 0x5a, 0x73, 0xa6, + 0x98, 0x6d, 0x15, 0xcf, 0x5a, 0x1f, 0xca, 0xcf, 0x16, 0xf3, 0x72, 0x21, 0x11, 0xf8, 0xdf, 0x43, + 0x0b, 0xb3, 0x11, 0x1c, 0xcd, 0x9a, 0x9e, 0xee, 0x0e, 0x19, 0x6c, 0xb4, 0x73, 0x1d, 0x32, 0x84, + 0x9e, 0xa6, 0xcc, 0xc0, 0x06, 0xd9, 0x87, 0xfe, 0xa3, 0x67, 0x8f, 0x39, 0xcb, 0x55, 0xc0, 0x99, + 0x72, 0x9b, 0xc4, 0x85, 0xc1, 0x29, 0xcb, 0x55, 0xac, 0xe2, 0x54, 0xc4, 0x22, 0x72, 0x5b, 0xfe, + 0x43, 0xd8, 0xa7, 0xec, 0x4c, 0x7d, 0x93, 0xc6, 0x82, 0xf2, 0x17, 0x05, 0x97, 0xaa, 0x26, 0xab, + 0x53, 0x97, 0x55, 0x37, 0xa3, 0x4f, 0x47, 0xab, 0x55, 0x5e, 0x36, 0x53, 0x62, 0x7f, 0x02, 0xee, + 0x26, 0x8d, 0xcc, 0x52, 0x21, 0x71, 0xd8, 0x1e, 0xe6, 0x79, 0x9a, 0xdb, 0x34, 0x06, 0xf8, 0xbf, + 0xb7, 0x60, 0xa8, 0x43, 0x9f, 0x66, 0x3c, 0x67, 0xba, 0x0e, 0x72, 0x08, 0xbb, 0x4f, 0xb3, 0x1a, + 0xa1, 0xb7, 0x90, 0xd0, 0xad, 0x18, 0x43, 0xab, 0x0d, 0x23, 0x53, 0x18, 0xd8, 0x85, 0x78, 0xc0, + 0x54, 0x78, 0x8e, 0xc5, 0xf4, 0x67, 0x80, 0xd7, 0xd0, 0x42, 0xb7, 0xfc, 0xe4, 0x63, 0x68, 0x2e, + 0x8b, 0x00, 0x89, 0xee, 0xcf, 0x0e, 0x30, 0xec, 0x68, 0xb5, 0xb2, 0x7b, 0x93, 0xe9, 0xfc, 0x54, + 0x07, 0x90, 0x3b, 0xd0, 0x46, 0x72, 0x91, 0xfd, 0xfe, 0xec, 0xe6, 0x34, 0x0b, 0xa6, 0x35, 0xb6, + 0x2d, 0x3f, 0xd4, 0x04, 0x91, 0x19, 0x80, 0x7e, 0x28, 0xb8, 0x50, 0x47, 0xe1, 0x1a, 0xc7, 0xae, + 0x3f, 0x23, 0x98, 0xbc, 0x34, 0x8b, 0xd5, 0x51, 0xb8, 0xa6, 0xb5, 0x28, 0xf2, 0x05, 0x0c, 0xcd, + 0xa0, 0x69, 0x95, 0x78, 0xa8, 0x70, 0xdd, 0xfa, 0xb3, 0xbd, 0xb2, 0x26, 0xe3, 0xa4, 0xdb, 0x41, + 0xe4, 0x4b, 0x70, 0xed, 0x78, 0xea, 0x27, 0xc2, 0x5c, 0xec, 0xe2, 0x45, 0x57, 0x97, 0x88, 0x6a, + 0x97, 0xc5, 0x5d, 0x8b, 0xd4, 0xeb, 0x76, 0x7c, 0xce, 0x84, 0xe0, 0x89, 0x5d, 0xd3, 0x12, 0xfa, + 0xbf, 0x3a, 0x76, 0xb0, 0xfa, 0xd5, 0x83, 0xe3, 0xee, 0xe8, 0x19, 0xaa, 0x9e, 0x14, 0xd7, 0x21, + 0x37, 0x81, 0x50, 0x7e, 0x91, 0x5e, 0xf2, 0x3a, 0x5f, 0x6e, 0x83, 0xbc, 0x07, 0x37, 0xf0, 0x87, + 0xb7, 0xcc, 0x4d, 0xb2, 0xa7, 0x5f, 0x41, 0xb1, 0x32, 0xbd, 0xbb, 0x2d, 0x9d, 0xda, 0xb6, 0xe1, + 0xee, 0x6a, 0xe7, 0xa6, 0x30, 0xb7, 0x43, 0x6e, 0xc0, 0xd0, 0x6c, 0x9c, 0xad, 0xc8, 0xed, 0xfa, + 0xf7, 0xa0, 0x6d, 0x44, 0x9b, 0x40, 0xf7, 0x84, 0x4b, 0xc9, 0x22, 0x2e, 0x3d, 0x67, 0xdc, 0x9c, + 0xf4, 0x67, 0x03, 0xdd, 0xec, 0x89, 0x8c, 0x70, 0xf5, 0x69, 0xe5, 0xf5, 0x7f, 0x84, 0xfd, 0x2b, + 0x72, 0x92, 0x7b, 0xd0, 0xb1, 0x84, 0xe0, 0x4c, 0xf5, 0x67, 0xb7, 0xa6, 0x46, 0x97, 0x8d, 0xe2, + 0x96, 0xaf, 0x32, 0xce, 0x3e, 0x74, 0xf5, 0x67, 0xa9, 0xc2, 0xfe, 0x1a, 0x86, 0x5b, 0x9a, 0xd6, + 0x39, 0x75, 0xb6, 0x38, 0x7d, 0x57, 0x1a, 0xbd, 0xf1, 0xf8, 0x39, 0x6b, 0x8e, 0x9b, 0x93, 0x16, + 0xc5, 0x33, 0x71, 0xa1, 0xa9, 0xc7, 0xa7, 0x85, 0x26, 0x7d, 0xf4, 0x97, 0xd0, 0xab, 0x26, 0x81, + 0xdc, 0xb9, 0xda, 0x08, 0x41, 0xc5, 0x0d, 0x83, 0xd7, 0x7a, 0xf0, 0x74, 0xf4, 0x59, 0xce, 0xa5, + 0xd9, 0x89, 0x2e, 0x2d, 0xa1, 0xff, 0x93, 0x03, 0x03, 0xbd, 0x51, 0x4b, 0xc1, 0x32, 0x79, 0x9e, + 0x2a, 0xf2, 0x09, 0x74, 0xcc, 0x4f, 0x94, 0xec, 0xee, 0x9b, 0x67, 0xac, 0x7a, 0xb6, 0x69, 0xe9, + 0x27, 0x9f, 0x41, 0xd7, 0x76, 0x27, 0xbd, 0x06, 0xc6, 0x9a, 0x1d, 0xb2, 0xc6, 0x32, 0x25, 0xad, + 0xa2, 0xf0, 0x03, 0xc6, 0x56, 0xab, 0x58, 0x44, 0xf6, 0x75, 0x2b, 0xa1, 0xff, 0x9b, 0x03, 0xfb, + 0x57, 0xee, 0xbd, 0x83, 0xcc, 0x03, 0x68, 0x3f, 0x8a, 0x73, 0xa9, 0xca, 0x2f, 0x2d, 0x02, 0x4d, + 0xe3, 0xb7, 0x4c, 0x2a, 0x4c, 0xdd, 0xa2, 0x78, 0x26, 0x5f, 0xa1, 0x42, 0x95, 0xba, 0x12, 0x09, + 0xed, 0xcf, 0xde, 0x2f, 0xf7, 0xb1, 0xf2, 0x54, 0xd5, 0x6e, 0xc7, 0xfb, 0x3f, 0xc0, 0xc1, 0xdb, + 0xc2, 0xc8, 0x87, 0xd0, 0xc6, 0xff, 0x02, 0x96, 0xfe, 0x61, 0xb5, 0xe0, 0xda, 0x48, 0x8d, 0x4f, + 0x7f, 0x78, 0xf5, 0x47, 0xf1, 0x94, 0x0b, 0xec, 0xb9, 0x81, 0x62, 0xd6, 0x4d, 0x0f, 0x0e, 0x5e, + 0xff, 0x3b, 0xda, 0x79, 0xfd, 0x66, 0xe4, 0xfc, 0xf5, 0x66, 0xe4, 0xfc, 0xf3, 0x66, 0xe4, 0xfc, + 0xf2, 0xdf, 0x68, 0x27, 0xd8, 0xc5, 0xff, 0x75, 0x9f, 0xff, 0x1f, 0x00, 0x00, 0xff, 0xff, 0xb6, + 0xf3, 0x4c, 0x5c, 0x55, 0x0a, 0x00, 0x00, +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/stores/common.go b/vendor/github.com/nats-io/nats-streaming-server/stores/common.go new file mode 100644 index 00000000000..b2413700393 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/stores/common.go @@ -0,0 +1,413 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stores + +import ( + "sync" + + "github.com/nats-io/go-nats-streaming/pb" + "github.com/nats-io/nats-streaming-server/logger" + "github.com/nats-io/nats-streaming-server/spb" + "github.com/nats-io/nats-streaming-server/util" +) + +// format string used to report that limit is reached when storing +// messages. +var droppingMsgsFmt = "WARNING: Reached limits for store %q (msgs=%v/%v bytes=%v/%v), " + + "dropping old messages to make room for new ones" + +// commonStore contains everything that is common to any type of store +type commonStore struct { + sync.RWMutex + closed bool + log logger.Logger +} + +// genericStore is the generic store implementation with a map of channels. +type genericStore struct { + commonStore + limits *StoreLimits + sublist *util.Sublist + name string + channels map[string]*Channel +} + +// Used as the value for the genericSubStore's subs map. +var emptySub = struct{}{} + +// genericSubStore is the generic store implementation that manages subscriptions +// for a given channel. +type genericSubStore struct { + commonStore + limits SubStoreLimits + subs map[uint64]interface{} + maxSubID uint64 +} + +// genericMsgStore is the generic store implementation that manages messages +// for a given channel. +type genericMsgStore struct { + commonStore + limits MsgStoreLimits + subject string // Can't be wildcard + first uint64 + last uint64 + totalCount int + totalBytes uint64 + hitLimit bool // indicates if store had to drop messages due to limit +} + +//////////////////////////////////////////////////////////////////////////// +// genericStore methods +//////////////////////////////////////////////////////////////////////////// + +// init initializes the structure of a generic store +func (gs *genericStore) init(name string, log logger.Logger, limits *StoreLimits) error { + gs.name = name + if limits == nil { + limits = &DefaultStoreLimits + } + if err := gs.setLimits(limits); err != nil { + return err + } + gs.log = log + // Do not use limits values to create the map. + gs.channels = make(map[string]*Channel) + return nil +} + +// GetExclusiveLock implements the Store interface. +func (gs *genericStore) GetExclusiveLock() (bool, error) { + // Need to be implementation specific. + return false, ErrNotSupported +} + +// Init can be used to initialize the store with server's information. +func (gs *genericStore) Init(info *spb.ServerInfo) error { + return nil +} + +// Name returns the type name of this store +func (gs *genericStore) Name() string { + return gs.name +} + +// Recover implements the Store interface. +func (gs *genericStore) Recover() (*RecoveredState, error) { + // Implementations that can recover their state need to + // override this. + return nil, nil +} + +// setLimits makes a copy of the given StoreLimits, +// validates the limits and if ok, applies the inheritance. +func (gs *genericStore) setLimits(limits *StoreLimits) error { + // Make a copy. + gs.limits = limits.Clone() + // Build will validate and apply inheritance if no error. + if err := gs.limits.Build(); err != nil { + return err + } + // We don't need the PerChannel map and the sublist. So replace + // the map with the sublist instead. + gs.sublist = util.NewSublist() + for key, val := range gs.limits.PerChannel { + // val is already a copy of the original limits.PerChannel[key], + // so don't need to make a copy again, we own this. + gs.sublist.Insert(key, val) + } + // Get rid of the map now. + gs.limits.PerChannel = nil + return nil +} + +// Returns the appropriate limits for this channel based on inheritance. +// The channel is assumed to be a literal, and the store lock held on entry. +func (gs *genericStore) getChannelLimits(channel string) *ChannelLimits { + r := gs.sublist.Match(channel) + if len(r) == 0 { + // If there is no match, that means we need to use the global limits. + return &gs.limits.ChannelLimits + } + // If there is a match, use the limits from the last element because + // we know that the returned array is ordered from widest to narrowest, + // and the only literal that there is would be the channel we are + // looking up. + return r[len(r)-1].(*ChannelLimits) +} + +// GetChannelLimits implements the Store interface +func (gs *genericStore) GetChannelLimits(channel string) *ChannelLimits { + gs.RLock() + defer gs.RUnlock() + c := gs.channels[channel] + if c == nil { + return nil + } + // Return a copy + cl := *gs.getChannelLimits(channel) + return &cl +} + +// SetLimits sets limits for this store +func (gs *genericStore) SetLimits(limits *StoreLimits) error { + gs.Lock() + err := gs.setLimits(limits) + gs.Unlock() + return err +} + +// CreateChannel implements the Store interface +func (gs *genericStore) CreateChannel(channel string) (*Channel, error) { + return nil, nil +} + +// DeleteChannel implements the Store interface +func (gs *genericStore) DeleteChannel(channel string) error { + gs.Lock() + err := gs.deleteChannel(channel) + gs.Unlock() + return err +} + +func (gs *genericStore) deleteChannel(channel string) error { + c := gs.channels[channel] + if c == nil { + return ErrNotFound + } + err := c.Msgs.Close() + if lerr := c.Subs.Close(); lerr != nil && err == nil { + err = lerr + } + if err != nil { + return err + } + delete(gs.channels, channel) + return nil +} + +// canAddChannel returns true if the current number of channels is below the limit. +// If a channel named `channelName` alreadt exists, an error is returned. +// Store lock is assumed to be locked. +func (gs *genericStore) canAddChannel(name string) error { + if gs.channels[name] != nil { + return ErrAlreadyExists + } + if gs.limits.MaxChannels > 0 && len(gs.channels) >= gs.limits.MaxChannels { + return ErrTooManyChannels + } + return nil +} + +// AddClient implements the Store interface +func (gs *genericStore) AddClient(info *spb.ClientInfo) (*Client, error) { + return &Client{*info}, nil +} + +// DeleteClient implements the Store interface +func (gs *genericStore) DeleteClient(clientID string) error { + return nil +} + +// Close closes all stores +func (gs *genericStore) Close() error { + gs.Lock() + defer gs.Unlock() + if gs.closed { + return nil + } + gs.closed = true + return gs.close() +} + +// close closes all stores. Store lock is assumed held on entry +func (gs *genericStore) close() error { + var err error + var lerr error + + for _, cs := range gs.channels { + lerr = cs.Subs.Close() + if lerr != nil && err == nil { + err = lerr + } + lerr = cs.Msgs.Close() + if lerr != nil && err == nil { + err = lerr + } + } + return err +} + +//////////////////////////////////////////////////////////////////////////// +// genericMsgStore methods +//////////////////////////////////////////////////////////////////////////// + +// init initializes this generic message store +func (gms *genericMsgStore) init(subject string, log logger.Logger, limits *MsgStoreLimits) { + gms.subject = subject + gms.limits = *limits + gms.log = log +} + +// State returns some statistics related to this store +func (gms *genericMsgStore) State() (numMessages int, byteSize uint64, err error) { + gms.RLock() + c, b := gms.totalCount, gms.totalBytes + gms.RUnlock() + return c, b, nil +} + +// Store implements the MsgStore interface +func (gms *genericMsgStore) Store(msg *pb.MsgProto) (uint64, error) { + // no-op + return 0, nil +} + +// FirstSequence returns sequence for first message stored. +func (gms *genericMsgStore) FirstSequence() (uint64, error) { + gms.RLock() + first := gms.first + gms.RUnlock() + return first, nil +} + +// LastSequence returns sequence for last message stored. +func (gms *genericMsgStore) LastSequence() (uint64, error) { + gms.RLock() + last := gms.last + gms.RUnlock() + return last, nil +} + +// FirstAndLastSequence returns sequences for the first and last messages stored. +func (gms *genericMsgStore) FirstAndLastSequence() (uint64, uint64, error) { + gms.RLock() + first, last := gms.first, gms.last + gms.RUnlock() + return first, last, nil +} + +// Lookup returns the stored message with given sequence number. +func (gms *genericMsgStore) Lookup(seq uint64) (*pb.MsgProto, error) { + return nil, nil +} + +// FirstMsg returns the first message stored. +func (gms *genericMsgStore) FirstMsg() (*pb.MsgProto, error) { + return nil, nil +} + +// LastMsg returns the last message stored. +func (gms *genericMsgStore) LastMsg() (*pb.MsgProto, error) { + return nil, nil +} + +func (gms *genericMsgStore) Flush() error { + return nil +} + +// GetSequenceFromTimestamp returns the sequence of the first message whose +// timestamp is greater or equal to given timestamp. +func (gms *genericMsgStore) GetSequenceFromTimestamp(timestamp int64) (uint64, error) { + return 0, nil +} + +// Empty implements the MsgStore interface +func (gms *genericMsgStore) Empty() error { + return nil +} + +func (gms *genericMsgStore) empty() { + gms.first, gms.last, gms.totalCount, gms.totalBytes, gms.hitLimit = 0, 0, 0, 0, false +} + +// Close closes this store. +func (gms *genericMsgStore) Close() error { + return nil +} + +//////////////////////////////////////////////////////////////////////////// +// genericSubStore methods +//////////////////////////////////////////////////////////////////////////// + +// init initializes the structure of a generic sub store +func (gss *genericSubStore) init(log logger.Logger, limits *SubStoreLimits) { + gss.limits = *limits + gss.log = log + gss.subs = make(map[uint64]interface{}) +} + +// CreateSub records a new subscription represented by SubState. On success, +// it records the subscription's ID in SubState.ID. This ID is to be used +// by the other SubStore methods. +func (gss *genericSubStore) CreateSub(sub *spb.SubState) error { + gss.Lock() + err := gss.createSub(sub) + gss.Unlock() + return err +} + +// UpdateSub updates a given subscription represented by SubState. +func (gss *genericSubStore) UpdateSub(sub *spb.SubState) error { + return nil +} + +// createSub checks that the number of subscriptions is below the max +// and if so, assigns a new subscription ID and keep track of it in a map. +// Lock is assumed to be held on entry. +func (gss *genericSubStore) createSub(sub *spb.SubState) error { + if gss.limits.MaxSubscriptions > 0 && len(gss.subs) >= gss.limits.MaxSubscriptions { + return ErrTooManySubs + } + + // Bump the max value before assigning it to the new subscription. + gss.maxSubID++ + + // This new subscription has the max value. + sub.ID = gss.maxSubID + // Store anything. Some implementations may replace with specific + // object. + gss.subs[sub.ID] = emptySub + + return nil +} + +// DeleteSub invalidates this subscription. +func (gss *genericSubStore) DeleteSub(subid uint64) error { + gss.Lock() + delete(gss.subs, subid) + gss.Unlock() + return nil +} + +// AddSeqPending adds the given message seqno to the given subscription. +func (gss *genericSubStore) AddSeqPending(subid, seqno uint64) error { + return nil +} + +// AckSeqPending records that the given message seqno has been acknowledged +// by the given subscription. +func (gss *genericSubStore) AckSeqPending(subid, seqno uint64) error { + return nil +} + +// Flush is for stores that may buffer operations and need them to be persisted. +func (gss *genericSubStore) Flush() error { + return nil +} + +// Close closes this store +func (gss *genericSubStore) Close() error { + return nil +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/stores/filestore.go b/vendor/github.com/nats-io/nats-streaming-server/stores/filestore.go new file mode 100644 index 00000000000..e76f1487368 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/stores/filestore.go @@ -0,0 +1,4076 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stores + +import ( + "bufio" + "errors" + "fmt" + "hash/crc32" + "io" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/nats-io/go-nats-streaming/pb" + "github.com/nats-io/nats-streaming-server/logger" + "github.com/nats-io/nats-streaming-server/spb" + "github.com/nats-io/nats-streaming-server/util" +) + +const ( + // Our file version. + fileVersion = 1 + + // Prefix for message log files + msgFilesPrefix = "msgs." + + // Data files suffix + datSuffix = ".dat" + + // Index files suffix + idxSuffix = ".idx" + + // Backup file suffix + bakSuffix = ".bak" + + // Name of the subscriptions file. + subsFileName = "subs" + datSuffix + + // Name of the clients file. + clientsFileName = "clients" + datSuffix + + // Name of the server file. + serverFileName = "server" + datSuffix + + // Number of bytes required to store a CRC-32 checksum + crcSize = crc32.Size + + // Size of a record header. + // 4 bytes: For typed records: 1 byte for type, 3 bytes for buffer size + // For non typed rec: buffer size + // +4 bytes for CRC-32 + recordHeaderSize = 4 + crcSize + + // defaultBufSize is used for various buffered IO operations + defaultBufSize = 10 * 1024 * 1024 + + // Size of an message index record + // Seq - Offset - Timestamp - Size - CRC + msgIndexRecSize = 8 + 8 + 8 + 4 + crcSize + + // msgRecordOverhead is the number of bytes to count toward the size + // of a serialized message so that file slice size is closer to + // channels and/or file slice limits. + msgRecordOverhead = recordHeaderSize + msgIndexRecSize + + // Percentage of buffer usage to decide if the buffer should shrink + bufShrinkThreshold = 50 + + // Interval when to check/try to shrink buffer writers + defaultBufShrinkInterval = 5 * time.Second + + // Interval an unused file slice is left opened + defaultSliceCloseInterval = time.Second + + // If FileStoreOption's BufferSize is > 0, the buffer writer is initially + // created with this size (unless this is > than BufferSize, in which case + // BufferSize is used). When possible, the buffer will shrink but not lower + // than this value. This is for FileSubStore's + subBufMinShrinkSize = 128 + + // If FileStoreOption's BufferSize is > 0, the buffer writer is initially + // created with this size (unless this is > than BufferSize, in which case + // BufferSize is used). When possible, the buffer will shrink but not lower + // than this value. This is for FileMsgStore's + msgBufMinShrinkSize = 512 + + // This is the sleep time in the background tasks go routine. + defaultBkgTasksSleepDuration = time.Second + + // This is the default amount of time a message is cached. + defaultCacheTTL = time.Second + + // defaultFileFlags are the default file flags used when opening a file + defaultFileFlags = os.O_RDWR | os.O_CREATE | os.O_APPEND + + // Lock file name + lockFileName = ".rootdir.lck" + + // Witness file for TruncateUnexpectedEOF option + truncateBadEOFFileName = ".truncate.lck" +) + +// FileStoreOption is a function on the options for a File Store +type FileStoreOption func(*FileStoreOptions) error + +// FileStoreOptions can be used to customize a File Store +type FileStoreOptions struct { + // BufferSize is the size of the buffer used during store operations. + BufferSize int + + // CompactEnabled allows to enable/disable files compaction. + CompactEnabled bool + + // CompactInterval indicates the minimum interval (in seconds) between compactions. + CompactInterval int + + // CompactFragmentation indicates the minimum ratio of fragmentation + // to trigger compaction. For instance, 50 means that compaction + // would not happen until fragmentation is more than 50%. + CompactFragmentation int + + // CompactMinFileSize indicates the minimum file size before compaction + // can be performed, regardless of the current file fragmentation. + CompactMinFileSize int64 + + // DoCRC enables (or disables) CRC checksum verification on read operations. + DoCRC bool + + // CRCPoly is a polynomial used to make the table used in CRC computation. + CRCPolynomial int64 + + // DoSync indicates if `File.Sync()`` is called during a flush. + DoSync bool + + // Regardless of channel limits, the options below allow to split a message + // log in smaller file chunks. If all those options were to be set to 0, + // some file slice limit will be selected automatically based on the channel + // limits. + // SliceMaxMsgs defines how many messages can fit in a file slice (0 means + // count is not checked). + SliceMaxMsgs int + // SliceMaxBytes defines how many bytes can fit in a file slice, including + // the corresponding index file (0 means size is not checked). + SliceMaxBytes int64 + // SliceMaxAge defines the period of time covered by a slice starting when + // the first message is stored (0 means time is not checked). + SliceMaxAge time.Duration + // SliceArchiveScript is the path to a script to be invoked when a file + // slice (and the corresponding index file) is going to be removed. + // The script will be invoked with the channel name and names of data and + // index files (which both have been previously renamed with a '.bak' + // extension). It is the responsibility of the script to move/remove + // those files. + SliceArchiveScript string + + // FileDescriptorsLimit is a soft limit hinting at FileStore to try to + // limit the number of concurrent opened files to that limit. + FileDescriptorsLimit int64 + + // Number of channels recovered in parallel (default is 1). + ParallelRecovery int + + // TruncateUnexpectedEOF is set to true means that if recovery reports + // an error about unexpected end of file, the last bad record will be + // removed (the file is truncated at the beginning of the first incomplete + // record). Dataloss may occur. + TruncateUnexpectedEOF bool +} + +// This is an internal error to detect situations where we do +// not get an EOF but all data we read are zeros. The file +// will be rewind to previous position and use this as the +// first write position. +var errNeedRewind = errors.New("end of file padded with zeros") + +// DefaultFileStoreOptions defines the default options for a File Store. +var DefaultFileStoreOptions = FileStoreOptions{ + BufferSize: 2 * 1024 * 1024, // 2MB + CompactEnabled: true, + CompactInterval: 5 * 60, // 5 minutes + CompactFragmentation: 50, + CompactMinFileSize: 1024 * 1024, + DoCRC: true, + CRCPolynomial: int64(crc32.IEEE), + DoSync: true, + SliceMaxBytes: 64 * 1024 * 1024, // 64MB + ParallelRecovery: 1, +} + +// BufferSize is a FileStore option that sets the size of the buffer used +// during store writes. This can help improve write performance. +func BufferSize(size int) FileStoreOption { + return func(o *FileStoreOptions) error { + if size < 0 { + return fmt.Errorf("buffer size value must be a positive number") + } + o.BufferSize = size + return nil + } +} + +// CompactEnabled is a FileStore option that enables or disables file compaction. +// The value false will disable compaction. +func CompactEnabled(enabled bool) FileStoreOption { + return func(o *FileStoreOptions) error { + o.CompactEnabled = enabled + return nil + } +} + +// CompactInterval is a FileStore option that defines the minimum compaction interval. +// Compaction is not timer based, but instead when things get "deleted". This value +// prevents compaction to happen too often. +func CompactInterval(seconds int) FileStoreOption { + return func(o *FileStoreOptions) error { + if seconds <= 0 { + return fmt.Errorf("compact interval value must at least be 1 seconds") + } + o.CompactInterval = seconds + return nil + } +} + +// CompactFragmentation is a FileStore option that defines the fragmentation ratio +// below which compaction would not occur. For instance, specifying 50 means that +// if other variables would allow for compaction, the compaction would occur only +// after 50% of the file has data that is no longer valid. +func CompactFragmentation(fragmentation int) FileStoreOption { + return func(o *FileStoreOptions) error { + if fragmentation <= 0 { + return fmt.Errorf("compact fragmentation value must at least be 1") + } + o.CompactFragmentation = fragmentation + return nil + } +} + +// CompactMinFileSize is a FileStore option that defines the minimum file size below +// which compaction would not occur. Specify `0` if you don't want any minimum. +func CompactMinFileSize(fileSize int64) FileStoreOption { + return func(o *FileStoreOptions) error { + if fileSize < 0 { + return fmt.Errorf("compact minimum file size value must be a positive number") + } + o.CompactMinFileSize = fileSize + return nil + } +} + +// DoCRC is a FileStore option that defines if a CRC checksum verification should +// be performed when records are read from disk. +func DoCRC(enableCRC bool) FileStoreOption { + return func(o *FileStoreOptions) error { + o.DoCRC = enableCRC + return nil + } +} + +// CRCPolynomial is a FileStore option that defines the polynomial to use to create +// the table used for CRC-32 Checksum. +// See https://golang.org/pkg/hash/crc32/#MakeTable +func CRCPolynomial(polynomial int64) FileStoreOption { + return func(o *FileStoreOptions) error { + if polynomial <= 0 || polynomial > int64(0xFFFFFFFF) { + return fmt.Errorf("crc polynomial should be between 1 and %v", int64(0xFFFFFFFF)) + } + o.CRCPolynomial = polynomial + return nil + } +} + +// DoSync is a FileStore option that defines if `File.Sync()` should be called +// during a `Flush()` call. +func DoSync(enableFileSync bool) FileStoreOption { + return func(o *FileStoreOptions) error { + o.DoSync = enableFileSync + return nil + } +} + +// SliceConfig is a FileStore option that allows the configuration of +// file slice limits and optional archive script file name. +func SliceConfig(maxMsgs int, maxBytes int64, maxAge time.Duration, script string) FileStoreOption { + return func(o *FileStoreOptions) error { + if maxMsgs < 0 || maxBytes < 0 || maxAge < 0 { + return fmt.Errorf("slice max values must be positive numbers") + } + o.SliceMaxMsgs = maxMsgs + o.SliceMaxBytes = maxBytes + o.SliceMaxAge = maxAge + o.SliceArchiveScript = script + return nil + } +} + +// FileDescriptorsLimit is a soft limit hinting at FileStore to try to +// limit the number of concurrent opened files to that limit. +func FileDescriptorsLimit(limit int64) FileStoreOption { + return func(o *FileStoreOptions) error { + if limit < 0 { + return fmt.Errorf("file descriptor limit must be a positive number") + } + o.FileDescriptorsLimit = limit + return nil + } +} + +// ParallelRecovery is a FileStore option that allows the parallel +// recovery of channels. When running with SSDs, try to use a higher +// value than the default number of 1. When running with HDDs, +// performance may be better if it stays at 1. +func ParallelRecovery(count int) FileStoreOption { + return func(o *FileStoreOptions) error { + if count <= 0 { + return fmt.Errorf("parallel recovery value must be at least 1") + } + o.ParallelRecovery = count + return nil + } +} + +// TruncateUnexpectedEOF indicates if on recovery the store should +// truncate a file that reports an unexpected end-of-file (EOF) on recovery. +// If set to true, the invalid record byte content is printed but the store +// will truncate the file prior to this bad record and proceed with recovery. +// Dataloss may occur. +func TruncateUnexpectedEOF(truncate bool) FileStoreOption { + return func(o *FileStoreOptions) error { + o.TruncateUnexpectedEOF = truncate + return nil + } +} + +// AllOptions is a convenient option to pass all options from a FileStoreOptions +// structure to the constructor. +func AllOptions(opts *FileStoreOptions) FileStoreOption { + return func(o *FileStoreOptions) error { + if err := BufferSize(opts.BufferSize)(o); err != nil { + return err + } + if err := CompactInterval(opts.CompactInterval)(o); err != nil { + return err + } + if err := CompactFragmentation(opts.CompactFragmentation)(o); err != nil { + return err + } + if err := CompactMinFileSize(opts.CompactMinFileSize)(o); err != nil { + return err + } + if err := CRCPolynomial(opts.CRCPolynomial)(o); err != nil { + return err + } + if err := SliceConfig(opts.SliceMaxMsgs, opts.SliceMaxBytes, opts.SliceMaxAge, opts.SliceArchiveScript)(o); err != nil { + return err + } + if err := FileDescriptorsLimit(opts.FileDescriptorsLimit)(o); err != nil { + return err + } + if err := ParallelRecovery(opts.ParallelRecovery)(o); err != nil { + return err + } + o.CompactEnabled = opts.CompactEnabled + o.DoCRC = opts.DoCRC + o.DoSync = opts.DoSync + o.TruncateUnexpectedEOF = opts.TruncateUnexpectedEOF + return nil + } +} + +// Type for the records in the subscriptions file +type recordType byte + +// Protobufs do not share a common interface, yet, when saving a +// record on disk, we have to get the size and marshal the record in +// a buffer. These methods are available in all the protobuf. +// So we create this interface with those two methods to be used by the +// writeRecord method. +type record interface { + Size() int + MarshalTo([]byte) (int, error) +} + +// This is use for cases when the record is not typed +const recNoType = recordType(0) + +// Record types for subscription file +const ( + subRecNew = recordType(iota) + 1 + subRecUpdate + subRecDel + subRecAck + subRecMsg +) + +// Record types for client store +const ( + addClient = recordType(iota) + 1 + delClient +) + +type fileID int64 + +type beforeFileClose func() error + +const ( + invalidFileID fileID = -1 + + fileOpened = int32(1) + fileInUse = int32(2) + fileClosing = int32(3) + fileClosed = int32(4) + fileRemoved = int32(5) + fmClosed = int32(6) +) + +type file struct { + // Atomic need to be memory aligned. Put them first in the + // structure definition. + state int32 + + id fileID + handle *os.File + name string + flags int + beforeClose beforeFileClose +} + +type filesManager struct { + sync.Mutex + openedFDs int64 + limit int64 + rootDir string + files map[fileID]*file + nextID fileID + isClosed bool +} + +// FileStore is the storage interface for STAN servers, backed by files. +type FileStore struct { + genericStore + fm *filesManager + serverFile *file + clientsFile *file + opts FileStoreOptions + compactItvl time.Duration + clients map[string]*Client + delClientRec spb.ClientDelete + cliFileSize int64 + cliDeleteRecs int // Number of deleted client records + cliCompactTS time.Time + crcTable *crc32.Table + lockFile util.LockFile +} + +type subscription struct { + sub *spb.SubState + seqnos map[uint64]struct{} +} + +type bufferedWriter struct { + buf *bufio.Writer + bufSize int // current buffer size + minShrinkSize int // minimum shrink size. Note that this can be bigger than maxSize (see setSizes) + maxSize int // maximum size the buffer can grow + shrinkReq bool // used to decide if buffer should shrink +} + +// FileSubStore is a subscription store in files. +type FileSubStore struct { + genericSubStore + fstore *FileStore + fm *filesManager + tmpSubBuf []byte + file *file + bw *bufferedWriter + delSub spb.SubStateDelete + updateSub spb.SubStateUpdate + opts *FileStoreOptions // points to options from FileStore + compactItvl time.Duration + fileSize int64 + numRecs int // Number of records (sub and msgs) + delRecs int // Number of delete (or ack) records + compactTS time.Time + crcTable *crc32.Table // reference to the one from FileStore + activity bool // was there any write between two flush calls + writer io.Writer // this is either `bw` or `file` depending if buffer writer is used or not + shrinkTimer *time.Timer // timer associated with callback shrinking buffer when possible + allDone sync.WaitGroup +} + +// fileSlice represents one of the message store file (there are a number +// of files for a MsgStore on a given channel). +type fileSlice struct { + file *file + idxFile *file + firstSeq uint64 + lastSeq uint64 + rmCount int // Count of messages "removed" from the slice due to limits. + msgsCount int + msgsSize uint64 + firstWrite int64 // Time the first message was added to this slice (used for slice age limit) + lastUsed int64 +} + +// msgIndex contains the message's offset in the data file, its timestamp +// and size, which allows quick recovery of message and reconstructing of +// file slices. It also helps for GetSequenceFromTimestamp by not having +// to recover actual messages to find out the correct message sequence +// based on timestamp. +type msgIndex struct { + offset int64 + timestamp int64 + msgSize uint32 +} + +// bufferedMsg is required to keep track of a message and msgRecord when +// file buffering is used. It is possible that a message and index is +// not flushed on disk while the message gets removed from the store +// due to limit. We need a map that keeps a reference to message and +// record until the file is flushed. +type bufferedMsg struct { + msg *pb.MsgProto + index *msgIndex +} + +// cachedMsg is a structure that contains a reference to a message +// and cache expiration value. The cache has a map and list so +// that cached messages can be ordered by expiration time. +type cachedMsg struct { + expiration int64 + msg *pb.MsgProto + prev *cachedMsg + next *cachedMsg +} + +// msgsCache is the file store cache. +type msgsCache struct { + tryEvict int32 + seqMaps map[uint64]*cachedMsg + head *cachedMsg + tail *cachedMsg +} + +// FileMsgStore is a per channel message file store. +type FileMsgStore struct { + genericMsgStore + // Atomic operations require 64bit aligned fields to be able + // to run with 32bit processes. + checkSlices int64 // used with atomic operations + timeTick int64 // time captured in background tasks go routine + + tmpMsgBuf []byte + fm *filesManager // shortcut to ms.fstore.fm + hasFDsLimit bool // shortcut to ms.fstore.opts.FileDescriptorsLimit > 0 + bw *bufferedWriter + writer io.Writer // this is `bw.buf` or `file` depending if buffer writer is used or not + files map[int]*fileSlice + writeSlice *fileSlice + channelName string + firstFSlSeq int // First file slice sequence number + lastFSlSeq int // Last file slice sequence number + slCountLim int + slSizeLim uint64 + slAgeLim int64 + slHasLimits bool + fstore *FileStore // pointer to file store object + cache *msgsCache + wOffset int64 + firstMsg *pb.MsgProto + lastMsg *pb.MsgProto + expiration int64 + bufferedSeqs []uint64 + bufferedMsgs map[uint64]*bufferedMsg + bkgTasksDone chan bool // signal the background tasks go routine to stop + bkgTasksWake chan bool // signal the background tasks go routine to get out of a sleep + allDone sync.WaitGroup +} + +// some variables based on constants but that we can change +// for tests puposes. +var ( + bufShrinkInterval = defaultBufShrinkInterval + bkgTaskMu sync.Mutex + bkgTaskRefs int + bkgTasksSleepDuration = defaultBkgTasksSleepDuration + cacheTTL = int64(defaultCacheTTL) + sliceCloseInterval = defaultSliceCloseInterval +) + +// FileStoreTestSetBackgroundTaskInterval is used by tests to reduce the interval +// at which some tasks are performed in the background +func FileStoreTestSetBackgroundTaskInterval(wait time.Duration) { + // It is possible that both the server test package and + // stores test package run in paraller. Ensure that only + // one is setting the value to avoid races. + bkgTaskMu.Lock() + if bkgTaskRefs == 0 { + bkgTasksSleepDuration = wait + } + bkgTaskRefs++ + bkgTaskMu.Unlock() +} + +// openFile opens the file specified by `filename`. +// If the file exists, it checks that the version is supported. +// The file is created if not present, opened in Read/Write and Append mode. +var openFile = func(fileName string) (*os.File, error) { + return openFileWithFlags(fileName, defaultFileFlags) +} + +// openFileWithModes opens the file specified by `filename`, using +// the `modes` as open flags. +// If the file exists, it checks that the version is supported. +// If no open mode override is provided, the file is created if not present, +// opened in Read/Write and Append mode. +func openFileWithFlags(fileName string, flags int) (*os.File, error) { + checkVersion := false + + // Check if file already exists + if s, err := os.Stat(fileName); s != nil && err == nil { + checkVersion = true + } + file, err := os.OpenFile(fileName, flags, 0666) + if err != nil { + return nil, err + } + + if checkVersion { + err = checkFileVersion(file) + } else { + // This is a new file, write our file version + err = util.WriteInt(file, fileVersion) + } + if err != nil { + file.Close() + file = nil + } + return file, err +} + +// check that the version of the file is understood by this interface +func checkFileVersion(r io.Reader) error { + fv, err := util.ReadInt(r) + if err != nil { + return fmt.Errorf("unable to verify file version: %v", err) + } + if fv == 0 || fv > fileVersion { + return fmt.Errorf("unsupported file version: %v (supports [1..%v])", fv, fileVersion) + } + return nil +} + +// writeRecord writes a record to `w`. +// The record layout is as follows: +// 8 bytes: 4 bytes for type and/or size combined +// 4 bytes for CRC-32 +// variable bytes: payload. +// If a buffer is provided, this function uses it and expands it if necessary. +// The function returns the buffer (possibly changed due to expansion) and the +// number of bytes written into that buffer. +func writeRecord(w io.Writer, buf []byte, recType recordType, rec record, recSize int, crcTable *crc32.Table) ([]byte, int, error) { + // This is the header + payload size + totalSize := recordHeaderSize + recSize + // Alloc or realloc as needed + buf = util.EnsureBufBigEnough(buf, totalSize) + // If there is a record type, encode it + headerFirstInt := 0 + if recType != recNoType { + if recSize > 0xFFFFFF { + panic("record size too big") + } + // Encode the type in the high byte of the header + headerFirstInt = int(recType)<<24 | recSize + } else { + // The header is the size of the record + headerFirstInt = recSize + } + // Write the first part of the header at the beginning of the buffer + util.ByteOrder.PutUint32(buf[:4], uint32(headerFirstInt)) + // Marshal the record into the given buffer, after the header offset + if _, err := rec.MarshalTo(buf[recordHeaderSize:totalSize]); err != nil { + // Return the buffer because the caller may have provided one + return buf, 0, err + } + // Compute CRC + crc := crc32.Checksum(buf[recordHeaderSize:totalSize], crcTable) + // Write it in the buffer + util.ByteOrder.PutUint32(buf[4:recordHeaderSize], crc) + // Are we dealing with a buffered writer? + bw, isBuffered := w.(*bufio.Writer) + // if so, make sure that if what we are about to "write" is more + // than what's available, then first flush the buffer. + // This is to reduce the risk of partial writes. + if isBuffered && (bw.Buffered() > 0) && (bw.Available() < totalSize) { + if err := bw.Flush(); err != nil { + return buf, 0, err + } + } + // Write the content of our slice into the writer `w` + if _, err := w.Write(buf[:totalSize]); err != nil { + // Return the tmpBuf because the caller may have provided one + return buf, 0, err + } + return buf, totalSize, nil +} + +// readRecord reads a record from `r`, possibly checking the CRC-32 checksum. +// When `buf`` is not nil, this function ensures the buffer is big enough to +// hold the payload (expanding if necessary). Therefore, this call always +// return `buf`, regardless if there is an error or not. +// The caller is indicating if the record is supposed to be typed or not. +func readRecord(r io.Reader, buf []byte, recTyped bool, crcTable *crc32.Table, checkCRC bool) ([]byte, int, recordType, error) { + _header := [recordHeaderSize]byte{} + header := _header[:] + if _, err := io.ReadFull(r, header); err != nil { + return buf, 0, recNoType, err + } + recType := recNoType + recSize := 0 + firstInt := int(util.ByteOrder.Uint32(header[:4])) + if recTyped { + recType = recordType(firstInt >> 24 & 0xFF) + recSize = firstInt & 0xFFFFFF + } else { + recSize = firstInt + } + if recSize == 0 && recType == 0 { + crc := util.ByteOrder.Uint32(header[4:recordHeaderSize]) + if crc == 0 { + return buf, 0, 0, errNeedRewind + } + } + // Now we are going to read the payload + buf = util.EnsureBufBigEnough(buf, recSize) + if _, err := io.ReadFull(r, buf[:recSize]); err != nil { + return buf, 0, recNoType, err + } + if checkCRC { + crc := util.ByteOrder.Uint32(header[4:recordHeaderSize]) + // check CRC against what was stored + if c := crc32.Checksum(buf[:recSize], crcTable); c != crc { + return buf, 0, recNoType, fmt.Errorf("corrupted data, expected crc to be 0x%08x, got 0x%08x", crc, c) + } + } + return buf, recSize, recType, nil +} + +// setSize sets the initial buffer size and keep track of min/max allowed sizes +func newBufferWriter(minShrinkSize, maxSize int) *bufferedWriter { + w := &bufferedWriter{minShrinkSize: minShrinkSize, maxSize: maxSize} + w.bufSize = minShrinkSize + // The minSize is the minimum size the buffer can shrink to. + // However, if the given max size is smaller than the min + // shrink size, use that instead. + if maxSize < minShrinkSize { + w.bufSize = maxSize + } + return w +} + +// createNewWriter creates a new buffer writer for `file` with +// the bufferedWriter's current buffer size. +func (w *bufferedWriter) createNewWriter(file *os.File) io.Writer { + w.buf = bufio.NewWriterSize(file, w.bufSize) + return w.buf +} + +// expand the buffer (first flushing the buffer if not empty) +func (w *bufferedWriter) expand(file *os.File, required int) (io.Writer, error) { + // If there was a request to shrink the buffer, cancel that. + w.shrinkReq = false + // If there was something, flush first + if w.buf.Buffered() > 0 { + if err := w.buf.Flush(); err != nil { + return w.buf, err + } + } + // Double the size + w.bufSize *= 2 + // If still smaller than what is required, adjust + if w.bufSize < required { + w.bufSize = required + } + // But cap it. + if w.bufSize > w.maxSize { + w.bufSize = w.maxSize + } + w.buf = bufio.NewWriterSize(file, w.bufSize) + return w.buf, nil +} + +// tryShrinkBuffer checks and possibly shrinks the buffer +func (w *bufferedWriter) tryShrinkBuffer(file *os.File) (io.Writer, error) { + // Nothing to do if we are already at the lowest + // or file not set/opened. + if w.bufSize == w.minShrinkSize || file == nil { + return w.buf, nil + } + + if !w.shrinkReq { + percentFilled := w.buf.Buffered() * 100 / w.bufSize + if percentFilled <= bufShrinkThreshold { + w.shrinkReq = true + } + // Wait for next tick to see if we can shrink + return w.buf, nil + } + if err := w.buf.Flush(); err != nil { + return w.buf, err + } + // Reduce size, but ensure it does not go below the limit + w.bufSize /= 2 + if w.bufSize < w.minShrinkSize { + w.bufSize = w.minShrinkSize + } + w.buf = bufio.NewWriterSize(file, w.bufSize) + // Don't reset shrinkReq unless we are down to the limit + if w.bufSize == w.minShrinkSize { + w.shrinkReq = true + } + return w.buf, nil +} + +// checkShrinkRequest checks how full the buffer is, and if is above a certain +// threshold, cancels the shrink request +func (w *bufferedWriter) checkShrinkRequest() { + percentFilled := w.buf.Buffered() * 100 / w.bufSize + // If above the threshold, cancel the request. + if percentFilled > bufShrinkThreshold { + w.shrinkReq = false + } +} + +//////////////////////////////////////////////////////////////////////////// +// filesManager methods +//////////////////////////////////////////////////////////////////////////// + +// createFilesManager returns an instance of the files manager. +func createFilesManager(rootDir string, openedFilesLimit int64) *filesManager { + fm := &filesManager{ + rootDir: rootDir, + limit: openedFilesLimit, + files: make(map[fileID]*file), + } + return fm +} + +// closeUnusedFiles cloes files that are opened and not currently in-use. +// Since the number of opened files is a soft limit, and if this function +// is unable to close any file, the caller will still attempt to create/open +// the requested file. If the system's file descriptor limit is reached, +// opening the file will fail and that error will be returned to the caller. +// Lock is required on entry. +func (fm *filesManager) closeUnusedFiles(idToSkip fileID) { + for _, file := range fm.files { + if file.id == idToSkip { + continue + } + if atomic.CompareAndSwapInt32(&file.state, fileOpened, fileClosing) { + fm.doClose(file) + if fm.openedFDs < fm.limit { + break + } + } + } +} + +// createFile creates a file, open it, adds it to the list of files and returns +// an instance of `*file` with the state sets to `fileInUse`. +// This call will possibly cause opened but unused files to be closed if the +// number of open file requests is above the set limit. +func (fm *filesManager) createFile(name string, flags int, bfc beforeFileClose) (*file, error) { + fm.Lock() + if fm.isClosed { + fm.Unlock() + return nil, fmt.Errorf("unable to create file %q, store is being closed", name) + } + if fm.limit > 0 && fm.openedFDs >= fm.limit { + fm.closeUnusedFiles(0) + } + fileName := filepath.Join(fm.rootDir, name) + handle, err := openFileWithFlags(fileName, flags) + if err != nil { + fm.Unlock() + return nil, err + } + fm.nextID++ + newFile := &file{ + state: fileInUse, + id: fm.nextID, + handle: handle, + name: fileName, + flags: flags, + beforeClose: bfc, + } + fm.files[newFile.id] = newFile + fm.openedFDs++ + fm.Unlock() + return newFile, nil +} + +// openFile opens the given file and sets its state to `fileInUse`. +// If the file manager has been closed or the file removed, this call +// returns an error. +// Otherwise, if the file's state is not `fileClosed` this call will panic. +// This call will possibly cause opened but unused files to be closed if the +// number of open file requests is above the set limit. +func (fm *filesManager) openFile(file *file) error { + fm.Lock() + if fm.isClosed { + fm.Unlock() + return fmt.Errorf("unable to open file %q, store is being closed", file.name) + } + curState := atomic.LoadInt32(&file.state) + if curState == fileRemoved { + fm.Unlock() + return fmt.Errorf("unable to open file %q, it has been removed", file.name) + } + if curState != fileClosed || file.handle != nil { + fm.Unlock() + panic(fmt.Errorf("request to open file %q but invalid state: handle=%v - state=%v", file.name, file.handle, file.state)) + } + var err error + if fm.limit > 0 && fm.openedFDs >= fm.limit { + fm.closeUnusedFiles(file.id) + } + file.handle, err = openFileWithFlags(file.name, file.flags) + if err == nil { + atomic.StoreInt32(&file.state, fileInUse) + fm.openedFDs++ + } + fm.Unlock() + return err +} + +// closeLockedFile closes the handle of the given file, but only if the caller +// has locked the file. Will panic otherwise. +// If the file's beforeClose callback is not nil, this callback is invoked +// before the file handle is closed. +func (fm *filesManager) closeLockedFile(file *file) error { + if !atomic.CompareAndSwapInt32(&file.state, fileInUse, fileClosing) { + panic(fmt.Errorf("file %q is requested to be closed but was not locked by caller", file.name)) + } + fm.Lock() + err := fm.doClose(file) + fm.Unlock() + return err +} + +// closeFileIfOpened closes the handle of the given file, but only if the +// file is opened and not currently locked. Does not return any error or panic +// if file is in any other state. +// If the file's beforeClose callback is not nil, this callback is invoked +// before the file handle is closed. +func (fm *filesManager) closeFileIfOpened(file *file) error { + if !atomic.CompareAndSwapInt32(&file.state, fileOpened, fileClosing) { + return nil + } + fm.Lock() + err := fm.doClose(file) + fm.Unlock() + return err +} + +// closeLockedOrOpenedFile closes the handle of the given file if this file +// is either locked or opened. Does not return any error or panic if file +// is in any other state. +// If the file's beforeClose callback is not nil, this callback is invoked +// before the file handle is closed. +func (fm *filesManager) closeLockedOrOpenedFile(file *file) error { + // Check first locked files + if !atomic.CompareAndSwapInt32(&file.state, fileInUse, fileClosing) { + // then opened but unlocked files + if !atomic.CompareAndSwapInt32(&file.state, fileOpened, fileClosing) { + return nil + } + } + fm.Lock() + err := fm.doClose(file) + fm.Unlock() + return err +} + +// doClose closes the file handle, setting it to nil and switching state to `fileClosed`. +// If a `beforeClose` callback was registered on file creation, it is invoked +// before the file handler is actually closed. +// Lock is required on entry. +func (fm *filesManager) doClose(file *file) error { + var err error + if file.beforeClose != nil { + err = file.beforeClose() + } + util.CloseFile(err, file.handle) + // Regardless of error, we need to change the state to closed. + file.handle = nil + atomic.StoreInt32(&file.state, fileClosed) + fm.openedFDs-- + return err +} + +// lockFile locks the given file. +// If the file was already opened, the boolean returned is true, +// otherwise, the file is opened and the call returns false. +func (fm *filesManager) lockFile(file *file) (bool, error) { + if atomic.CompareAndSwapInt32(&file.state, fileOpened, fileInUse) { + return true, nil + } + return false, fm.openFile(file) +} + +// lockFileIfOpened is like lockFile but returns true only if the +// file is already opened, false otherwise (and the file remain closed). +func (fm *filesManager) lockFileIfOpened(file *file) bool { + return atomic.CompareAndSwapInt32(&file.state, fileOpened, fileInUse) +} + +// unlockFile unlocks the file if currently locked, otherwise panic. +func (fm *filesManager) unlockFile(file *file) { + if !atomic.CompareAndSwapInt32(&file.state, fileInUse, fileOpened) { + panic(fmt.Errorf("failed to switch state from fileInUse to fileOpened for file %q, state=%v", + file.name, file.state)) + } +} + +// trySwitchState attempts to switch an initial state of `fileOpened` +// or `fileClosed` to the given newState. If it can't it will return an +// error, otherwise, returned a boolean to indicate if the initial state +// was `fileOpened`. +func (fm *filesManager) trySwitchState(file *file, newState int32) (bool, error) { + wasOpened := false + wasClosed := false + for i := 0; i < 10000; i++ { + if atomic.CompareAndSwapInt32(&file.state, fileOpened, newState) { + wasOpened = true + break + } + if atomic.CompareAndSwapInt32(&file.state, fileClosed, newState) { + wasClosed = true + break + } + if i%1000 == 1 { + time.Sleep(time.Millisecond) + } + } + if !wasOpened && !wasClosed { + return false, fmt.Errorf("file %q is still probably locked", file.name) + } + return wasOpened, nil +} + +// remove a file from the list of files. The initial state must be either `fileOpened` +// or `fileClosed`. This call will loop until it can switch the file's state from +// one of these states to `fileRemoved`, or return an error if the change can't +// be made after a certain number of attempts. +// When removed, this call returns true and the given `file` is untouched (except +// for its state). So it is still possible for caller to read/write (if handle is +// valid) or close this file. +func (fm *filesManager) remove(file *file) bool { + fm.Lock() + wasOpened, err := fm.trySwitchState(file, fileRemoved) + if err != nil { + fm.Unlock() + return false + } + // With code above, we can't be removing a file twice, so no need to check if + // file is present in map. + delete(fm.files, file.id) + if wasOpened { + fm.openedFDs-- + } + fm.Unlock() + return true +} + +// setBeforeCloseCb sets the beforeFileClose callback for this file. +// When this callback is set, and the files manager closes a file, +// the callback is invoked prior to actual closing of the file handle. +// This allows the caller to perfom some work before the file is +// asynchronously (form its perspective) closed. +func (fm *filesManager) setBeforeCloseCb(file *file, bccb beforeFileClose) { + fm.Lock() + file.beforeClose = bccb + fm.Unlock() +} + +// truncateFile truncates the file to the given offset. +// The file is assumed to be locked on entry. +// If the file's flags indicate that this file is opened with O_APPEND, it +// is first closed, reopened in non append mode, truncated, then reopened +// (and locked) with original flags. +func (fm *filesManager) truncateFile(file *file, offset int64) error { + reopen := false + fd := file.handle + if file.flags&os.O_APPEND != 0 { + if err := fm.closeLockedFile(file); err != nil { + return err + } + var err error + fd, err = openFileWithFlags(file.name, os.O_RDWR) + if err != nil { + return err + } + reopen = true + } + newPos := offset + if err := fd.Truncate(newPos); err != nil { + return err + } + pos, err := fd.Seek(newPos, io.SeekStart) // or Seek(0, io.SeekEnd) + if err != nil { + return err + } + if pos != newPos { + return fmt.Errorf("unable to set position of file %q to %v", file.name, newPos) + } + if reopen { + if err := fd.Close(); err != nil { + return err + } + if err := fm.openFile(file); err != nil { + return err + } + } + return nil +} + +// close the files manager, including all files currently opened. +// Returns the first error encountered when closing the files. +func (fm *filesManager) close() error { + fm.Lock() + if fm.isClosed { + fm.Unlock() + return nil + } + fm.isClosed = true + + files := make([]*file, 0, len(fm.files)) + for _, file := range fm.files { + files = append(files, file) + } + fm.files = nil + fm.Unlock() + + var err error + for _, file := range files { + wasOpened, sserr := fm.trySwitchState(file, fmClosed) + if sserr != nil { + if err == nil { + err = sserr + } + } else if wasOpened { + fm.Lock() + if cerr := fm.doClose(file); cerr != nil && err == nil { + err = cerr + } + fm.Unlock() + } + } + return err +} + +//////////////////////////////////////////////////////////////////////////// +// FileStore methods +//////////////////////////////////////////////////////////////////////////// + +// NewFileStore returns a factory for stores backed by files. +// If not limits are provided, the store will be created with +// DefaultStoreLimits. +func NewFileStore(log logger.Logger, rootDir string, limits *StoreLimits, options ...FileStoreOption) (*FileStore, error) { + if rootDir == "" { + return nil, fmt.Errorf("for %v stores, root directory must be specified", TypeFile) + } + + fs := &FileStore{opts: DefaultFileStoreOptions, clients: make(map[string]*Client)} + if err := fs.init(TypeFile, log, limits); err != nil { + return nil, err + } + + for _, opt := range options { + if err := opt(&fs.opts); err != nil { + return nil, err + } + } + // Create filesManager based on options' FD limit + fs.fm = createFilesManager(rootDir, fs.opts.FileDescriptorsLimit) + // Convert the compact interval in time.Duration + fs.compactItvl = time.Duration(fs.opts.CompactInterval) * time.Second + // Create the table using polynomial in options + if fs.opts.CRCPolynomial == int64(crc32.IEEE) { + fs.crcTable = crc32.IEEETable + } else { + fs.crcTable = crc32.MakeTable(uint32(fs.opts.CRCPolynomial)) + } + + if err := os.MkdirAll(rootDir, os.ModeDir+os.ModePerm); err != nil && !os.IsExist(err) { + return nil, fmt.Errorf("unable to create the root directory [%s]: %v", rootDir, err) + } + + // If the TruncateUnexpectedEOF is set, check that the witness + // file is not present. If it is, fail starting. If it isn't, + // create the witness file. + truncateFName := filepath.Join(rootDir, truncateBadEOFFileName) + if fs.opts.TruncateUnexpectedEOF { + // Try to create the file, if it exists, this is an error. + f, err := os.OpenFile(truncateFName, os.O_CREATE|os.O_EXCL, 0666) + if f != nil { + f.Close() + } + if err != nil { + return nil, fmt.Errorf("file store should not be opened consecutively with the TruncateUnexpectedEOF option set to true") + } + } else { + // Delete possible TruncateUnexpectedEOF witness file + os.Remove(truncateFName) + } + + return fs, nil +} + +type channelRecoveryCtx struct { + wg *sync.WaitGroup + poolCh chan struct{} + errCh chan error + recoverCh chan *recoveredChannel +} + +type recoveredChannel struct { + name string + rc *RecoveredChannel +} + +// Recover implements the Store interface +func (fs *FileStore) Recover() (*RecoveredState, error) { + fs.Lock() + defer fs.Unlock() + var ( + err error + recoveredState *RecoveredState + serverInfo *spb.ServerInfo + recoveredClients []*Client + recoveredChannels = make(map[string]*RecoveredChannel) + channels []os.FileInfo + ) + + // Ensure store is closed in case of return with error + defer func() { + if fs.serverFile != nil { + fs.fm.unlockFile(fs.serverFile) + } + if fs.clientsFile != nil { + fs.fm.unlockFile(fs.clientsFile) + } + }() + + // Open/Create the server file (note that this file must not be opened, + // in APPEND mode to allow truncate to work). + fs.serverFile, err = fs.fm.createFile(serverFileName, os.O_RDWR|os.O_CREATE, nil) + if err != nil { + return nil, err + } + + // Open/Create the client file. + fs.clientsFile, err = fs.fm.createFile(clientsFileName, defaultFileFlags, nil) + if err != nil { + return nil, err + } + + // Recover the server file. + serverInfo, err = fs.recoverServerInfo() + if err != nil { + return nil, fmt.Errorf("unable to recover server file %q: %v", fs.serverFile.name, err) + } + // If the server file is empty, then we are done + if serverInfo == nil { + // We return the file store instance, but no recovered state. + return nil, nil + } + + // Recover the clients file + recoveredClients, err = fs.recoverClients() + if err != nil { + return nil, fmt.Errorf("unable to recover client file %q: %v", fs.clientsFile.name, err) + } + + // Get the channels (there are subdirectories of rootDir) + channels, err = ioutil.ReadDir(fs.fm.rootDir) + if err != nil { + return nil, err + } + if len(channels) > 0 { + wg, poolCh, errCh, recoverCh := initParalleRecovery(fs.opts.ParallelRecovery, len(channels)) + ctx := &channelRecoveryCtx{wg: wg, poolCh: poolCh, errCh: errCh, recoverCh: recoverCh} + for _, c := range channels { + // Channels are directories. Ignore simple files + if !c.IsDir() { + continue + } + channel := c.Name() + channelDirName := filepath.Join(fs.fm.rootDir, channel) + limits := fs.genericStore.getChannelLimits(channel) + // This will block if the max number of go-routines is reached. + // When one of the go-routine finishes, it will add back to the + // pool and we will be able to start the recovery of another + // channel. + <-poolCh + wg.Add(1) + go fs.recoverOneChannel(channelDirName, channel, limits, ctx) + // Fail as soon as we detect that a go routine has encountered + // an error + if len(errCh) > 0 { + break + } + } + // We need to wait for all current go routines to exit + wg.Wait() + // Also, even if there was an error, we need to collect + // all channels that were recovered so that we can close + // the msgs/subs stores on exit. + done := false + for !done { + select { + case rc := <-recoverCh: + recoveredChannels[rc.name] = rc.rc + fs.channels[rc.name] = rc.rc.Channel + default: + done = true + } + } + select { + case err = <-errCh: + return nil, err + default: + } + } + // Create the recovered state to return + recoveredState = &RecoveredState{ + Info: serverInfo, + Clients: recoveredClients, + Channels: recoveredChannels, + } + return recoveredState, nil +} + +func initParalleRecovery(maxGoRoutines, foundChannels int) (*sync.WaitGroup, chan struct{}, chan error, chan *recoveredChannel) { + wg := sync.WaitGroup{} + poolCh := make(chan struct{}, maxGoRoutines) + for i := 0; i < maxGoRoutines; i++ { + poolCh <- struct{}{} + } + errCh := make(chan error, 1) + // foundChannels is the number of directories (channels) found + // in the root directory. It is the max number of elements we will + // put in this channel during the recovery process. + recoverCh := make(chan *recoveredChannel, foundChannels) + return &wg, poolCh, errCh, recoverCh +} + +func (fs *FileStore) recoverOneChannel(dir, name string, limits *ChannelLimits, ctx *channelRecoveryCtx) { + var ( + msgStore *FileMsgStore + subStore *FileSubStore + err error + ) + defer func() { + if err != nil { + select { + case ctx.errCh <- err: + default: + } + } + ctx.poolCh <- struct{}{} + ctx.wg.Done() + }() + msgStore, err = fs.newFileMsgStore(dir, name, &limits.MsgStoreLimits, true) + if err != nil { + return + } + subStore, err = fs.newFileSubStore(name, &limits.SubStoreLimits, true) + if err != nil { + msgStore.Close() + return + } + + recoveredChannel := &recoveredChannel{ + name: name, + rc: &RecoveredChannel{ + Channel: &Channel{ + Subs: subStore, + Msgs: msgStore, + }, + Subscriptions: make([]*RecoveredSubscription, 0, len(subStore.subs)), + }, + } + + // Fill that array with what we got from newFileSubStore. + for _, subi := range subStore.subs { + sub := subi.(*subscription) + // The server is making a copy of rss.Sub, still it is not + // a good idea to return a pointer to an object that belong + // to the store. So make a copy and return the pointer to + // that copy. + csub := *sub.sub + rs := &RecoveredSubscription{ + Sub: &csub, + Pending: make(PendingAcks), + } + // If we recovered any seqno... + if len(sub.seqnos) > 0 { + // Lookup messages, and if we find those, update the + // Pending map. + for seq := range sub.seqnos { + rs.Pending[seq] = struct{}{} + } + } + // Add to the array of recovered subscriptions + recoveredChannel.rc.Subscriptions = append(recoveredChannel.rc.Subscriptions, rs) + } + // Push our recovered info into the recovered channel. + ctx.recoverCh <- recoveredChannel +} + +// GetExclusiveLock implements the Store interface +func (fs *FileStore) GetExclusiveLock() (bool, error) { + fs.Lock() + defer fs.Unlock() + if fs.lockFile != nil { + return true, nil + } + f, err := util.CreateLockFile(filepath.Join(fs.fm.rootDir, lockFileName)) + if err != nil { + if err == util.ErrUnableToLockNow { + return false, nil + } + return false, err + } + // We must keep a reference to the file, otherwise, it `f` is GC'ed, + // its file descriptor is closed, which automatically releases the lock. + fs.lockFile = f + return true, nil +} + +// Init is used to persist server's information after the first start +func (fs *FileStore) Init(info *spb.ServerInfo) error { + fs.Lock() + defer fs.Unlock() + + if fs.serverFile == nil { + var err error + // Open/Create the server file (note that this file must not be opened, + // in APPEND mode to allow truncate to work). + fs.serverFile, err = fs.fm.createFile(serverFileName, os.O_RDWR|os.O_CREATE, nil) + if err != nil { + return err + } + } else { + if _, err := fs.fm.lockFile(fs.serverFile); err != nil { + return err + } + } + f := fs.serverFile.handle + // defer is ok for this function... + defer fs.fm.unlockFile(fs.serverFile) + + // Truncate the file (4 is the size of the fileVersion record) + if err := f.Truncate(4); err != nil { + return err + } + // Move offset to 4 (truncate does not do that) + if _, err := f.Seek(4, io.SeekStart); err != nil { + return err + } + // ServerInfo record is not typed. We also don't pass a reusable buffer. + if _, _, err := writeRecord(f, nil, recNoType, info, info.Size(), fs.crcTable); err != nil { + return err + } + return nil +} + +// recoverClients reads the client files and returns an array of RecoveredClient +func (fs *FileStore) recoverClients() ([]*Client, error) { + var err error + var recType recordType + var recSize int + + _buf := [256]byte{} + buf := _buf[:] + offset := int64(4) + + // Create a buffered reader to speed-up recovery + br := bufio.NewReaderSize(fs.clientsFile.handle, defaultBufSize) + + for { + buf, recSize, recType, err = readRecord(br, buf, true, fs.crcTable, fs.opts.DoCRC) + if err != nil { + switch err { + case io.EOF: + err = nil + case errNeedRewind: + err = fs.fm.truncateFile(fs.clientsFile, offset) + default: + err = fs.handleUnexpectedEOF(err, fs.clientsFile, offset, true) + } + if err == nil { + break + } + return nil, err + } + readBytes := int64(recSize + recordHeaderSize) + offset += readBytes + fs.cliFileSize += readBytes + switch recType { + case addClient: + c := &Client{} + if err := c.ClientInfo.Unmarshal(buf[:recSize]); err != nil { + return nil, err + } + // Add to the map. Note that if one already exists, which should + // not, just replace with this most recent one. + fs.clients[c.ID] = c + case delClient: + c := spb.ClientDelete{} + if err := c.Unmarshal(buf[:recSize]); err != nil { + return nil, err + } + delete(fs.clients, c.ID) + fs.cliDeleteRecs++ + default: + return nil, fmt.Errorf("invalid client record type: %v", recType) + } + } + clients := make([]*Client, len(fs.clients)) + i := 0 + // Convert the map into an array + for _, c := range fs.clients { + clients[i] = c + i++ + } + return clients, nil +} + +// recoverServerInfo reads the server file and returns a ServerInfo structure +func (fs *FileStore) recoverServerInfo() (*spb.ServerInfo, error) { + info := &spb.ServerInfo{} + buf, size, _, err := readRecord(fs.serverFile.handle, nil, false, fs.crcTable, fs.opts.DoCRC) + if err != nil { + if err == io.EOF { + // We are done, no state recovered + return nil, nil + } + fs.log.Errorf("Server file %q corrupted: %v", fs.serverFile.name, err) + fs.log.Errorf("Follow instructions in documentation in order to recover from this") + return nil, err + } + // Check that the size of the file is consistent with the size + // of the record we are supposed to recover. Account for the + // 12 bytes (4 + recordHeaderSize) corresponding to the fileVersion and + // record header. + fstat, err := fs.serverFile.handle.Stat() + if err != nil { + return nil, err + } + expectedSize := int64(size + 4 + recordHeaderSize) + if fstat.Size() != expectedSize { + return nil, fmt.Errorf("incorrect file size, expected %v bytes, got %v bytes", + expectedSize, fstat.Size()) + } + // Reconstruct now + if err := info.Unmarshal(buf[:size]); err != nil { + return nil, err + } + return info, nil +} + +// CreateChannel implements the Store interface +func (fs *FileStore) CreateChannel(channel string) (*Channel, error) { + fs.Lock() + defer fs.Unlock() + + // Verify that it does not already exist or that we did not hit the limits + if err := fs.canAddChannel(channel); err != nil { + return nil, err + } + + // We create the channel here... + + channelDirName := filepath.Join(fs.fm.rootDir, channel) + if err := os.MkdirAll(channelDirName, os.ModeDir+os.ModePerm); err != nil { + return nil, err + } + + var err error + var msgStore MsgStore + var subStore SubStore + + channelLimits := fs.genericStore.getChannelLimits(channel) + + msgStore, err = fs.newFileMsgStore(channelDirName, channel, &channelLimits.MsgStoreLimits, false) + if err != nil { + return nil, err + } + subStore, err = fs.newFileSubStore(channel, &channelLimits.SubStoreLimits, false) + if err != nil { + msgStore.Close() + return nil, err + } + + c := &Channel{ + Subs: subStore, + Msgs: msgStore, + } + + fs.channels[channel] = c + + return c, nil +} + +// DeleteChannel implements the Store interface +func (fs *FileStore) DeleteChannel(channel string) error { + fs.Lock() + defer fs.Unlock() + err := fs.deleteChannel(channel) + if err != nil { + return err + } + return os.RemoveAll(filepath.Join(fs.fm.rootDir, channel)) +} + +// AddClient implements the Store interface +func (fs *FileStore) AddClient(info *spb.ClientInfo) (*Client, error) { + fs.Lock() + if _, err := fs.fm.lockFile(fs.clientsFile); err != nil { + fs.Unlock() + return nil, err + } + _, size, err := writeRecord(fs.clientsFile.handle, nil, addClient, info, info.Size(), fs.crcTable) + if err != nil { + fs.fm.unlockFile(fs.clientsFile) + fs.Unlock() + return nil, err + } + fs.cliFileSize += int64(size) + fs.fm.unlockFile(fs.clientsFile) + client := &Client{*info} + fs.clients[client.ID] = client + fs.Unlock() + return client, nil +} + +// DeleteClient implements the Store interface +func (fs *FileStore) DeleteClient(clientID string) error { + fs.Lock() + if _, err := fs.fm.lockFile(fs.clientsFile); err != nil { + fs.Unlock() + return err + } + fs.delClientRec = spb.ClientDelete{ID: clientID} + _, size, err := writeRecord(fs.clientsFile.handle, nil, delClient, &fs.delClientRec, fs.delClientRec.Size(), fs.crcTable) + // Even if there is an error, proceed. If we compact the file, + // this may resolve the issue. + delete(fs.clients, clientID) + fs.cliDeleteRecs++ + fs.cliFileSize += int64(size) + // Check if this triggers a need for compaction + if fs.shouldCompactClientFile() { + // close the file now + // If we can't close the file, it does not make sense + // to proceed with compaction. + if lerr := fs.fm.closeLockedFile(fs.clientsFile); lerr != nil { + fs.Unlock() + return lerr + } + // compact (this uses a temporary file) + // Override writeRecord error with the result of compaction. + // If compaction works, the original error is no longer an issue + // since the file has been replaced. + err = fs.compactClientFile(fs.clientsFile.name) + } else { + fs.fm.unlockFile(fs.clientsFile) + } + fs.Unlock() + return err +} + +// shouldCompactClientFile returns true if the client file should be compacted +// Lock is held by caller +func (fs *FileStore) shouldCompactClientFile() bool { + // Global switch + if !fs.opts.CompactEnabled { + return false + } + // Check that if minimum file size is set, the client file + // is at least at the minimum. + if fs.opts.CompactMinFileSize > 0 && fs.cliFileSize < fs.opts.CompactMinFileSize { + return false + } + // Check fragmentation + frag := fs.cliDeleteRecs * 100 / (fs.cliDeleteRecs + len(fs.clients)) + if frag < fs.opts.CompactFragmentation { + return false + } + // Check that we don't do too often + if time.Since(fs.cliCompactTS) < fs.compactItvl { + return false + } + return true +} + +// Rewrite the content of the clients map into a temporary file, +// then swap back to active file. +// Store lock held on entry +func (fs *FileStore) compactClientFile(orgFileName string) error { + // Open a temporary file + tmpFile, err := getTempFile(fs.fm.rootDir, clientsFileName) + if err != nil { + return err + } + defer func() { + if tmpFile != nil { + tmpFile.Close() + os.Remove(tmpFile.Name()) + } + }() + bw := bufio.NewWriterSize(tmpFile, defaultBufSize) + fileSize := int64(0) + size := 0 + _buf := [256]byte{} + buf := _buf[:] + // Dump the content of active clients into the temporary file. + for _, c := range fs.clients { + buf, size, err = writeRecord(bw, buf, addClient, &c.ClientInfo, c.ClientInfo.Size(), fs.crcTable) + if err != nil { + return err + } + fileSize += int64(size) + } + // Flush the buffer on disk + if err := bw.Flush(); err != nil { + return err + } + // Start by closing the temporary file. + if err := tmpFile.Close(); err != nil { + return err + } + // Rename the tmp file to original file name + if err := os.Rename(tmpFile.Name(), orgFileName); err != nil { + return err + } + // Avoid unnecessary attempt to cleanup + tmpFile = nil + + fs.cliDeleteRecs = 0 + fs.cliFileSize = fileSize + fs.cliCompactTS = time.Now() + return nil +} + +// Return a temporary file (including file version) +func getTempFile(rootDir, prefix string) (*os.File, error) { + tmpFile, err := ioutil.TempFile(rootDir, prefix) + if err != nil { + return nil, err + } + if err := util.WriteInt(tmpFile, fileVersion); err != nil { + return nil, err + } + return tmpFile, nil +} + +// Close closes all stores. +func (fs *FileStore) Close() error { + fs.Lock() + if fs.closed { + fs.Unlock() + return nil + } + fs.closed = true + + err := fs.genericStore.close() + + fm := fs.fm + lockFile := fs.lockFile + fs.Unlock() + + if fm != nil { + if fmerr := fm.close(); fmerr != nil && err == nil { + err = fmerr + } + } + if lockFile != nil { + err = util.CloseFile(err, lockFile) + } + return err +} + +func (fs *FileStore) handleUnexpectedEOF(recoveryErr error, f *file, offset int64, recTyped bool) error { + // Regardless the recoveryErr, we will dump the bytes for + // the corrupted record, however, we attempt to fix only + // for io.ErrUnexpectedEOF. + if recoveryErr == io.ErrUnexpectedEOF { + fs.log.Errorf("Unexpected EOF for file %q", f.name) + if !fs.opts.TruncateUnexpectedEOF { + fs.log.Errorf("It is recommended that you make a copy of the whole datatstore %q.", fs.fm.rootDir) + fs.log.Errorf("Restart with the ContinueOnUnexpectedEOF flag to truncate this file to this offset: %v.", offset) + fs.log.Errorf("Dataloss may occur. Details about the first corrupted record follows...") + } + } else { + fs.log.Errorf("Corrupted record in file %q: %v", f.name, recoveryErr) + } + if _, err := f.handle.Seek(offset, io.SeekStart); err != nil { + panic(fmt.Errorf("Unable to set position of file %q to %v: %v", f.name, offset, err)) + } + var ( + expectedSize int + read int + part string + ) + fs.log.Errorf("Record header:") + part = "record header" + expectedSize = recordHeaderSize + var ( + _header = [recordHeaderSize]byte{} + header = _header[:] + ) + read, _ = io.ReadFull(f.handle, header) + fs.log.Errorf(" Bytes:") + dumpBytes(fs.log, header[:read], false) + if read >= recordHeaderSize { + recType := recNoType + recSize := 0 + firstInt := int(util.ByteOrder.Uint32(header[:4])) + if recTyped { + recType = recordType(firstInt >> 24 & 0xFF) + recSize = firstInt & 0xFFFFFF + } else { + recSize = firstInt + } + crc := util.ByteOrder.Uint32(header[4:recordHeaderSize]) + if recTyped { + fs.log.Errorf(" Type: %v", recType) + } + fs.log.Errorf(" Size: %v", recSize) + fs.log.Errorf(" CRC : 0x%08x", crc) + fs.log.Errorf("Record payload:") + + part = "record payload" + expectedSize = recSize + buf := util.EnsureBufBigEnough(nil, recSize) + read, _ = io.ReadFull(f.handle, buf) + dumpBytes(fs.log, buf[:read], true) + } + if recoveryErr == io.ErrUnexpectedEOF { + if fs.opts.TruncateUnexpectedEOF { + if err := fs.fm.truncateFile(f, offset); err != nil { + return fmt.Errorf("unable to repair file %q by truncating at offset %v: %v", f.name, offset, err) + } + fs.log.Noticef("File %q has been truncated to offset: %v", f.name, offset) + fs.log.Noticef("Recovery resumes...") + return nil + } + fs.log.Errorf("%s expected to be %v bytes, only read %v", part, expectedSize, read) + } + return recoveryErr +} + +func dumpBytes(log logger.Logger, buf []byte, printTxt bool) { + lines := len(buf) / 20 + start := 0 + for i := 0; i < lines+1; i++ { + if start >= len(buf) { + break + } + end := len(buf) - start + if end > 20 { + end = 20 + } + bl := fmt.Sprintf("% x", buf[start:start+end]) + if printTxt { + tl := "" + for b := start; b < start+end; b++ { + c := buf[b] + if int(c) < 32 || int(c) > 128 { + c = '.' + } + tl = fmt.Sprintf("%s%s", tl, []byte{c}) + } + var paddingStr string + padding := 3 * (20 - end) + if padding > 0 { + paddingStr = fmt.Sprintf("%*s", padding, " ") + } + log.Errorf("%s%s - %s", bl, paddingStr, tl) + } else { + log.Errorf(bl) + } + start += end + } +} + +//////////////////////////////////////////////////////////////////////////// +// FileMsgStore methods +//////////////////////////////////////////////////////////////////////////// + +// newFileMsgStore returns a new instace of a file MsgStore. +func (fs *FileStore) newFileMsgStore(channelDirName, channel string, limits *MsgStoreLimits, doRecover bool) (*FileMsgStore, error) { + // Create an instance and initialize + ms := &FileMsgStore{ + fm: fs.fm, + hasFDsLimit: fs.opts.FileDescriptorsLimit > 0, + fstore: fs, + wOffset: int64(4), // The very first record starts after the file version record + files: make(map[int]*fileSlice), + channelName: channel, + bkgTasksDone: make(chan bool, 1), + bkgTasksWake: make(chan bool, 1), + } + ms.init(channel, fs.log, limits) + + ms.setSliceLimits() + ms.initCache() + + maxBufSize := fs.opts.BufferSize + if maxBufSize > 0 { + ms.bw = newBufferWriter(msgBufMinShrinkSize, maxBufSize) + ms.bufferedSeqs = make([]uint64, 0, 1) + ms.bufferedMsgs = make(map[uint64]*bufferedMsg) + } + + // Use this variable for all errors below so we can do the cleanup + var err error + + // Recovery case + if doRecover { + var dirFiles []os.FileInfo + var fseq int64 + var datFile, idxFile *file + var useIdxFile bool + + dirFiles, err = ioutil.ReadDir(channelDirName) + for _, file := range dirFiles { + if file.IsDir() { + continue + } + fileName := file.Name() + if !strings.HasPrefix(fileName, msgFilesPrefix) || !strings.HasSuffix(fileName, datSuffix) { + continue + } + // Remove suffix + fileNameWithoutSuffix := strings.TrimSuffix(fileName, datSuffix) + // Remove prefix + fileNameWithoutPrefixAndSuffix := strings.TrimPrefix(fileNameWithoutSuffix, msgFilesPrefix) + // Get the file sequence number + fseq, err = strconv.ParseInt(fileNameWithoutPrefixAndSuffix, 10, 64) + if err != nil { + err = fmt.Errorf("message log has an invalid name: %v", fileName) + break + } + idxFName := fmt.Sprintf("%s%v%s", msgFilesPrefix, fseq, idxSuffix) + useIdxFile = false + if s, statErr := os.Stat(filepath.Join(channelDirName, idxFName)); s != nil && statErr == nil { + useIdxFile = true + } + datFile, err = ms.fm.createFile(filepath.Join(channel, fileName), defaultFileFlags, nil) + if err != nil { + break + } + idxFile, err = ms.fm.createFile(filepath.Join(channel, idxFName), defaultFileFlags, nil) + if err != nil { + ms.fm.unlockFile(datFile) + break + } + // Create the slice + fslice := &fileSlice{file: datFile, idxFile: idxFile, lastUsed: time.Now().UnixNano()} + // Recover the file slice + err = ms.recoverOneMsgFile(fslice, int(fseq), useIdxFile) + if err != nil { + break + } + } + if err == nil && ms.lastFSlSeq > 0 { + // Now that all file slices have been recovered, we know which + // one is the last, so use it as the write slice. + ms.writeSlice = ms.files[ms.lastFSlSeq] + // Need to set the writer, etc.. + ms.fm.lockFile(ms.writeSlice.file) + err = ms.setFile(ms.writeSlice, -1) + ms.fm.unlockFile(ms.writeSlice.file) + if err == nil { + // Set the beforeFileClose callback to the slices now that + // we are done recovering. + for _, fslice := range ms.files { + ms.fm.setBeforeCloseCb(fslice.file, ms.beforeDataFileCloseCb(fslice)) + ms.fm.setBeforeCloseCb(fslice.idxFile, ms.beforeIndexFileCloseCb(fslice)) + } + ms.checkSlices = 1 + } + } + if err == nil { + // Apply message limits (no need to check if there are limits + // defined, the call won't do anything if they aren't). + err = ms.enforceLimits(false, true) + } + } + if err == nil { + ms.Lock() + ms.allDone.Add(1) + // Capture the time here first, it will then be captured + // in the go routine we are about to start. + ms.timeTick = time.Now().UnixNano() + // On recovery, if there is age limit set and at least one message... + if doRecover { + if ms.limits.MaxAge > 0 && ms.totalCount > 0 { + // Force the execution of the expireMsgs method. + // This will take care of expiring messages that should have + // expired while the server was stopped. + ms.expireMsgs(ms.timeTick, int64(ms.limits.MaxAge)) + } + // Now that we are done with recovery, close the write slice + if ms.writeSlice != nil { + ms.fm.closeFileIfOpened(ms.writeSlice.file) + ms.fm.closeFileIfOpened(ms.writeSlice.idxFile) + } + } + // Start the background tasks go routine + go ms.backgroundTasks() + ms.Unlock() + } + // Cleanup on error + if err != nil { + // The buffer writer may not be fully set yet + if ms.bw != nil && ms.bw.buf == nil { + ms.bw = nil + } + ms.Close() + ms = nil + action := "create" + if doRecover { + action = "recover" + } + err = fmt.Errorf("unable to %s message store for [%s]: %v", action, channel, err) + return nil, err + } + + return ms, nil +} + +// beforeDataFileCloseCb returns a beforeFileClose callback to be used +// by FileMsgStore's files when a data file for that slice is being closed. +// This is invoked asynchronously and should not acquire the store's lock. +// That being said, we have the guarantee that this will be not be invoked +// concurrently for a given file and that the store will not be using this file. +func (ms *FileMsgStore) beforeDataFileCloseCb(fslice *fileSlice) beforeFileClose { + return func() error { + if fslice != ms.writeSlice { + return nil + } + if ms.bw != nil && ms.bw.buf != nil && ms.bw.buf.Buffered() > 0 { + if err := ms.bw.buf.Flush(); err != nil { + return err + } + } + if ms.fstore.opts.DoSync { + if err := fslice.file.handle.Sync(); err != nil { + return err + } + } + ms.writer = nil + return nil + } +} + +// beforeIndexFileCloseCb returns a beforeFileClose callback to be used +// by FileMsgStore's files when an index file for that slice is being closed. +// This is invoked asynchronously and should not acquire the store's lock. +// That being said, we have the guarantee that this will be not be invoked +// concurrently for a given file and that the store will not be using this file. +func (ms *FileMsgStore) beforeIndexFileCloseCb(fslice *fileSlice) beforeFileClose { + return func() error { + if fslice != ms.writeSlice { + return nil + } + if len(ms.bufferedMsgs) > 0 { + if err := ms.processBufferedMsgs(fslice); err != nil { + return err + } + } + if ms.fstore.opts.DoSync { + if err := fslice.idxFile.handle.Sync(); err != nil { + return err + } + } + return nil + } +} + +// setFile sets the current data and index file. +// The buffered writer is recreated. +func (ms *FileMsgStore) setFile(fslice *fileSlice, offset int64) error { + var err error + file := fslice.file.handle + ms.writer = file + if file != nil && ms.bw != nil { + ms.writer = ms.bw.createNewWriter(file) + } + if offset == -1 { + ms.wOffset, err = file.Seek(0, io.SeekEnd) + } else { + ms.wOffset = offset + } + return err +} + +func (ms *FileMsgStore) doLockFiles(fslice *fileSlice, onlyIndexFile bool) error { + var datWasOpened, idxWasOpened bool + var err error + + if !onlyIndexFile { + datWasOpened, err = ms.fm.lockFile(fslice.file) + if err != nil { + return err + } + } + idxWasOpened, err = ms.fm.lockFile(fslice.idxFile) + if err != nil { + if !datWasOpened { + ms.fm.unlockFile(fslice.file) + } + return err + } + if !onlyIndexFile { + // We need to reset writer/offset only if the data file is opened + // in this call and it is the slice to which we are currently + // writing to. + if fslice == ms.writeSlice && !datWasOpened { + err = ms.setFile(fslice, -1) + } + } + // If we try to limit FDs use or simply not the write slice, then + // we need to notify the background task code that it should + // try to close unused slices. + if ms.hasFDsLimit || fslice != ms.writeSlice { + if !datWasOpened || !idxWasOpened { + atomic.StoreInt64(&ms.checkSlices, 1) + } + if fslice.lastUsed == 0 { + fslice.lastUsed = atomic.LoadInt64(&ms.timeTick) + } else { + fslice.lastUsed++ + } + } + return err +} + +// lockFiles locks the data and index files of the given file slice. +// If files were closed they are opened in this call, and if so, +// and if this slice is the write slice, the writer and offset are reset. +func (ms *FileMsgStore) lockFiles(fslice *fileSlice) error { + return ms.doLockFiles(fslice, false) +} + +// lockIndexFile locks the index file of the given file slice. +// If the file was closed it is opened in this call. +func (ms *FileMsgStore) lockIndexFile(fslice *fileSlice) error { + return ms.doLockFiles(fslice, true) +} + +// unlockIndexFile unlocks the already locked index file of the given file slice. +func (ms *FileMsgStore) unlockIndexFile(fslice *fileSlice) { + ms.fm.unlockFile(fslice.idxFile) +} + +// unlockFiles unlocks both data and index files of the given file slice. +func (ms *FileMsgStore) unlockFiles(fslice *fileSlice) { + ms.fm.unlockFile(fslice.file) + ms.fm.unlockFile(fslice.idxFile) +} + +// closeLockedFiles (unlocks and) closes the files of the given file slice. +func (ms *FileMsgStore) closeLockedFiles(fslice *fileSlice) error { + err := ms.fm.closeLockedFile(fslice.file) + if idxErr := ms.fm.closeLockedFile(fslice.idxFile); idxErr != nil && err == nil { + err = idxErr + } + return err +} + +// recovers one of the file +func (ms *FileMsgStore) recoverOneMsgFile(fslice *fileSlice, fseq int, useIdxFile bool) error { + var err error + + msgSize := 0 + var msg *pb.MsgProto + var mindex *msgIndex + var seq uint64 + + // Select which file to recover based on presence of index file + file := fslice.file + if useIdxFile { + file = fslice.idxFile + } + + // Create a buffered reader to speed-up recovery + br := bufio.NewReaderSize(file.handle, defaultBufSize) + + // The first record starts after the file version record + offset := int64(4) + + if useIdxFile { + var ( + lastIndex *msgIndex + lastSeq uint64 + ) + for { + seq, mindex, err = ms.readIndex(br) + if err != nil { + switch err { + case io.EOF: + // We are done, reset err + err = nil + case errNeedRewind: + err = ms.fm.truncateFile(file, offset) + } + break + } + + // Update file slice + if fslice.firstSeq == 0 { + fslice.firstSeq = seq + } + fslice.lastSeq = seq + fslice.msgsCount++ + // For size, add the message record size, the record header and the size + // required for the corresponding index record. + fslice.msgsSize += uint64(mindex.msgSize + msgRecordOverhead) + if fslice.firstWrite == 0 { + fslice.firstWrite = mindex.timestamp + } + lastIndex = mindex + lastSeq = seq + offset += msgIndexRecSize + } + if err == nil { + if lastIndex != nil { + err = ms.ensureLastMsgAndIndexMatch(fslice, lastSeq, lastIndex) + if err != nil { + ms.fstore.log.Errorf(err.Error()) + if _, serr := fslice.file.handle.Seek(4, io.SeekStart); serr != nil { + panic(fmt.Errorf("File %q: unable to set position to beginning of file: %v", fslice.file.name, serr)) + } + } + } else { + // Nothing recovered from the index file, try to recover + // from data file in case it is not empty. + useIdxFile = false + } + } + // We can get an error either because the index file was corrupted, + // or because the data file is. In both case, we truncate the index + // file and recover from data file. The handling of unexpected EOF + // is handled in the data file recovery down below. + if err != nil { + ms.fstore.log.Errorf("Error with index file %q: %v. Truncating and recovering from data file", fslice.idxFile.name, err) + if terr := ms.fm.truncateFile(fslice.idxFile, 4); terr != nil { + panic(fmt.Errorf("Error during recovery of file %q: %v, you need "+ + "to manually remove index file %q (truncate failed with err: %v)", + fslice.file.name, err, fslice.idxFile.name, terr)) + } + fslice.firstSeq = 0 + fslice.lastSeq = 0 + fslice.msgsCount = 0 + fslice.msgsSize = 0 + fslice.firstWrite = 0 + file = fslice.file + err = nil + useIdxFile = false + } + } + // No `else` here because in case of error recovering index file, we will do data file recovery + if !useIdxFile { + // Get these from the file store object + crcTable := ms.fstore.crcTable + doCRC := ms.fstore.opts.DoCRC + + // Create a buffered reader from the data file to speed-up recovery + br := bufio.NewReaderSize(fslice.file.handle, defaultBufSize) + + // We are going to write the index file while recovering the data file + bw := bufio.NewWriterSize(fslice.idxFile.handle, msgIndexRecSize*1000) + + for { + ms.tmpMsgBuf, msgSize, _, err = readRecord(br, ms.tmpMsgBuf, false, crcTable, doCRC) + if err != nil { + switch err { + case io.EOF: + // We are done, reset err + err = nil + case errNeedRewind: + err = ms.fm.truncateFile(file, offset) + default: + err = ms.fstore.handleUnexpectedEOF(err, file, offset, false) + } + break + } + + // Recover this message + msg = &pb.MsgProto{} + err = msg.Unmarshal(ms.tmpMsgBuf[:msgSize]) + if err != nil { + break + } + + if fslice.firstSeq == 0 { + fslice.firstSeq = msg.Sequence + } + fslice.lastSeq = msg.Sequence + fslice.msgsCount++ + // For size, add the message record size, the record header and the size + // required for the corresponding index record. + fslice.msgsSize += uint64(msgSize + msgRecordOverhead) + if fslice.firstWrite == 0 { + fslice.firstWrite = msg.Timestamp + } + + // There was no index file, update it + err = ms.writeIndex(bw, msg.Sequence, offset, msg.Timestamp, msgSize) + if err != nil { + break + } + // Move offset + offset += int64(recordHeaderSize + msgSize) + } + if err == nil { + err = bw.Flush() + if err == nil { + err = fslice.idxFile.handle.Sync() + } + } + // Since there was no index and there was an error, remove the index + // file so when server restarts, it recovers again from the data file. + if err != nil { + // Close the index file + ms.fm.closeLockedFile(fslice.idxFile) + // Remove form store's map + ms.fm.remove(fslice.idxFile) + // Remove it, and panic if we can't + if rmErr := os.Remove(fslice.idxFile.name); rmErr != nil { + panic(fmt.Errorf("Error during recovery of file %q: %v, you need "+ + "to manually remove index file %q (remove failed with err: %v)", + fslice.file.name, err, fslice.idxFile.name, rmErr)) + } + // Close the data file + ms.fm.closeLockedFile(fslice.file) + return err + } + } + + // Close the files + ms.fm.closeLockedFile(fslice.file) + ms.fm.closeLockedFile(fslice.idxFile) + + // If no error and slice is not empty... + if fslice.msgsCount > 0 { + if ms.first == 0 || ms.first > fslice.firstSeq { + ms.first = fslice.firstSeq + } + if ms.last < fslice.lastSeq { + ms.last = fslice.lastSeq + } + ms.totalCount += fslice.msgsCount + ms.totalBytes += fslice.msgsSize + + // On success, add to the map of file slices and + // update first/last file slice sequence. + ms.files[fseq] = fslice + if ms.firstFSlSeq == 0 || ms.firstFSlSeq > fseq { + ms.firstFSlSeq = fseq + } + if ms.lastFSlSeq < fseq { + ms.lastFSlSeq = fseq + } + return nil + } + // Slice was empty and not recovered. Need to remove those from store's files manager. + ms.fm.remove(fslice.file) + ms.fm.remove(fslice.idxFile) + return nil +} + +func (ms *FileMsgStore) ensureLastMsgAndIndexMatch(fslice *fileSlice, seq uint64, index *msgIndex) error { + var ( + msgSize int + err error + startErr = fmt.Sprintf("Verification of last message for file %q failed", fslice.file.name) + ) + fd := fslice.file.handle + // Position for the last record + if _, err := fd.Seek(index.offset, io.SeekStart); err != nil { + return fmt.Errorf("%s: unable to set position to %v", startErr, index.offset) + } + ms.tmpMsgBuf, msgSize, _, err = readRecord(fd, ms.tmpMsgBuf, false, ms.fstore.crcTable, true) + if err != nil { + return fmt.Errorf("%s: unable to read last record: %v", startErr, err) + } + if uint32(msgSize) != index.msgSize { + return fmt.Errorf("%s: last message size in index is %v, data file is %v", + startErr, index.msgSize, msgSize) + } + // Recover this message + msg := &pb.MsgProto{} + if err := msg.Unmarshal(ms.tmpMsgBuf[:msgSize]); err != nil { + return fmt.Errorf("%s: error decoding message: %v", startErr, err) + } + if msg.Sequence != seq { + return fmt.Errorf("%s: last message sequence in index is %v, data file is %v", + startErr, seq, msg.Sequence) + } + return nil +} + +// setSliceLimits sets the limits of a file slice based on options and/or +// channel limits. +func (ms *FileMsgStore) setSliceLimits() { + // First set slice limits based on slice configuration. + ms.slCountLim = ms.fstore.opts.SliceMaxMsgs + ms.slSizeLim = uint64(ms.fstore.opts.SliceMaxBytes) + ms.slAgeLim = int64(ms.fstore.opts.SliceMaxAge) + // Did we configure any of the "dimension"? + ms.slHasLimits = ms.slCountLim > 0 || ms.slSizeLim > 0 || ms.slAgeLim > 0 + + // If so, we are done. We will use those limits to decide + // when to move to a new slice. + if ms.slHasLimits { + return + } + + // Slices limits were not configured. We will set a limit based on channel limits. + if ms.limits.MaxMsgs > 0 { + limit := ms.limits.MaxMsgs / 4 + if limit == 0 { + limit = 1 + } + ms.slCountLim = limit + } + if ms.limits.MaxBytes > 0 { + limit := uint64(ms.limits.MaxBytes) / 4 + if limit == 0 { + limit = 1 + } + ms.slSizeLim = limit + } + if ms.limits.MaxAge > 0 { + limit := time.Duration(int64(ms.limits.MaxAge) / 4) + if limit < time.Second { + limit = time.Second + } + ms.slAgeLim = int64(limit) + } + // Refresh our view of slices having limits. + ms.slHasLimits = ms.slCountLim > 0 || ms.slSizeLim > 0 || ms.slAgeLim > 0 +} + +// writeIndex writes a message index record to the writer `w` +func (ms *FileMsgStore) writeIndex(w io.Writer, seq uint64, offset, timestamp int64, msgSize int) error { + _buf := [msgIndexRecSize]byte{} + buf := _buf[:] + ms.addIndex(buf, seq, offset, timestamp, msgSize) + _, err := w.Write(buf[:msgIndexRecSize]) + return err +} + +// addIndex adds a message index record in the given buffer +func (ms *FileMsgStore) addIndex(buf []byte, seq uint64, offset, timestamp int64, msgSize int) { + util.ByteOrder.PutUint64(buf, seq) + util.ByteOrder.PutUint64(buf[8:], uint64(offset)) + util.ByteOrder.PutUint64(buf[16:], uint64(timestamp)) + util.ByteOrder.PutUint32(buf[24:], uint32(msgSize)) + crc := crc32.Checksum(buf[:msgIndexRecSize-crcSize], ms.fstore.crcTable) + util.ByteOrder.PutUint32(buf[msgIndexRecSize-crcSize:], crc) +} + +// readIndex reads a message index record from the given reader +// and returns an allocated msgIndex object. +func (ms *FileMsgStore) readIndex(r io.Reader) (uint64, *msgIndex, error) { + _buf := [msgIndexRecSize]byte{} + buf := _buf[:] + if _, err := io.ReadFull(r, buf); err != nil { + return 0, nil, err + } + mindex := &msgIndex{} + seq := util.ByteOrder.Uint64(buf) + mindex.offset = int64(util.ByteOrder.Uint64(buf[8:])) + mindex.timestamp = int64(util.ByteOrder.Uint64(buf[16:])) + mindex.msgSize = util.ByteOrder.Uint32(buf[24:]) + // If all zeros, return that caller should rewind (for recovery) + if seq == 0 && mindex.offset == 0 && mindex.timestamp == 0 && mindex.msgSize == 0 { + storedCRC := util.ByteOrder.Uint32(buf[msgIndexRecSize-crcSize:]) + if storedCRC == 0 { + return 0, nil, errNeedRewind + } + } + if ms.fstore.opts.DoCRC { + storedCRC := util.ByteOrder.Uint32(buf[msgIndexRecSize-crcSize:]) + crc := crc32.Checksum(buf[:msgIndexRecSize-crcSize], ms.fstore.crcTable) + if storedCRC != crc { + return 0, nil, fmt.Errorf("corrupted data, expected crc to be 0x%08x, got 0x%08x", storedCRC, crc) + } + } + return seq, mindex, nil +} + +// Store a given message. +func (ms *FileMsgStore) Store(m *pb.MsgProto) (uint64, error) { + ms.Lock() + defer ms.Unlock() + + if m.Sequence <= ms.last { + // We've already seen this message. + return m.Sequence, nil + } + + fslice := ms.writeSlice + if fslice != nil { + if err := ms.lockFiles(fslice); err != nil { + return 0, err + } + } + + // Is there a gap in message sequence? + if ms.last > 0 && m.Sequence > ms.last+1 { + if err := ms.fillGaps(fslice, m); err != nil { + ms.unlockFiles(fslice) + return 0, err + } + } + + // Check if we need to move to next file slice + if fslice == nil || ms.slHasLimits { + if fslice == nil || + (ms.slSizeLim > 0 && fslice.msgsSize >= ms.slSizeLim) || + (ms.slCountLim > 0 && fslice.msgsCount >= ms.slCountLim) || + (ms.slAgeLim > 0 && atomic.LoadInt64(&ms.timeTick)-fslice.firstWrite >= ms.slAgeLim) { + + // Don't change store variable until success... + newSliceSeq := ms.lastFSlSeq + 1 + + // Close the current file slice (if applicable) and open the next slice + if fslice != nil { + if err := ms.closeLockedFiles(fslice); err != nil { + return 0, err + } + } + // Create new slice + datFName := filepath.Join(ms.channelName, fmt.Sprintf("%s%v%s", msgFilesPrefix, newSliceSeq, datSuffix)) + idxFName := filepath.Join(ms.channelName, fmt.Sprintf("%s%v%s", msgFilesPrefix, newSliceSeq, idxSuffix)) + datFile, err := ms.fm.createFile(datFName, defaultFileFlags, nil) + if err != nil { + return 0, err + } + idxFile, err := ms.fm.createFile(idxFName, defaultFileFlags, nil) + if err != nil { + ms.fm.closeLockedFile(datFile) + ms.fm.remove(datFile) + return 0, err + } + // Success, update the store's variables + newSlice := &fileSlice{ + file: datFile, + idxFile: idxFile, + lastUsed: atomic.LoadInt64(&ms.timeTick), + } + ms.fm.setBeforeCloseCb(datFile, ms.beforeDataFileCloseCb(newSlice)) + ms.fm.setBeforeCloseCb(idxFile, ms.beforeIndexFileCloseCb(newSlice)) + ms.files[newSliceSeq] = newSlice + ms.writeSlice = newSlice + if ms.firstFSlSeq == 0 { + ms.firstFSlSeq = newSliceSeq + } + ms.lastFSlSeq = newSliceSeq + ms.setFile(newSlice, 4) + + // If we added a second slice and the first slice was empty but not removed + // because it was the only one, we remove it now. + if len(ms.files) == 2 && fslice.msgsCount == fslice.rmCount { + ms.removeFirstSlice() + } + // Update the fslice reference to new slice for rest of function + fslice = ms.writeSlice + } + } + + // !! IMPORTANT !! + // We want to reduce use of defer in functions that are in the fast path, + // so after this point, on error, use goto processErr instead of return. + // It means that we should not use local errors like this: + // if err := this(); err != nil { + // goto processErr + // } + + seq := m.Sequence + + msgInBuffer := false + + var recSize int + var err error + var mindex *msgIndex + var size uint64 + + var bwBuf *bufio.Writer + if ms.bw != nil { + bwBuf = ms.bw.buf + } + msgSize := m.Size() + if bwBuf != nil { + required := msgSize + recordHeaderSize + if required > bwBuf.Available() { + ms.writer, err = ms.bw.expand(fslice.file.handle, required) + if err != nil { + goto processErr + } + if len(ms.bufferedMsgs) > 0 { + err = ms.processBufferedMsgs(fslice) + if err != nil { + goto processErr + } + } + // Refresh this since it has changed. + bwBuf = ms.bw.buf + } + } + ms.tmpMsgBuf, recSize, err = writeRecord(ms.writer, ms.tmpMsgBuf, recNoType, m, msgSize, ms.fstore.crcTable) + if err != nil { + goto processErr + } + if bwBuf != nil { + // Check to see if we should cancel a buffer shrink request + if ms.bw.shrinkReq { + ms.bw.checkShrinkRequest() + } + // If message was added to the buffer we need to also save a reference + // to that message outside of the cache, until the buffer is flushed. + if bwBuf.Buffered() >= recSize { + ms.bufferedSeqs = append(ms.bufferedSeqs, seq) + mindex = &msgIndex{offset: ms.wOffset, timestamp: m.Timestamp, msgSize: uint32(msgSize)} + ms.bufferedMsgs[seq] = &bufferedMsg{msg: m, index: mindex} + msgInBuffer = true + } + } + // Message was flushed to disk, write corresponding index + if !msgInBuffer { + err = ms.writeIndex(fslice.idxFile.handle, seq, ms.wOffset, m.Timestamp, msgSize) + if err != nil { + goto processErr + } + } + + if ms.first == 0 || ms.first == seq { + // First ever message or after all messages expired and this is the + // first new message. + ms.first = seq + ms.firstMsg = m + if maxAge := ms.limits.MaxAge; maxAge > 0 { + ms.expiration = m.Timestamp + int64(maxAge) + if len(ms.bkgTasksWake) == 0 { + ms.bkgTasksWake <- true + } + } + } + ms.last = seq + ms.lastMsg = m + ms.cache.add(seq, m, true) + ms.wOffset += int64(recSize) + + // For size, add the message record size, the record header and the size + // required for the corresponding index record. + size = uint64(msgSize + msgRecordOverhead) + + // Total stats + ms.totalCount++ + ms.totalBytes += size + + // Stats per file slice + fslice.msgsCount++ + fslice.msgsSize += size + if fslice.firstWrite == 0 { + fslice.firstWrite = m.Timestamp + } + + // Save references to first and last sequences for this slice + if fslice.firstSeq == 0 { + fslice.firstSeq = seq + } + fslice.lastSeq = seq + + if ms.limits.MaxMsgs > 0 || ms.limits.MaxBytes > 0 { + // Enfore limits and update file slice if needed. + err = ms.enforceLimits(true, false) + if err != nil { + goto processErr + } + } + ms.unlockFiles(fslice) + return seq, nil + +processErr: + ms.unlockFiles(fslice) + return 0, err +} + +func (ms *FileMsgStore) fillGaps(fslice *fileSlice, upToMsg *pb.MsgProto) error { + // flush possible buffered messages. + if err := ms.flush(fslice); err != nil { + return err + } + + var ( + recSize int + err error + msgSize int + ) + + ms.lastMsg = nil + emptyMsg := &pb.MsgProto{ + Subject: ms.channelName, + Timestamp: upToMsg.Timestamp, + } + for i := ms.last + 1; i < upToMsg.Sequence; i++ { + emptyMsg.Sequence = i + msgSize = emptyMsg.Size() + ms.tmpMsgBuf, recSize, err = writeRecord(fslice.file.handle, ms.tmpMsgBuf, recNoType, emptyMsg, msgSize, ms.fstore.crcTable) + if err != nil { + return err + } + if err := ms.writeIndex(fslice.idxFile.handle, i, ms.wOffset, emptyMsg.Timestamp, msgSize); err != nil { + return err + } + ms.wOffset += int64(recSize) + ms.last++ + ms.totalCount++ + size := uint64(msgSize + msgRecordOverhead) + ms.totalBytes += size + fslice.lastSeq = i + fslice.msgsCount++ + fslice.msgsSize += size + } + return nil +} + +// processBufferedMsgs adds message index records in the given buffer +// for every pending buffered messages. +func (ms *FileMsgStore) processBufferedMsgs(fslice *fileSlice) error { + idxBufferSize := len(ms.bufferedMsgs) * msgIndexRecSize + ms.tmpMsgBuf = util.EnsureBufBigEnough(ms.tmpMsgBuf, idxBufferSize) + bufOffset := 0 + for _, pseq := range ms.bufferedSeqs { + bm := ms.bufferedMsgs[pseq] + if bm != nil { + mindex := bm.index + // We add the index info for this flushed message + ms.addIndex(ms.tmpMsgBuf[bufOffset:], pseq, mindex.offset, + mindex.timestamp, int(mindex.msgSize)) + bufOffset += msgIndexRecSize + delete(ms.bufferedMsgs, pseq) + } + } + if bufOffset > 0 { + if _, err := fslice.idxFile.handle.Write(ms.tmpMsgBuf[:bufOffset]); err != nil { + return err + } + } + ms.bufferedSeqs = ms.bufferedSeqs[:0] + return nil +} + +// expireMsgs ensures that messages don't stay in the log longer than the +// limit's MaxAge. +// Returns the time of the next expiration (possibly 0 if no message left) +// The store's lock is assumed to be held on entry +func (ms *FileMsgStore) expireMsgs(now, maxAge int64) int64 { + if ms.first == 0 { + ms.expiration = 0 + return ms.expiration + } + var m *msgIndex + var slice *fileSlice + for { + m = nil + if ms.first <= ms.last { + if slice == nil || ms.first > slice.lastSeq { + // If slice is not nil, it means that we have expired all + // messages belong to that slice, and the slice itslef. + // So there is no need to unlock it since this has already + // been done. + slice = ms.getFileSliceForSeq(ms.first) + if slice == nil { + // If we did not find a slice for this sequence, it could + // be cause there is a gap in message sequence due to + // file truncation following unexpected EOF on recovery. + // So set the first seq to the first sequence of the now + // first slice. + slice = ms.files[ms.firstFSlSeq] + if slice != nil { + ms.first = slice.firstSeq + } + } + if slice != nil { + if err := ms.lockIndexFile(slice); err != nil { + slice = nil + break + } + } + } + if slice != nil { + m = ms.getMsgIndex(slice, ms.first) + } + } + if m == nil { + ms.expiration = 0 + break + } + elapsed := now - m.timestamp + if elapsed >= maxAge { + ms.removeFirstMsg(m, false) + } else { + if elapsed < 0 { + ms.expiration = m.timestamp + maxAge + } else { + ms.expiration = now + (maxAge - elapsed) + } + break + } + } + if slice != nil { + ms.unlockIndexFile(slice) + } + return ms.expiration +} + +// enforceLimits checks total counts with current msg store's limits, +// removing a file slice and/or updating slices' count as necessary. +func (ms *FileMsgStore) enforceLimits(reportHitLimit, lockFile bool) error { + // Check if we need to remove any (but leave at least the last added). + // Note that we may have to remove more than one msg if we are here + // after a restart with smaller limits than originally set, or if + // message is quite big, etc... + maxMsgs := ms.limits.MaxMsgs + maxBytes := ms.limits.MaxBytes + for ms.totalCount > 1 && + ((maxMsgs > 0 && ms.totalCount > maxMsgs) || + (maxBytes > 0 && ms.totalBytes > uint64(maxBytes))) { + + // Remove first message from first slice, potentially removing + // the slice, etc... + ms.removeFirstMsg(nil, lockFile) + if reportHitLimit && !ms.hitLimit { + ms.hitLimit = true + ms.log.Noticef(droppingMsgsFmt, ms.subject, ms.totalCount, ms.limits.MaxMsgs, + util.FriendlyBytes(int64(ms.totalBytes)), util.FriendlyBytes(ms.limits.MaxBytes)) + } + } + return nil +} + +// getMsgIndex returns a msgIndex object for message with sequence `seq`, +// or nil if message is not found (or no longer valid: expired, removed +// due to limits, etc). +// This call first checks that the record is not present in +// ms.bufferedMsgs since it is possible that message and index are not +// yet stored on disk. +func (ms *FileMsgStore) getMsgIndex(slice *fileSlice, seq uint64) *msgIndex { + bm := ms.bufferedMsgs[seq] + if bm != nil { + return bm.index + } + return ms.readMsgIndex(slice, seq) +} + +// readMsgIndex reads a message index record from disk and returns a msgIndex +// object. Same than getMsgIndex but without checking for message in +// ms.bufferedMsgs first. +func (ms *FileMsgStore) readMsgIndex(slice *fileSlice, seq uint64) *msgIndex { + // Compute the offset in the index file itself. + idxFileOffset := 4 + (int64(seq-slice.firstSeq)+int64(slice.rmCount))*msgIndexRecSize + // Then position the file pointer of the index file. + if _, err := slice.idxFile.handle.Seek(idxFileOffset, io.SeekStart); err != nil { + return nil + } + // Read the index record and ensure we have what we expect + seqInIndexFile, msgIndex, err := ms.readIndex(slice.idxFile.handle) + if seqInIndexFile != seq || err != nil { + return nil + } + return msgIndex +} + +// removeFirstMsg "removes" the first message of the first slice. +// If the slice is "empty" the file slice is removed. +func (ms *FileMsgStore) removeFirstMsg(mindex *msgIndex, lockFile bool) { + // Work with the first slice + slice := ms.files[ms.firstFSlSeq] + // Get the message index for the first valid message in this slice + if mindex == nil { + if lockFile || slice != ms.writeSlice { + ms.lockIndexFile(slice) + } + mindex = ms.getMsgIndex(slice, slice.firstSeq) + if lockFile || slice != ms.writeSlice { + ms.unlockIndexFile(slice) + } + } + // Size of the first message in this slice + firstMsgSize := mindex.msgSize + // For size, we count the size of serialized message + record header + + // the corresponding index record + size := uint64(firstMsgSize + msgRecordOverhead) + // Keep track of number of "removed" messages in this slice + slice.rmCount++ + // Update total counts + ms.totalCount-- + ms.totalBytes -= size + // Messages sequence is incremental with no gap on a given msgstore. + ms.first++ + // Invalidate ms.firstMsg, it will be looked-up on demand. + ms.firstMsg = nil + // Invalidate ms.lastMsg if it was the last message being removed. + if ms.first > ms.last { + ms.lastMsg = nil + } + // Is file slice is "empty" and not the last one + if slice.msgsCount == slice.rmCount && len(ms.files) > 1 { + ms.removeFirstSlice() + } else { + // This is the new first message in this slice. + slice.firstSeq = ms.first + } +} + +// removeFirstSlice removes the first file slice. +// Should not be called if first slice is also last! +func (ms *FileMsgStore) removeFirstSlice() { + sl := ms.files[ms.firstFSlSeq] + // We may or may not have the first slice locked, so need to close + // the file knowing that files can be in either state. + ms.fm.closeLockedOrOpenedFile(sl.file) + ms.fm.remove(sl.file) + // Close index file too. + ms.fm.closeLockedOrOpenedFile(sl.idxFile) + ms.fm.remove(sl.idxFile) + // Assume we will remove the files + remove := true + // If there is an archive script invoke it first + script := ms.fstore.opts.SliceArchiveScript + if script != "" { + datBak := sl.file.name + bakSuffix + idxBak := sl.idxFile.name + bakSuffix + + var err error + if err = os.Rename(sl.file.name, datBak); err == nil { + if err = os.Rename(sl.idxFile.name, idxBak); err != nil { + // Remove first backup file + os.Remove(datBak) + } + } + if err == nil { + // Files have been successfully renamed, so don't attempt + // to remove the original files. + remove = false + + // We run the script in a go routine to not block the server. + ms.allDone.Add(1) + go func(subj, dat, idx string) { + defer ms.allDone.Done() + cmd := exec.Command(script, subj, dat, idx) + output, err := cmd.CombinedOutput() + if err != nil { + ms.log.Noticef("Error invoking archive script %q: %v (output=%v)", script, err, string(output)) + } else { + ms.log.Noticef("Output of archive script for %s (%s and %s): %v", subj, dat, idx, string(output)) + } + }(ms.subject, datBak, idxBak) + } + } + // Remove files + if remove { + os.Remove(sl.file.name) + os.Remove(sl.idxFile.name) + } + // Remove slice from map + delete(ms.files, ms.firstFSlSeq) + // Normally, file slices have an incremental sequence number with + // no gap. However, we want to support the fact that an user could + // copy back some old file slice to be recovered, and so there + // may be a gap. So find out what is the new first file sequence. + for ms.firstFSlSeq < ms.lastFSlSeq { + ms.firstFSlSeq++ + if _, ok := ms.files[ms.firstFSlSeq]; ok { + break + } + } + // This should not happen! + if ms.firstFSlSeq > ms.lastFSlSeq { + panic("Removed last slice!") + } +} + +// getFileSliceForSeq returns the file slice where the message of the +// given sequence is stored, or nil if the message is not found in any +// of the file slices. +func (ms *FileMsgStore) getFileSliceForSeq(seq uint64) *fileSlice { + if len(ms.files) == 0 { + return nil + } + // Start with write slice + slice := ms.writeSlice + if (slice.firstSeq <= seq) && (seq <= slice.lastSeq) { + return slice + } + // We want to support possible gaps in file slice sequence, so + // no dichotomy, but simple iteration of the map, which in Go is + // random. + for _, slice := range ms.files { + if (slice.firstSeq <= seq) && (seq <= slice.lastSeq) { + return slice + } + } + return nil +} + +// backgroundTasks performs some background tasks related to this +// messages store. +func (ms *FileMsgStore) backgroundTasks() { + defer ms.allDone.Done() + + ms.RLock() + hasBuffer := ms.bw != nil + maxAge := int64(ms.limits.MaxAge) + nextExpiration := ms.expiration + lastCacheCheck := ms.timeTick + lastBufShrink := ms.timeTick + ms.RUnlock() + + for { + // Update time + timeTick := time.Now().UnixNano() + atomic.StoreInt64(&ms.timeTick, timeTick) + + // Close unused file slices + if atomic.LoadInt64(&ms.checkSlices) == 1 { + ms.Lock() + opened := 0 + for _, slice := range ms.files { + // If no FD limit and this is the write slice, skip. + if !ms.hasFDsLimit && slice == ms.writeSlice { + continue + } + opened++ + if slice.lastUsed > 0 && time.Duration(timeTick-slice.lastUsed) >= sliceCloseInterval { + slice.lastUsed = 0 + ms.fm.closeFileIfOpened(slice.file) + ms.fm.closeFileIfOpened(slice.idxFile) + opened-- + } + } + if opened == 0 { + // We can update this without atomic since we are under store lock + // and this go routine is the only place where we check the value. + ms.checkSlices = 0 + } + ms.Unlock() + } + + // Shrink the buffer if applicable + if hasBuffer && time.Duration(timeTick-lastBufShrink) >= bufShrinkInterval { + ms.Lock() + if ms.writeSlice != nil { + file := ms.writeSlice.file + if ms.fm.lockFileIfOpened(file) { + ms.writer, _ = ms.bw.tryShrinkBuffer(file.handle) + ms.fm.unlockFile(file) + } + } + ms.Unlock() + lastBufShrink = timeTick + } + + // Check for expiration + if maxAge > 0 && nextExpiration > 0 && timeTick >= nextExpiration { + ms.Lock() + // Expire messages + nextExpiration = ms.expireMsgs(timeTick, maxAge) + ms.Unlock() + } + + // Check for message caching + if timeTick >= lastCacheCheck+cacheTTL { + tryEvict := atomic.LoadInt32(&ms.cache.tryEvict) + if tryEvict == 1 { + ms.Lock() + // Possibly remove some/all cached messages + ms.cache.evict(timeTick) + ms.Unlock() + } + lastCacheCheck = timeTick + } + + select { + case <-ms.bkgTasksDone: + return + case <-ms.bkgTasksWake: + // wake up from a possible sleep to run the loop + ms.RLock() + nextExpiration = ms.expiration + ms.RUnlock() + case <-time.After(bkgTasksSleepDuration): + // go back to top of for loop. + } + } +} + +// lookup returns the message for the given sequence number, possibly +// reading the message from disk. +// Store write lock is assumed to be held on entry +func (ms *FileMsgStore) lookup(seq uint64) (*pb.MsgProto, error) { + // Reject message for sequence outside valid range + if seq < ms.first || seq > ms.last { + return nil, nil + } + // Check first if it's in the cache. + msg := ms.cache.get(seq) + if msg == nil && ms.bufferedMsgs != nil { + // Possibly in bufferedMsgs + bm := ms.bufferedMsgs[seq] + if bm != nil { + msg = bm.msg + ms.cache.add(seq, msg, false) + } + } + // If not, we need to read it from disk... + if msg == nil { + fslice := ms.getFileSliceForSeq(seq) + if fslice == nil { + return nil, nil + } + err := ms.lockFiles(fslice) + if err != nil { + return nil, err + } + msgIndex := ms.readMsgIndex(fslice, seq) + if msgIndex != nil { + file := fslice.file.handle + // Position file to message's offset. 0 means from start. + _, err = file.Seek(msgIndex.offset, io.SeekStart) + if err == nil { + ms.tmpMsgBuf, _, _, err = readRecord(file, ms.tmpMsgBuf, false, ms.fstore.crcTable, ms.fstore.opts.DoCRC) + } + } + ms.unlockFiles(fslice) + if err != nil || msgIndex == nil { + return nil, err + } + // Recover this message + msg = &pb.MsgProto{} + err = msg.Unmarshal(ms.tmpMsgBuf[:msgIndex.msgSize]) + if err != nil { + return nil, err + } + ms.cache.add(seq, msg, false) + } + return msg, nil +} + +// Lookup returns the stored message with given sequence number. +func (ms *FileMsgStore) Lookup(seq uint64) (*pb.MsgProto, error) { + ms.Lock() + msg, err := ms.lookup(seq) + ms.Unlock() + return msg, err +} + +// FirstMsg returns the first message stored. +func (ms *FileMsgStore) FirstMsg() (*pb.MsgProto, error) { + var err error + ms.RLock() + if ms.firstMsg == nil { + ms.firstMsg, err = ms.lookup(ms.first) + } + m := ms.firstMsg + ms.RUnlock() + return m, err +} + +// LastMsg returns the last message stored. +func (ms *FileMsgStore) LastMsg() (*pb.MsgProto, error) { + var err error + ms.RLock() + if ms.lastMsg == nil { + ms.lastMsg, err = ms.lookup(ms.last) + } + m := ms.lastMsg + ms.RUnlock() + return m, err +} + +// GetSequenceFromTimestamp returns the sequence of the first message whose +// timestamp is greater or equal to given timestamp. +func (ms *FileMsgStore) GetSequenceFromTimestamp(timestamp int64) (uint64, error) { + ms.RLock() + defer ms.RUnlock() + + // No message ever stored + if ms.first == 0 { + return 0, nil + } + // All messages have expired + if ms.first > ms.last { + return ms.last + 1, nil + } + // If we have some state, try to quickly get the sequence + if ms.firstMsg != nil && ms.firstMsg.Timestamp >= timestamp { + return ms.first, nil + } + if ms.lastMsg != nil && timestamp >= ms.lastMsg.Timestamp { + return ms.last + 1, nil + } + + smallest := int64(-1) + // This will require disk access. + for _, slice := range ms.files { + if err := ms.lockIndexFile(slice); err != nil { + return 0, err + } + mindex := ms.getMsgIndex(slice, slice.firstSeq) + if timestamp >= mindex.timestamp { + mindex = ms.getMsgIndex(slice, slice.lastSeq) + if timestamp <= mindex.timestamp { + // Could do binary search, but will be probably more efficient + // to do sequential disk reads. The index records are small, + // so read of a record will probably bring many consecutive ones + // in the system's disk cache, resulting in memory-only access + // for the following indexes... + for seq := slice.firstSeq + 1; seq < slice.lastSeq; seq++ { + mindex = ms.getMsgIndex(slice, seq) + if mindex.timestamp >= timestamp { + ms.unlockIndexFile(slice) + return seq, nil + } + } + } + } else if smallest == -1 || mindex.timestamp < smallest { + smallest = mindex.timestamp + } + ms.unlockIndexFile(slice) + } + if timestamp < smallest { + return ms.first, nil + } + return ms.last + 1, nil +} + +// initCache initializes the message cache +func (ms *FileMsgStore) initCache() { + ms.cache = &msgsCache{ + seqMaps: make(map[uint64]*cachedMsg), + } +} + +// add adds a message to the cache. +// Store write lock is assumed held on entry +func (c *msgsCache) add(seq uint64, msg *pb.MsgProto, isNew bool) { + exp := cacheTTL + if isNew { + exp += msg.Timestamp + } else { + exp += time.Now().UnixNano() + } + cMsg := &cachedMsg{ + expiration: exp, + msg: msg, + } + if c.tail == nil { + c.head = cMsg + } else { + c.tail.next = cMsg + // Ensure last expiration is at least >= previous one. + if cMsg.expiration < c.tail.expiration { + cMsg.expiration = c.tail.expiration + } + } + cMsg.prev = c.tail + c.tail = cMsg + c.seqMaps[seq] = cMsg + if len(c.seqMaps) == 1 { + atomic.StoreInt32(&c.tryEvict, 1) + } +} + +// get returns a message if available in the cache. +// Store write lock is assumed held on entry +func (c *msgsCache) get(seq uint64) *pb.MsgProto { + cMsg := c.seqMaps[seq] + if cMsg == nil { + return nil + } + // Bump the expiration + cMsg.expiration = time.Now().UnixNano() + cacheTTL + // If not already at the tail of the list, move it there + if cMsg != c.tail { + if cMsg.prev != nil { + cMsg.prev.next = cMsg.next + } + if cMsg.next != nil { + cMsg.next.prev = cMsg.prev + } + if cMsg == c.head { + c.head = cMsg.next + } + cMsg.prev = c.tail + c.tail.next = cMsg + cMsg.next = nil + // Ensure last expiration is at least >= previous one. + if cMsg.expiration < c.tail.expiration { + cMsg.expiration = c.tail.expiration + } + c.tail = cMsg + } + return cMsg.msg +} + +// evict move down the cache maps, evicting the last one. +// Store write lock is assumed held on entry +func (c *msgsCache) evict(now int64) { + if c.head == nil { + return + } + if now >= c.tail.expiration { + // Bulk remove + c.seqMaps = make(map[uint64]*cachedMsg) + c.head, c.tail, c.tryEvict = nil, nil, 0 + return + } + cMsg := c.head + for cMsg != nil && cMsg.expiration <= now { + delete(c.seqMaps, cMsg.msg.Sequence) + cMsg = cMsg.next + } + if cMsg != c.head { + // There should be at least one left, otherwise, they + // would all have been bulk removed at top of this function. + cMsg.prev = nil + c.head = cMsg + } +} + +// empty empties the cache +func (c *msgsCache) empty() { + atomic.StoreInt32(&c.tryEvict, 0) + c.head, c.tail = nil, nil + c.seqMaps = make(map[uint64]*cachedMsg) +} + +// Close closes the store. +func (ms *FileMsgStore) Close() error { + ms.Lock() + if ms.closed { + ms.Unlock() + return nil + } + + ms.closed = true + + // Signal the background tasks go-routine to exit + ms.bkgTasksDone <- true + + ms.Unlock() + + // Wait on go routines/timers to finish + ms.allDone.Wait() + + ms.Lock() + var err error + if ms.writeSlice != nil { + // Flush current file slice where writes happen + ms.lockFiles(ms.writeSlice) + err = ms.flush(ms.writeSlice) + ms.unlockFiles(ms.writeSlice) + } + // Remove/close all file slices + for _, slice := range ms.files { + ms.fm.remove(slice.file) + ms.fm.remove(slice.idxFile) + if slice.file.handle != nil { + err = util.CloseFile(err, slice.file.handle) + } + if slice.idxFile.handle != nil { + err = util.CloseFile(err, slice.idxFile.handle) + } + } + ms.Unlock() + + return err +} + +func (ms *FileMsgStore) flush(fslice *fileSlice) error { + if ms.bw != nil && ms.bw.buf != nil && ms.bw.buf.Buffered() > 0 { + if err := ms.bw.buf.Flush(); err != nil { + return err + } + } + // This used to be inside the above `if` statement, but now it has + // to be separate because the data file may have been closed + // (and therefore the buffer flushed) and we could still have + // buffered messages that need to be processed. + if len(ms.bufferedMsgs) > 0 { + if err := ms.processBufferedMsgs(fslice); err != nil { + return err + } + } + if ms.fstore.opts.DoSync { + if err := fslice.file.handle.Sync(); err != nil { + return err + } + if err := fslice.idxFile.handle.Sync(); err != nil { + return err + } + } + return nil +} + +// Flush flushes outstanding data into the store. +func (ms *FileMsgStore) Flush() error { + ms.Lock() + var err error + if ms.writeSlice != nil { + err = ms.lockFiles(ms.writeSlice) + if err == nil { + err = ms.flush(ms.writeSlice) + ms.unlockFiles(ms.writeSlice) + } + } + ms.Unlock() + return err +} + +// Empty implements the MsgStore interface +func (ms *FileMsgStore) Empty() error { + ms.Lock() + defer ms.Unlock() + + var err error + // Remove/close all file slices + for sliceID, slice := range ms.files { + ms.fm.remove(slice.file) + ms.fm.remove(slice.idxFile) + if slice.file.handle != nil { + err = util.CloseFile(err, slice.file.handle) + } + if lerr := os.Remove(slice.file.name); lerr != nil && err == nil { + err = lerr + } + if slice.idxFile.handle != nil { + err = util.CloseFile(err, slice.idxFile.handle) + } + if lerr := os.Remove(slice.idxFile.name); lerr != nil && err == nil { + err = lerr + } + delete(ms.files, sliceID) + } + // Reset generic counters + ms.empty() + // FileMsgStore specific + ms.writer = nil + ms.writeSlice = nil + ms.cache.empty() + ms.wOffset = 0 + ms.firstMsg, ms.lastMsg = nil, nil + ms.expiration = 0 + ms.firstFSlSeq, ms.lastFSlSeq = 0, 0 + // If we are running in buffered mode... + if ms.bw != nil { + ms.bw = newBufferWriter(msgBufMinShrinkSize, ms.fstore.opts.BufferSize) + ms.bufferedSeqs = make([]uint64, 0, 1) + ms.bufferedMsgs = make(map[uint64]*bufferedMsg) + } + return err +} + +//////////////////////////////////////////////////////////////////////////// +// FileSubStore methods +//////////////////////////////////////////////////////////////////////////// + +// newFileSubStore returns a new instace of a file SubStore. +func (fs *FileStore) newFileSubStore(channel string, limits *SubStoreLimits, doRecover bool) (*FileSubStore, error) { + ss := &FileSubStore{ + fstore: fs, + fm: fs.fm, + opts: &fs.opts, + crcTable: fs.crcTable, + } + ss.init(fs.log, limits) + // Convert the CompactInterval in time.Duration + ss.compactItvl = time.Duration(ss.opts.CompactInterval) * time.Second + + var err error + + fileName := filepath.Join(channel, subsFileName) + ss.file, err = fs.fm.createFile(fileName, defaultFileFlags, func() error { + ss.writer = nil + return ss.flush() + }) + if err != nil { + return nil, err + } + maxBufSize := ss.opts.BufferSize + ss.writer = ss.file.handle + // If we allow buffering, then create the buffered writer and + // set ss's writer to that buffer. + if maxBufSize > 0 { + ss.bw = newBufferWriter(subBufMinShrinkSize, maxBufSize) + ss.writer = ss.bw.createNewWriter(ss.file.handle) + } + if doRecover { + if err := ss.recoverSubscriptions(); err != nil { + fs.fm.unlockFile(ss.file) + ss.Close() + return nil, fmt.Errorf("unable to recover subscription store for [%s]: %v", channel, err) + } + } + // Do not attempt to shrink unless the option is greater than the + // minimum shrinkable size. + if maxBufSize > subBufMinShrinkSize { + // Use lock to avoid RACE report between setting shrinkTimer and + // execution of the callback itself. + ss.Lock() + ss.allDone.Add(1) + ss.shrinkTimer = time.AfterFunc(bufShrinkInterval, func() { + ss.shrinkBuffer(true) + }) + ss.Unlock() + } + if doRecover { + fs.fm.closeLockedFile(ss.file) + } else { + fs.fm.unlockFile(ss.file) + } + return ss, nil +} + +// getFile ensures that the store's file handle is valid, opening +// the file if needed. If file needs to be opened, the store's writer +// is set to either the bare file or the buffered writer (based on +// store's configuration). +func (ss *FileSubStore) lockFile() error { + wasOpened, err := ss.fm.lockFile(ss.file) + if err != nil { + return err + } + // If file was not opened, we need to reset ss.writer + if !wasOpened { + if ss.bw != nil { + ss.writer = ss.bw.createNewWriter(ss.file.handle) + } else { + ss.writer = ss.file.handle + } + } + return nil +} + +// shrinkBuffer is a timer callback that shrinks the buffer writer when possible. +// Since this function is called directly in tests, the boolean `fromTimer` is +// used to indicate if this function is invoked from the timer callback (in which +// case, the timer need to be Reset()) or not. Reseting a timer while timer fires +// can lead to unexpected behavior. +func (ss *FileSubStore) shrinkBuffer(fromTimer bool) { + ss.Lock() + defer ss.Unlock() + + if ss.closed { + ss.allDone.Done() + return + } + // Fire again + if fromTimer { + ss.shrinkTimer.Reset(bufShrinkInterval) + } + + // If file currently opened, lock it, otherwise we are done for now. + if !ss.fm.lockFileIfOpened(ss.file) { + return + } + // If error, the buffer (in bufio) memorizes the error + // so any other write/flush on that buffer will fail. We will get the + // error at the next "synchronous" operation where we can report back + // to the user. + ss.writer, _ = ss.bw.tryShrinkBuffer(ss.file.handle) + ss.fm.unlockFile(ss.file) +} + +// recoverSubscriptions recovers subscriptions state for this store. +func (ss *FileSubStore) recoverSubscriptions() error { + var err error + var recType recordType + + recSize := 0 + offset := int64(4) + + // Create a buffered reader to speed-up recovery + br := bufio.NewReaderSize(ss.file.handle, defaultBufSize) + + for { + ss.tmpSubBuf, recSize, recType, err = readRecord(br, ss.tmpSubBuf, true, ss.crcTable, ss.opts.DoCRC) + if err != nil { + switch err { + case io.EOF: + // We are done, reset err + err = nil + case errNeedRewind: + err = ss.fm.truncateFile(ss.file, offset) + default: + err = ss.fstore.handleUnexpectedEOF(err, ss.file, offset, true) + } + if err == nil { + break + } + return err + } + readBytes := int64(recSize + recordHeaderSize) + offset += readBytes + ss.fileSize += readBytes + // Based on record type... + switch recType { + case subRecNew: + newSub := &spb.SubState{} + if err := newSub.Unmarshal(ss.tmpSubBuf[:recSize]); err != nil { + return err + } + sub := &subscription{ + sub: newSub, + seqnos: make(map[uint64]struct{}), + } + ss.subs[newSub.ID] = sub + // Keep track of max subscription ID found. + if newSub.ID > ss.maxSubID { + ss.maxSubID = newSub.ID + } + ss.numRecs++ + case subRecUpdate: + modifiedSub := &spb.SubState{} + if err := modifiedSub.Unmarshal(ss.tmpSubBuf[:recSize]); err != nil { + return err + } + // Search if the create has been recovered. + subi, exists := ss.subs[modifiedSub.ID] + if exists { + sub := subi.(*subscription) + sub.sub = modifiedSub + // An update means that the previous version is free space. + ss.delRecs++ + } else { + sub := &subscription{ + sub: modifiedSub, + seqnos: make(map[uint64]struct{}), + } + ss.subs[modifiedSub.ID] = sub + } + // Keep track of max subscription ID found. + if modifiedSub.ID > ss.maxSubID { + ss.maxSubID = modifiedSub.ID + } + ss.numRecs++ + case subRecDel: + delSub := spb.SubStateDelete{} + if err := delSub.Unmarshal(ss.tmpSubBuf[:recSize]); err != nil { + return err + } + if si, exists := ss.subs[delSub.ID]; exists { + s := si.(*subscription) + delete(ss.subs, delSub.ID) + // Delete and count all non-ack'ed messages free space. + ss.delRecs++ + ss.delRecs += len(s.seqnos) + } + // Keep track of max subscription ID found. + if delSub.ID > ss.maxSubID { + ss.maxSubID = delSub.ID + } + case subRecMsg: + updateSub := spb.SubStateUpdate{} + if err := updateSub.Unmarshal(ss.tmpSubBuf[:recSize]); err != nil { + return err + } + if subi, exists := ss.subs[updateSub.ID]; exists { + sub := subi.(*subscription) + seqno := updateSub.Seqno + // Same seqno/ack can appear several times for the same sub. + // See queue subscribers redelivery. + if seqno > sub.sub.LastSent { + sub.sub.LastSent = seqno + } + sub.seqnos[seqno] = struct{}{} + ss.numRecs++ + } + case subRecAck: + updateSub := spb.SubStateUpdate{} + if err := updateSub.Unmarshal(ss.tmpSubBuf[:recSize]); err != nil { + return err + } + if subi, exists := ss.subs[updateSub.ID]; exists { + sub := subi.(*subscription) + delete(sub.seqnos, updateSub.Seqno) + // A message is ack'ed + ss.delRecs++ + } + default: + return fmt.Errorf("unexpected record type: %v", recType) + } + } + return nil +} + +// CreateSub records a new subscription represented by SubState. On success, +// it returns an id that is used by the other methods. +func (ss *FileSubStore) CreateSub(sub *spb.SubState) error { + // Check if we can create the subscription (check limits and update + // subscription count) + ss.Lock() + defer ss.Unlock() + if err := ss.createSub(sub); err != nil { + return err + } + if err := ss.writeRecord(nil, subRecNew, sub); err != nil { + delete(ss.subs, sub.ID) + return err + } + // We need to get a copy of the passed sub, we can't hold a reference + // to it. + csub := *sub + s := &subscription{sub: &csub, seqnos: make(map[uint64]struct{})} + ss.subs[sub.ID] = s + return nil +} + +// UpdateSub updates a given subscription represented by SubState. +func (ss *FileSubStore) UpdateSub(sub *spb.SubState) error { + ss.Lock() + defer ss.Unlock() + if err := ss.writeRecord(nil, subRecUpdate, sub); err != nil { + return err + } + // We need to get a copy of the passed sub, we can't hold a reference + // to it. + csub := *sub + si := ss.subs[sub.ID] + if si != nil { + s := si.(*subscription) + s.sub = &csub + } else { + s := &subscription{sub: &csub, seqnos: make(map[uint64]struct{})} + ss.subs[sub.ID] = s + } + return nil +} + +// DeleteSub invalidates this subscription. +func (ss *FileSubStore) DeleteSub(subid uint64) error { + ss.Lock() + ss.delSub.ID = subid + err := ss.writeRecord(nil, subRecDel, &ss.delSub) + // Even if there is an error, continue with cleanup. If later + // a compact is successful, the sub won't be present in the compacted file. + if si, exists := ss.subs[subid]; exists { + s := si.(*subscription) + delete(ss.subs, subid) + // writeRecord has already accounted for the count of the + // delete record. We add to this the number of pending messages + ss.delRecs += len(s.seqnos) + // Check if this triggers a need for compaction + if ss.shouldCompact() { + ss.fm.closeFileIfOpened(ss.file) + ss.compact(ss.file.name) + } + } + ss.Unlock() + return err +} + +// shouldCompact returns a boolean indicating if we should compact +// Lock is held by caller +func (ss *FileSubStore) shouldCompact() bool { + // Gobal switch + if !ss.opts.CompactEnabled { + return false + } + // Check that if minimum file size is set, the client file + // is at least at the minimum. + if ss.opts.CompactMinFileSize > 0 && ss.fileSize < ss.opts.CompactMinFileSize { + return false + } + // Check fragmentation + frag := 0 + if ss.numRecs == 0 { + frag = 100 + } else { + frag = ss.delRecs * 100 / ss.numRecs + } + if frag < ss.opts.CompactFragmentation { + return false + } + // Check that we don't compact too often + if time.Since(ss.compactTS) < ss.compactItvl { + return false + } + return true +} + +// AddSeqPending adds the given message seqno to the given subscription. +func (ss *FileSubStore) AddSeqPending(subid, seqno uint64) error { + ss.Lock() + ss.updateSub.ID, ss.updateSub.Seqno = subid, seqno + if err := ss.writeRecord(nil, subRecMsg, &ss.updateSub); err != nil { + ss.Unlock() + return err + } + si := ss.subs[subid] + if si != nil { + s := si.(*subscription) + if seqno > s.sub.LastSent { + s.sub.LastSent = seqno + } + s.seqnos[seqno] = struct{}{} + } + ss.Unlock() + return nil +} + +// AckSeqPending records that the given message seqno has been acknowledged +// by the given subscription. +func (ss *FileSubStore) AckSeqPending(subid, seqno uint64) error { + ss.Lock() + ss.updateSub.ID, ss.updateSub.Seqno = subid, seqno + if err := ss.writeRecord(nil, subRecAck, &ss.updateSub); err != nil { + ss.Unlock() + return err + } + si := ss.subs[subid] + if si != nil { + s := si.(*subscription) + delete(s.seqnos, seqno) + // Test if we should compact + if ss.shouldCompact() { + ss.fm.closeFileIfOpened(ss.file) + ss.compact(ss.file.name) + } + } + ss.Unlock() + return nil +} + +// compact rewrites all subscriptions on a temporary file, reducing the size +// since we get rid of deleted subscriptions and message sequences that have +// been acknowledged. On success, the subscriptions file is replaced by this +// temporary file. +// Lock is held by caller +func (ss *FileSubStore) compact(orgFileName string) error { + tmpFile, err := getTempFile(ss.fm.rootDir, "subs") + if err != nil { + return err + } + tmpBW := bufio.NewWriterSize(tmpFile, defaultBufSize) + // Save values in case of failed compaction + savedNumRecs := ss.numRecs + savedDelRecs := ss.delRecs + savedFileSize := ss.fileSize + // Cleanup in case of error during compact + defer func() { + if tmpFile != nil { + tmpFile.Close() + os.Remove(tmpFile.Name()) + // Since we failed compaction, restore values + ss.numRecs = savedNumRecs + ss.delRecs = savedDelRecs + ss.fileSize = savedFileSize + } + }() + // Reset to 0 since writeRecord() is updating the values. + ss.numRecs = 0 + ss.delRecs = 0 + ss.fileSize = 0 + for _, subi := range ss.subs { + sub := subi.(*subscription) + err = ss.writeRecord(tmpBW, subRecNew, sub.sub) + if err != nil { + return err + } + ss.updateSub.ID = sub.sub.ID + for seqno := range sub.seqnos { + ss.updateSub.Seqno = seqno + err = ss.writeRecord(tmpBW, subRecMsg, &ss.updateSub) + if err != nil { + return err + } + } + } + // Flush and sync the temporary file + err = tmpBW.Flush() + if err != nil { + return err + } + err = tmpFile.Sync() + if err != nil { + return err + } + // Start by closing the temporary file. + if err := tmpFile.Close(); err != nil { + return err + } + // Rename the tmp file to original file name + if err := os.Rename(tmpFile.Name(), orgFileName); err != nil { + return err + } + // Prevent cleanup on success + tmpFile = nil + // Update the timestamp of this last successful compact + ss.compactTS = time.Now() + return nil +} + +// writes a record in the subscriptions file. +// store's lock is held on entry. +func (ss *FileSubStore) writeRecord(w io.Writer, recType recordType, rec record) error { + var err error + totalSize := 0 + recSize := rec.Size() + + var bwBuf *bufio.Writer + needsUnlock := false + + if w == nil { + if err := ss.lockFile(); err != nil { + return err + } + needsUnlock = true + if ss.bw != nil { + bwBuf = ss.bw.buf + // If we are using the buffer writer on this call, and the buffer is + // not already at the max size... + if bwBuf != nil && ss.bw.bufSize != ss.opts.BufferSize { + // Check if record fits + required := recSize + recordHeaderSize + if required > bwBuf.Available() { + ss.writer, err = ss.bw.expand(ss.file.handle, required) + if err != nil { + ss.fm.unlockFile(ss.file) + return err + } + bwBuf = ss.bw.buf + } + } + } + w = ss.writer + } + ss.tmpSubBuf, totalSize, err = writeRecord(w, ss.tmpSubBuf, recType, rec, recSize, ss.crcTable) + if err != nil { + if needsUnlock { + ss.fm.unlockFile(ss.file) + } + return err + } + if bwBuf != nil && ss.bw.shrinkReq { + ss.bw.checkShrinkRequest() + } + // Indicate that we wrote something to the buffer/file + ss.activity = true + switch recType { + case subRecNew: + ss.numRecs++ + case subRecMsg: + ss.numRecs++ + case subRecAck: + // An ack makes the message record free space + ss.delRecs++ + case subRecUpdate: + ss.numRecs++ + // An update makes the old record free space + ss.delRecs++ + case subRecDel: + ss.delRecs++ + default: + panic(fmt.Errorf("Record type %v unknown", recType)) + } + ss.fileSize += int64(totalSize) + if needsUnlock { + ss.fm.unlockFile(ss.file) + } + return nil +} + +func (ss *FileSubStore) flush() error { + // Skip this if nothing was written since the last flush + if !ss.activity { + return nil + } + // Reset this now + ss.activity = false + if ss.bw != nil && ss.bw.buf.Buffered() > 0 { + if err := ss.bw.buf.Flush(); err != nil { + return err + } + } + if ss.opts.DoSync { + return ss.file.handle.Sync() + } + return nil +} + +// Flush persists buffered operations to disk. +func (ss *FileSubStore) Flush() error { + ss.Lock() + err := ss.lockFile() + if err == nil { + err = ss.flush() + ss.fm.unlockFile(ss.file) + } + ss.Unlock() + return err +} + +// Close closes this store +func (ss *FileSubStore) Close() error { + ss.Lock() + if ss.closed { + ss.Unlock() + return nil + } + + ss.closed = true + + if ss.shrinkTimer != nil { + if ss.shrinkTimer.Stop() { + // If we can stop, timer callback won't fire, + // so we need to decrement the wait group. + ss.allDone.Done() + } + } + ss.Unlock() + + // Wait on timers/callbacks + ss.allDone.Wait() + + ss.Lock() + var err error + if ss.fm.remove(ss.file) { + if ss.file.handle != nil { + err = ss.flush() + err = util.CloseFile(err, ss.file.handle) + } + } + ss.Unlock() + + return err +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/stores/limits.go b/vendor/github.com/nats-io/nats-streaming-server/stores/limits.go new file mode 100644 index 00000000000..7baf7da4627 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/stores/limits.go @@ -0,0 +1,328 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stores + +import ( + "fmt" + "time" + + "github.com/nats-io/nats-streaming-server/util" +) + +// Used for display of limits +const ( + limitCount = iota + limitBytes + limitDuration +) + +// Clone returns a copy of the store limits +func (sl *StoreLimits) Clone() *StoreLimits { + cloned := *sl + cloned.PerChannel = sl.ClonePerChannelMap() + return &cloned +} + +// ClonePerChannelMap returns a deep copy of the StoreLimits's PerChannel map +func (sl *StoreLimits) ClonePerChannelMap() map[string]*ChannelLimits { + if sl.PerChannel == nil { + return nil + } + clone := make(map[string]*ChannelLimits, len(sl.PerChannel)) + for k, v := range sl.PerChannel { + copyVal := *v + clone[k] = ©Val + } + return clone +} + +// AddPerChannel stores limits for the given channel `name` in the StoreLimits. +// Inheritance (that is, specifying 0 for a limit means that the global limit +// should be used) is not applied in this call. This is done in StoreLimits.Build +// along with some validation. +func (sl *StoreLimits) AddPerChannel(name string, cl *ChannelLimits) { + if sl.PerChannel == nil { + sl.PerChannel = make(map[string]*ChannelLimits) + } + sl.PerChannel[name] = cl +} + +type channelLimitInfo struct { + name string + limits *ChannelLimits + isLiteral bool + isProcessed bool +} + +// Build sets the global limits into per-channel limits that are set +// to zero. This call also validates the limits. An error is returned if: +// * any global limit is set to a negative value. +// * the number of per-channel is higher than StoreLimits.MaxChannels. +// * a per-channel name is invalid +func (sl *StoreLimits) Build() error { + // Check that there is no negative value + if err := sl.checkGlobalLimits(); err != nil { + return err + } + // If there is no per-channel, we are done. + if len(sl.PerChannel) == 0 { + return nil + } + literals := 0 + sublist := util.NewSublist() + for cn, cl := range sl.PerChannel { + if !util.IsChannelNameValid(cn, true) { + return fmt.Errorf("invalid channel name %q", cn) + } + isLiteral := util.IsChannelNameLiteral(cn) + if isLiteral { + literals++ + if sl.MaxChannels > 0 && literals > sl.MaxChannels { + return fmt.Errorf("too many channels defined (%v). The max channels limit is set to %v", + literals, sl.MaxChannels) + } + } + cli := &channelLimitInfo{ + name: cn, + limits: cl, + isLiteral: isLiteral, + } + sublist.Insert(cn, cli) + } + // If we are here, it means that there was no error, + // so we now apply inheritance. + sl.applyInheritance(sublist) + return nil +} + +func (sl *StoreLimits) applyInheritance(sublist *util.Sublist) { + // Get the subjects from the sublist. This ensure that they are ordered + // from the widest to the narrowest of subjects. + channels := sublist.Subjects() + for _, cn := range channels { + r := sublist.Match(cn) + // There has to be at least 1 match (the current channel name we + // are trying to match). + channel := r[0].(*channelLimitInfo) + if channel.isLiteral && channel.isProcessed { + continue + } + if !channel.isProcessed { + sl.inheritLimits(channel, &sl.ChannelLimits) + } + prev := channel + for i := 1; i < len(r); i++ { + channel = r[i].(*channelLimitInfo) + if !channel.isProcessed { + sl.inheritLimits(channel, prev.limits) + } + prev = channel + } + } +} + +func (sl *StoreLimits) inheritLimits(channel *channelLimitInfo, parentLimits *ChannelLimits) { + cl := channel.limits + if cl.MaxSubscriptions < 0 { + cl.MaxSubscriptions = 0 + } else if cl.MaxSubscriptions == 0 { + cl.MaxSubscriptions = parentLimits.MaxSubscriptions + } + if cl.MaxMsgs < 0 { + cl.MaxMsgs = 0 + } else if cl.MaxMsgs == 0 { + cl.MaxMsgs = parentLimits.MaxMsgs + } + if cl.MaxBytes < 0 { + cl.MaxBytes = 0 + } else if cl.MaxBytes == 0 { + cl.MaxBytes = parentLimits.MaxBytes + } + if cl.MaxAge < 0 { + cl.MaxAge = 0 + } else if cl.MaxAge == 0 { + cl.MaxAge = parentLimits.MaxAge + } + if cl.MaxInactivity < 0 { + cl.MaxInactivity = 0 + } else if cl.MaxInactivity == 0 { + cl.MaxInactivity = parentLimits.MaxInactivity + } + channel.isProcessed = true +} + +func (sl *StoreLimits) checkGlobalLimits() error { + if sl.MaxChannels < 0 { + return fmt.Errorf("max channels limit cannot be negative (%v)", sl.MaxChannels) + } + if sl.MaxSubscriptions < 0 { + return fmt.Errorf("max subscriptions limit cannot be negative (%v)", sl.MaxSubscriptions) + } + if sl.MaxMsgs < 0 { + return fmt.Errorf("max messages limit cannot be negative (%v)", sl.MaxMsgs) + } + if sl.MaxBytes < 0 { + return fmt.Errorf("max bytes limit cannot be negative (%v)", sl.MaxBytes) + } + if sl.MaxAge < 0 { + return fmt.Errorf("max age limit cannot be negative (%v)", sl.MaxAge) + } + if sl.MaxInactivity < 0 { + return fmt.Errorf("max inactivity limit cannot be negative (%v)", sl.MaxInactivity) + } + return nil +} + +// Print returns an array of strings suitable for printing the store limits. +func (sl *StoreLimits) Print() []string { + sublist := util.NewSublist() + for cn, cl := range sl.PerChannel { + sublist.Insert(cn, &channelLimitInfo{ + name: cn, + limits: cl, + isLiteral: util.IsChannelNameLiteral(cn), + }) + } + maxLevels := sublist.NumLevels() + txt := []string{} + title := "---------- Store Limits ----------" + txt = append(txt, title) + txt = append(txt, fmt.Sprintf("Channels: %s", + getLimitStr(true, int64(sl.MaxChannels), + int64(DefaultStoreLimits.MaxChannels), + limitCount))) + maxLen := len(title) + txt = append(txt, "--------- Channels Limits --------") + txt = append(txt, getGlobalLimitsPrintLines(&sl.ChannelLimits)...) + if len(sl.PerChannel) > 0 { + channels := sublist.Subjects() + channelLines := []string{} + for _, cn := range channels { + r := sublist.Match(cn) + var prev *channelLimitInfo + for i := 0; i < len(r); i++ { + channel := r[i].(*channelLimitInfo) + if channel.name == cn { + var parentLimits *ChannelLimits + if prev == nil { + parentLimits = &sl.ChannelLimits + } else { + parentLimits = prev.limits + } + channelLines = append(channelLines, + getChannelLimitsPrintLines(i, maxLevels, &maxLen, channel.name, channel.limits, parentLimits)...) + break + } + prev = channel + } + } + title := " List of Channels " + numberDashesLeft := (maxLen - len(title)) / 2 + numberDashesRight := maxLen - len(title) - numberDashesLeft + title = fmt.Sprintf("%s%s%s", + repeatChar("-", numberDashesLeft), + title, + repeatChar("-", numberDashesRight)) + txt = append(txt, title) + txt = append(txt, channelLines...) + } + txt = append(txt, repeatChar("-", maxLen)) + return txt +} + +func getLimitStr(isGlobal bool, val, parentVal int64, limitType int) string { + valStr := "" + inherited := "" + if !isGlobal && (val == parentVal) { + return "" + } + if val == parentVal { + inherited = " *" + } + if val == 0 { + valStr = "unlimited" + } else { + switch limitType { + case limitBytes: + valStr = util.FriendlyBytes(val) + case limitDuration: + valStr = fmt.Sprintf("%v", time.Duration(val)) + default: + valStr = fmt.Sprintf("%v", val) + } + } + return fmt.Sprintf("%13s%s", valStr, inherited) +} + +func getGlobalLimitsPrintLines(limits *ChannelLimits) []string { + defaultLimits := &DefaultStoreLimits + defMaxSubs := int64(defaultLimits.MaxSubscriptions) + defMaxMsgs := int64(defaultLimits.MaxMsgs) + defMaxBytes := defaultLimits.MaxBytes + defMaxAge := defaultLimits.MaxAge + defMaxInactivity := defaultLimits.MaxInactivity + txt := []string{} + txt = append(txt, fmt.Sprintf(" Subscriptions: %s", getLimitStr(true, int64(limits.MaxSubscriptions), defMaxSubs, limitCount))) + txt = append(txt, fmt.Sprintf(" Messages : %s", getLimitStr(true, int64(limits.MaxMsgs), defMaxMsgs, limitCount))) + txt = append(txt, fmt.Sprintf(" Bytes : %s", getLimitStr(true, limits.MaxBytes, defMaxBytes, limitBytes))) + txt = append(txt, fmt.Sprintf(" Age : %s", getLimitStr(true, int64(limits.MaxAge), int64(defMaxAge), limitDuration))) + txt = append(txt, fmt.Sprintf(" Inactivity : %s", getLimitStr(true, int64(limits.MaxInactivity), int64(defMaxInactivity), limitDuration))) + return txt +} + +func getChannelLimitsPrintLines(level, maxLevels int, maxLen *int, channelName string, limits, parentLimits *ChannelLimits) []string { + plMaxSubs := int64(parentLimits.MaxSubscriptions) + plMaxMsgs := int64(parentLimits.MaxMsgs) + plMaxBytes := parentLimits.MaxBytes + plMaxAge := parentLimits.MaxAge + plMaxInactivity := parentLimits.MaxInactivity + maxSubsOverride := getLimitStr(false, int64(limits.MaxSubscriptions), plMaxSubs, limitCount) + maxMsgsOverride := getLimitStr(false, int64(limits.MaxMsgs), plMaxMsgs, limitCount) + maxBytesOverride := getLimitStr(false, limits.MaxBytes, plMaxBytes, limitBytes) + maxAgeOverride := getLimitStr(false, int64(limits.MaxAge), int64(plMaxAge), limitDuration) + MaxInactivityOverride := getLimitStr(false, int64(limits.MaxInactivity), int64(plMaxInactivity), limitDuration) + paddingLeft := repeatChar(" ", level) + paddingRight := repeatChar(" ", maxLevels-level) + txt := []string{} + txt = append(txt, fmt.Sprintf("%s%s", paddingLeft, channelName)) + if maxSubsOverride != "" { + txt = append(txt, fmt.Sprintf("%s |-> Subscriptions %s%s", paddingLeft, paddingRight, maxSubsOverride)) + } + if maxMsgsOverride != "" { + txt = append(txt, fmt.Sprintf("%s |-> Messages %s%s", paddingLeft, paddingRight, maxMsgsOverride)) + } + if maxBytesOverride != "" { + txt = append(txt, fmt.Sprintf("%s |-> Bytes %s%s", paddingLeft, paddingRight, maxBytesOverride)) + } + if maxAgeOverride != "" { + txt = append(txt, fmt.Sprintf("%s |-> Age %s%s", paddingLeft, paddingRight, maxAgeOverride)) + } + if MaxInactivityOverride != "" { + txt = append(txt, fmt.Sprintf("%s |-> Inactivity %s%s", paddingLeft, paddingRight, MaxInactivityOverride)) + } + for _, l := range txt { + if len(l) > *maxLen { + *maxLen = len(l) + } + } + return txt +} + +func repeatChar(char string, len int) string { + res := "" + for i := 0; i < len; i++ { + res += char + } + return res +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/stores/memstore.go b/vendor/github.com/nats-io/nats-streaming-server/stores/memstore.go new file mode 100644 index 00000000000..c6ba6171909 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/stores/memstore.go @@ -0,0 +1,281 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stores + +import ( + "sort" + "sync" + "time" + + "github.com/nats-io/go-nats-streaming/pb" + "github.com/nats-io/nats-streaming-server/logger" + "github.com/nats-io/nats-streaming-server/util" +) + +// MemoryStore is a factory for message and subscription stores. +type MemoryStore struct { + genericStore +} + +// MemorySubStore is a subscription store in memory +type MemorySubStore struct { + genericSubStore +} + +// MemoryMsgStore is a per channel message store in memory +type MemoryMsgStore struct { + genericMsgStore + msgs map[uint64]*pb.MsgProto + ageTimer *time.Timer + wg sync.WaitGroup +} + +//////////////////////////////////////////////////////////////////////////// +// MemoryStore methods +//////////////////////////////////////////////////////////////////////////// + +// NewMemoryStore returns a factory for stores held in memory. +// If not limits are provided, the store will be created with +// DefaultStoreLimits. +func NewMemoryStore(log logger.Logger, limits *StoreLimits) (*MemoryStore, error) { + ms := &MemoryStore{} + if err := ms.init(TypeMemory, log, limits); err != nil { + return nil, err + } + return ms, nil +} + +// CreateChannel implements the Store interface +func (ms *MemoryStore) CreateChannel(channel string) (*Channel, error) { + ms.Lock() + defer ms.Unlock() + + // Verify that it does not already exist or that we did not hit the limits + if err := ms.canAddChannel(channel); err != nil { + return nil, err + } + + channelLimits := ms.genericStore.getChannelLimits(channel) + + msgStore := &MemoryMsgStore{msgs: make(map[uint64]*pb.MsgProto, 64)} + msgStore.init(channel, ms.log, &channelLimits.MsgStoreLimits) + + subStore := &MemorySubStore{} + subStore.init(ms.log, &channelLimits.SubStoreLimits) + + c := &Channel{ + Subs: subStore, + Msgs: msgStore, + } + ms.channels[channel] = c + + return c, nil +} + +//////////////////////////////////////////////////////////////////////////// +// MemoryMsgStore methods +//////////////////////////////////////////////////////////////////////////// + +// Store a given message. +func (ms *MemoryMsgStore) Store(m *pb.MsgProto) (uint64, error) { + ms.Lock() + defer ms.Unlock() + + if m.Sequence <= ms.last { + // We've already seen this message. + return m.Sequence, nil + } + + if ms.first == 0 { + ms.first = m.Sequence + } + ms.last = m.Sequence + ms.msgs[ms.last] = m + ms.totalCount++ + ms.totalBytes += uint64(m.Size()) + // If there is an age limit and no timer yet created, do so now + if ms.limits.MaxAge > time.Duration(0) && ms.ageTimer == nil { + ms.wg.Add(1) + ms.ageTimer = time.AfterFunc(ms.limits.MaxAge, ms.expireMsgs) + } + + // Check if we need to remove any (but leave at least the last added) + maxMsgs := ms.limits.MaxMsgs + maxBytes := ms.limits.MaxBytes + if maxMsgs > 0 || maxBytes > 0 { + for ms.totalCount > 1 && + ((maxMsgs > 0 && ms.totalCount > maxMsgs) || + (maxBytes > 0 && (ms.totalBytes > uint64(maxBytes)))) { + ms.removeFirstMsg() + if !ms.hitLimit { + ms.hitLimit = true + ms.log.Noticef(droppingMsgsFmt, ms.subject, ms.totalCount, ms.limits.MaxMsgs, + util.FriendlyBytes(int64(ms.totalBytes)), util.FriendlyBytes(ms.limits.MaxBytes)) + } + } + } + + return ms.last, nil +} + +// Lookup returns the stored message with given sequence number. +func (ms *MemoryMsgStore) Lookup(seq uint64) (*pb.MsgProto, error) { + ms.RLock() + m := ms.msgs[seq] + ms.RUnlock() + return m, nil +} + +// FirstMsg returns the first message stored. +func (ms *MemoryMsgStore) FirstMsg() (*pb.MsgProto, error) { + ms.RLock() + m := ms.msgs[ms.first] + ms.RUnlock() + return m, nil +} + +// LastMsg returns the last message stored. +func (ms *MemoryMsgStore) LastMsg() (*pb.MsgProto, error) { + ms.RLock() + m := ms.msgs[ms.last] + ms.RUnlock() + return m, nil +} + +// GetSequenceFromTimestamp returns the sequence of the first message whose +// timestamp is greater or equal to given timestamp. +func (ms *MemoryMsgStore) GetSequenceFromTimestamp(timestamp int64) (uint64, error) { + ms.RLock() + defer ms.RUnlock() + + // No message ever stored + if ms.first == 0 { + return 0, nil + } + // All messages have expired + if ms.first > ms.last { + return ms.last + 1, nil + } + if ms.msgs[ms.first].Timestamp >= timestamp { + return ms.first, nil + } + if timestamp >= ms.msgs[ms.last].Timestamp { + return ms.last + 1, nil + } + + index := sort.Search(len(ms.msgs), func(i int) bool { + return ms.msgs[uint64(i)+ms.first].Timestamp >= timestamp + }) + + return uint64(index) + ms.first, nil +} + +// expireMsgs ensures that messages don't stay in the log longer than the +// limit's MaxAge. +func (ms *MemoryMsgStore) expireMsgs() { + ms.Lock() + defer ms.Unlock() + if ms.closed { + ms.wg.Done() + return + } + + now := time.Now().UnixNano() + maxAge := int64(ms.limits.MaxAge) + for { + m, ok := ms.msgs[ms.first] + if !ok { + if ms.first < ms.last { + ms.first++ + continue + } + ms.ageTimer = nil + ms.wg.Done() + return + } + elapsed := now - m.Timestamp + if elapsed >= maxAge { + ms.removeFirstMsg() + } else { + if elapsed < 0 { + ms.ageTimer.Reset(time.Duration(m.Timestamp - now + maxAge)) + } else { + ms.ageTimer.Reset(time.Duration(maxAge - elapsed)) + } + return + } + } +} + +// removeFirstMsg removes the first message and updates totals. +func (ms *MemoryMsgStore) removeFirstMsg() { + firstMsg := ms.msgs[ms.first] + ms.totalBytes -= uint64(firstMsg.Size()) + ms.totalCount-- + delete(ms.msgs, ms.first) + ms.first++ +} + +// Empty implements the MsgStore interface +func (ms *MemoryMsgStore) Empty() error { + ms.Lock() + if ms.ageTimer != nil { + if ms.ageTimer.Stop() { + ms.wg.Done() + } + ms.ageTimer = nil + } + ms.empty() + ms.msgs = make(map[uint64]*pb.MsgProto) + ms.Unlock() + return nil +} + +// Close implements the MsgStore interface +func (ms *MemoryMsgStore) Close() error { + ms.Lock() + if ms.closed { + ms.Unlock() + return nil + } + ms.closed = true + if ms.ageTimer != nil { + if ms.ageTimer.Stop() { + ms.wg.Done() + } + } + ms.Unlock() + + ms.wg.Wait() + return nil +} + +//////////////////////////////////////////////////////////////////////////// +// MemorySubStore methods +//////////////////////////////////////////////////////////////////////////// + +// AddSeqPending adds the given message seqno to the given subscription. +func (*MemorySubStore) AddSeqPending(subid, seqno uint64) error { + // Overrides in case genericSubStore does something. For the memory + // based store, we want to minimize the cost of this to a minimum. + return nil +} + +// AckSeqPending records that the given message seqno has been acknowledged +// by the given subscription. +func (*MemorySubStore) AckSeqPending(subid, seqno uint64) error { + // Overrides in case genericSubStore does something. For the memory + // based store, we want to minimize the cost of this to a minimum. + return nil +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/stores/raftstore.go b/vendor/github.com/nats-io/nats-streaming-server/stores/raftstore.go new file mode 100644 index 00000000000..0cd444f59ec --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/stores/raftstore.go @@ -0,0 +1,113 @@ +// Copyright 2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stores + +import ( + "sync" + + "github.com/nats-io/nats-streaming-server/spb" +) + +// RaftStore is an hybrid store for server running in clustering mode. +// This store persists/recovers ServerInfo and messages, but is a no-op +// for clients and subscriptions since we rely on raft log for that. +// It still creates/deletes subscriptions so that we on recovery we +// can ensure that we don't reuse any subscription ID. +type RaftStore struct { + sync.Mutex + Store +} + +// RaftSubStore implements the SubStore interface +type RaftSubStore struct { + SubStore +} + +// NewRaftStore returns an instarce of a RaftStore +func NewRaftStore(s Store) *RaftStore { + return &RaftStore{Store: s} +} + +//////////////////////////////////////////////////////////////////////////// +// RaftStore methods +//////////////////////////////////////////////////////////////////////////// + +// CreateChannel implements the Store interface +func (s *RaftStore) CreateChannel(channel string) (*Channel, error) { + s.Lock() + defer s.Unlock() + c, err := s.Store.CreateChannel(channel) + if err != nil { + return nil, err + } + c.Subs = &RaftSubStore{SubStore: c.Subs} + return c, nil +} + +// Name implements the Store interface +func (s *RaftStore) Name() string { + return TypeRaft + "_" + s.Store.Name() +} + +// Recover implements the Store interface +func (s *RaftStore) Recover() (*RecoveredState, error) { + s.Lock() + defer s.Unlock() + state, err := s.Store.Recover() + if err != nil { + return nil, err + } + if state != nil { + for _, rc := range state.Channels { + rc.Subscriptions = nil + } + state.Clients = nil + } + return state, nil +} + +// AddClient implements the Store interface +func (s *RaftStore) AddClient(info *spb.ClientInfo) (*Client, error) { + // No need for storage + return &Client{*info}, nil +} + +// DeleteClient implements the Store interface +func (s *RaftStore) DeleteClient(clientID string) error { + // Make this a no-op + return nil +} + +//////////////////////////////////////////////////////////////////////////// +// RaftSubStore methods +//////////////////////////////////////////////////////////////////////////// + +// UpdateSub implements the SubStore interface +func (ss *RaftSubStore) UpdateSub(*spb.SubState) error { + // Make this a no-op + return nil +} + +// AddSeqPending adds the given message 'seqno' to the subscription 'subid'. +func (ss *RaftSubStore) AddSeqPending(subid, seqno uint64) error { + // Make this a no-op + return nil +} + +// AckSeqPending records that the given message 'seqno' has been acknowledged +// by the subscription 'subid'. +func (ss *RaftSubStore) AckSeqPending(subid, seqno uint64) error { + // Make this a no-op + return nil +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/stores/sqlstore.go b/vendor/github.com/nats-io/nats-streaming-server/stores/sqlstore.go new file mode 100644 index 00000000000..e521dcbd0da --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/stores/sqlstore.go @@ -0,0 +1,2132 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stores + +import ( + "database/sql" + "encoding/json" + "fmt" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/nats-io/go-nats-streaming/pb" + "github.com/nats-io/nats-streaming-server/logger" + "github.com/nats-io/nats-streaming-server/spb" + "github.com/nats-io/nats-streaming-server/util" + "github.com/nats-io/nuid" +) + +const ( + driverMySQL = "mysql" + driverPostgres = "postgres" +) + +const ( + sqlDBLockSelect = iota + sqlDBLockInsert + sqlDBLockUpdate + sqlHasServerInfoRow + sqlUpdateServerInfo + sqlAddServerInfo + sqlAddClient + sqlDeleteClient + sqlAddChannel + sqlStoreMsg + sqlLookupMsg + sqlGetSequenceFromTimestamp + sqlUpdateChannelMaxSeq + sqlGetExpiredMessages + sqlGetFirstMsgTimestamp + sqlDeletedMsgsWithSeqLowerThan + sqlGetSizeOfMessage + sqlDeleteMessage + sqlCheckMaxSubs + sqlCreateSub + sqlUpdateSub + sqlMarkSubscriptionAsDeleted + sqlDeleteSubscription + sqlDeleteSubMarkedAsDeleted + sqlDeleteSubPendingMessages + sqlSubUpdateLastSent + sqlSubAddPending + sqlSubAddPendingRow + sqlSubDeletePending + sqlSubDeletePendingRow + sqlRecoverServerInfo + sqlRecoverClients + sqlRecoverMaxChannelID + sqlRecoverMaxSubID + sqlRecoverChannelsList + sqlRecoverChannelMsgs + sqlRecoverChannelSubs + sqlRecoverDoPurgeSubsPending + sqlRecoverSubPending + sqlRecoverGetChannelLimits + sqlRecoverDoExpireMsgs + sqlRecoverGetMessagesCount + sqlRecoverGetSeqFloorForMaxMsgs + sqlRecoverGetChannelTotalSize + sqlRecoverGetSeqFloorForMaxBytes + sqlRecoverUpdateChannelLimits + sqlDeleteChannelFast + sqlDeleteChannelGetSubIds + sqlDeleteChannelDelSubsPending + sqlDeleteChannelDelSubscriptions + sqlDeleteChannelGetSomeMessagesSeq + sqlDeleteChannelDelSomeMessages + sqlDeleteChannelDelChannel + sqlGetLastSeq +) + +var sqlStmts = []string{ + "SELECT id, tick from StoreLock FOR UPDATE", // sqlDBLockSelect + "INSERT INTO StoreLock (id, tick) VALUES (?, ?)", // sqlDBLockInsert + "UPDATE StoreLock SET id=?, tick=?", // sqlDBLockUpdate + "SELECT COUNT(uniquerow) FROM ServerInfo", // sqlHasServerInfoRow + "UPDATE ServerInfo SET id=?, proto=?, version=? WHERE uniquerow=1", // sqlUpdateServerInfo + "INSERT INTO ServerInfo (id, proto, version) VALUES (?, ?, ?)", // sqlAddServerInfo + "INSERT INTO Clients (id, hbinbox, proto) VALUES (?, ?, ?)", // sqlAddClient + "DELETE FROM Clients WHERE id=?", // sqlDeleteClient + "INSERT INTO Channels (id, name, maxmsgs, maxbytes, maxage) VALUES (?, ?, ?, ?, ?)", // sqlAddChannel + "INSERT INTO Messages VALUES (?, ?, ?, ?, ?)", // sqlStoreMsg + "SELECT timestamp, data FROM Messages WHERE id=? AND seq=?", // sqlLookupMsg + "SELECT seq FROM Messages WHERE id=? AND timestamp>=? LIMIT 1", // sqlGetSequenceFromTimestamp + "UPDATE Channels SET maxseq=? WHERE id=?", // sqlUpdateChannelMaxSeq + "SELECT COUNT(seq), COALESCE(MAX(seq), 0), COALESCE(SUM(size), 0) FROM Messages WHERE id=? AND timestamp<=?", // sqlGetExpiredMessages + "SELECT timestamp FROM Messages WHERE id=? AND seq>=? LIMIT 1", // sqlGetFirstMsgTimestamp + "DELETE FROM Messages WHERE id=? AND seq<=?", // sqlDeletedMsgsWithSeqLowerThan + "SELECT size FROM Messages WHERE id=? AND seq=?", // sqlGetSizeOfMessage + "DELETE FROM Messages WHERE id=? AND seq=?", // sqlDeleteMessage + "SELECT COUNT(subid) FROM Subscriptions WHERE id=? AND deleted=FALSE", // sqlCheckMaxSubs + "INSERT INTO Subscriptions (id, subid, proto) VALUES (?, ?, ?)", // sqlCreateSub + "UPDATE Subscriptions SET proto=? WHERE id=? AND subid=?", // sqlUpdateSub + "UPDATE Subscriptions SET deleted=TRUE WHERE id=? AND subid=?", // sqlMarkSubscriptionAsDeleted + "DELETE FROM Subscriptions WHERE id=? AND subid=?", // sqlDeleteSubscription + "DELETE FROM Subscriptions WHERE id=? AND deleted=TRUE", // sqlDeleteSubMarkedAsDeleted + "DELETE FROM SubsPending WHERE subid=?", // sqlDeleteSubPendingMessages + "UPDATE Subscriptions SET lastsent=? WHERE id=? AND subid=?", // sqlSubUpdateLastSent + "INSERT INTO SubsPending (subid, `row`, seq) VALUES (?, ?, ?)", // sqlSubAddPending + "INSERT INTO SubsPending (subid, `row`, lastsent, pending, acks) VALUES (?, ?, ?, ?, ?)", // sqlSubAddPendingRow + "DELETE FROM SubsPending WHERE subid=? AND seq=?", // sqlSubDeletePending + "DELETE FROM SubsPending WHERE subid=? AND `row`=?", // sqlSubDeletePendingRow + "SELECT id, proto, version FROM ServerInfo WHERE uniquerow=1", // sqlRecoverServerInfo + "SELECT id, hbinbox, proto FROM Clients", // sqlRecoverClients + "SELECT COALESCE(MAX(id), 0) FROM Channels", // sqlRecoverMaxChannelID + "SELECT COALESCE(MAX(subid), 0) FROM Subscriptions", // sqlRecoverMaxSubID + "SELECT id, name, maxseq FROM Channels WHERE deleted=FALSE", // sqlRecoverChannelsList + "SELECT COUNT(seq), COALESCE(MIN(seq), 0), COALESCE(MAX(seq), 0), COALESCE(SUM(size), 0), COALESCE(MAX(timestamp), 0) FROM Messages WHERE id=?", // sqlRecoverChannelMsgs + "SELECT lastsent, proto FROM Subscriptions WHERE id=? AND deleted=FALSE", // sqlRecoverChannelSubs + "DELETE FROM SubsPending WHERE subid=? AND (seq > 0 AND seq sub.LastSent { + sub.LastSent = lastSent + } + if s.opts.NoCaching { + // We can remove entries for sequence that are below the smallest + // sequence that was found in Messages. + if _, err := s.preparedStmts[sqlRecoverDoPurgeSubsPending].Exec(sub.ID, msgStore.first); err != nil { + return nil, sqlStmtError(sqlRecoverDoPurgeSubsPending, err) + } + } else { + ap = subStore.getOrCreateAcksPending(sub.ID, 0) + } + rows, err := s.preparedStmts[sqlRecoverSubPending].Query(sub.ID) + if err != nil { + return nil, sqlStmtError(sqlRecoverSubPending, err) + } + defer rows.Close() + pendingAcks := make(PendingAcks) + var gcedRows map[uint64]struct{} + if !s.opts.NoCaching { + gcedRows = make(map[uint64]struct{}) + } + for rows.Next() { + if err := subStore.recoverPendingRow(rows, sub, ap, pendingAcks, gcedRows); err != nil { + return nil, err + } + } + rows.Close() + + if s.opts.NoCaching { + // Update the in-memory map tracking last sent + subStore.subLastSent[sub.ID] = sub.LastSent + } else { + // Go over garbage collected rows and delete them + for rowID := range gcedRows { + if err := subStore.deleteSubPendingRow(sub.ID, rowID); err != nil { + return nil, err + } + } + } + + // Add to the recovered subscriptions + subscriptions = append(subscriptions, &RecoveredSubscription{Sub: sub, Pending: pendingAcks}) + } + } + subRows.Close() + + if !s.opts.NoCaching { + // Clear but also allow scheduling now that the recovery is complete. + subStore.cache.needsFlush = false + } + + rc := &RecoveredChannel{ + Channel: &Channel{ + Msgs: msgStore, + Subs: subStore, + }, + Subscriptions: subscriptions, + } + if channels == nil { + channels = make(map[string]*RecoveredChannel) + } + channels[name] = rc + s.channels[name] = rc.Channel + } + channelRows.Close() + + // Set channels into recovered state + rs.Channels = channels + + return rs, nil +} + +func (s *SQLStore) applyLimitsOnRecovery(ms *SQLMsgStore) error { + // These are the current limits set on restart. + limits := &ms.limits + maxAge := int64(limits.MaxAge) + // We need to check the ones that were stored in the DB. + var ( + storedMsgsLimit int + storedBytesLimit int64 + storedAgeLimit int64 + ) + r := s.preparedStmts[sqlRecoverGetChannelLimits].QueryRow(ms.channelID) + if err := r.Scan(&storedMsgsLimit, &storedBytesLimit, &storedAgeLimit); err != nil { + return sqlStmtError(sqlRecoverGetChannelLimits, err) + } + // If any of the limits is different than what was stored, we will + // need to update the channel at the end of this function. + needUpdate := storedMsgsLimit != limits.MaxMsgs || storedBytesLimit != limits.MaxBytes || storedAgeLimit != maxAge + + // Let's reduce the number of messages if there is an age limit and messages + // should have expired. + if maxAge > 0 { + expiredTimestamp := time.Now().UnixNano() - int64(limits.MaxAge) + if _, err := s.preparedStmts[sqlRecoverDoExpireMsgs].Exec(ms.channelID, expiredTimestamp); err != nil { + return sqlStmtError(sqlRecoverDoExpireMsgs, err) + } + } + // For MaxMsgs and MaxBytes we are interested only the new limit is + // lower than the old one (since messages are removed during runtime, + // if the limit has not been lowered, we should be good). + if limits.MaxMsgs > 0 && limits.MaxMsgs < storedMsgsLimit { + count := 0 + r := s.preparedStmts[sqlRecoverGetMessagesCount].QueryRow(ms.channelID) + if err := r.Scan(&count); err != nil { + return sqlStmtError(sqlRecoverGetMessagesCount, err) + } + // We leave at least 1 message + if count > 1 && count > limits.MaxMsgs { + seq := uint64(0) + r = s.preparedStmts[sqlRecoverGetSeqFloorForMaxMsgs].QueryRow(ms.channelID, limits.MaxMsgs) + if err := r.Scan(&seq); err != nil { + return sqlStmtError(sqlRecoverGetSeqFloorForMaxMsgs, err) + } + if _, err := s.preparedStmts[sqlDeletedMsgsWithSeqLowerThan].Exec(ms.channelID, seq-1); err != nil { + return sqlStmtError(sqlDeletedMsgsWithSeqLowerThan, err) + } + } + } + if limits.MaxBytes > 0 && limits.MaxBytes < storedBytesLimit { + currentBytes := uint64(0) + r := s.preparedStmts[sqlRecoverGetChannelTotalSize].QueryRow(ms.channelID) + if err := r.Scan(¤tBytes); err != nil { + return sqlStmtError(sqlRecoverGetChannelTotalSize, err) + } + if currentBytes > uint64(limits.MaxBytes) { + seq := 0 + // This query finds the first seq (inclusive) for which the running total + // size is <= max bytes. + r := s.preparedStmts[sqlRecoverGetSeqFloorForMaxBytes].QueryRow(ms.channelID, uint64(limits.MaxBytes)) + if err := r.Scan(&seq); err != nil { + return sqlStmtError(sqlRecoverGetSeqFloorForMaxBytes, err) + } + // If 0, it could mean that the very last message is bigger than maxBytes, + // but then we should try to delete anything before the last (keep at least + // one). + if seq == 0 { + r = s.preparedStmts[sqlGetLastSeq].QueryRow(ms.channelID) + if err := r.Scan(&seq); err != nil { + return sqlStmtError(sqlGetLastSeq, err) + } + } + // Delete at seq-1 + if seq > 0 { + seq-- + } + if seq > 0 { + if _, err := s.preparedStmts[sqlDeletedMsgsWithSeqLowerThan].Exec(ms.channelID, seq); err != nil { + return sqlStmtError(sqlDeletedMsgsWithSeqLowerThan, err) + } + } + } + } + // If limits were changed compared to last run, we need to update the + // Channels table. + if needUpdate { + if _, err := s.preparedStmts[sqlRecoverUpdateChannelLimits].Exec( + limits.MaxMsgs, limits.MaxBytes, maxAge, ms.channelID); err != nil { + return sqlStmtError(sqlRecoverUpdateChannelLimits, err) + } + } + return nil +} + +// CreateChannel implements the Store interface +func (s *SQLStore) CreateChannel(channel string) (*Channel, error) { + s.Lock() + defer s.Unlock() + + // Verify that it does not already exist or that we did not hit the limits + if err := s.canAddChannel(channel); err != nil { + return nil, err + } + + channelLimits := s.genericStore.getChannelLimits(channel) + + cid := s.maxChannelID + 1 + if _, err := s.preparedStmts[sqlAddChannel].Exec(cid, channel, + channelLimits.MaxMsgs, channelLimits.MaxBytes, int64(channelLimits.MaxAge)); err != nil { + return nil, sqlStmtError(sqlAddChannel, err) + } + s.maxChannelID = cid + + msgStore := s.newSQLMsgStore(channel, cid, &channelLimits.MsgStoreLimits) + subStore := s.newSQLSubStore(cid, &channelLimits.SubStoreLimits) + + c := &Channel{ + Subs: subStore, + Msgs: msgStore, + } + s.channels[channel] = c + + return c, nil +} + +// DeleteChannel implements the Store interface +func (s *SQLStore) DeleteChannel(channel string) error { + s.Lock() + defer s.Unlock() + c := s.channels[channel] + if c == nil { + return ErrNotFound + } + // Get the channel ID from Msgs store + cid := c.Msgs.(*SQLMsgStore).channelID + // Fast delete just marks the channel row as deleted + if _, err := s.preparedStmts[sqlDeleteChannelFast].Exec(cid); err != nil { + return err + } + + // If that succeeds, proceed with deletion of channel + delete(s.channels, channel) + + // Close the messages and subs stores + c.Msgs.Close() + c.Subs.Close() + + // Now trigger in a go routine the longer deletion of entries + // in all other tables. + s.wg.Add(1) + go func() { + defer s.wg.Done() + + if err := s.deepChannelDelete(cid); err != nil { + s.log.Errorf("Unable to completely delete channel %q: %v", channel, err) + } + }() + + return nil +} + +// This function is called after a channel has been marked +// as deleted. It will do a "deep" delete of the channel, +// which means removing all rows from any table that has +// a reference to the deleted channel. It is executed in +// a separate go-routine (as to not block DeleteChannel() +// call). It will run to completion possibly delaying +// the closing of the store. +func (s *SQLStore) deepChannelDelete(channelID int64) error { + // On Store.Close(), the prepared statements and DB + // won't be closed until after this call returns, + // so we don't need explicit store locking. + + // We start by removing from SubsPending. + limit := 1000 + for { + // This will get us a set of subscription ids. We need + // to repeat since we have a limit in the query + rows, err := s.preparedStmts[sqlDeleteChannelGetSubIds].Query(channelID, limit) + + // If no more row, we are done, continue with other tables. + if err == sql.ErrNoRows { + break + } + if err != nil { + return err + } + defer rows.Close() + + count := 0 + for rows.Next() { + var subid uint64 + if err := rows.Scan(&subid); err != nil { + return err + } + _, err := s.preparedStmts[sqlDeleteChannelDelSubsPending].Exec(subid) + if err != nil { + return err + } + count++ + } + rows.Close() + if count < limit { + break + } + } + // Same for messages, we will get a certain number of messages + // to delete and repeat the operation. + for { + var maxSeq uint64 + + row := s.preparedStmts[sqlDeleteChannelGetSomeMessagesSeq].QueryRow(channelID, limit) + if err := row.Scan(&maxSeq); err != nil { + return err + } + if maxSeq == 0 { + break + } + _, err := s.preparedStmts[sqlDeleteChannelDelSomeMessages].Exec(channelID, maxSeq) + if err != nil { + return err + } + } + // Now with the subscriptions and channel + _, err := s.preparedStmts[sqlDeleteChannelDelSubscriptions].Exec(channelID) + if err == nil { + _, err = s.preparedStmts[sqlDeleteChannelDelChannel].Exec(channelID) + } + return err +} + +// AddClient implements the Store interface +func (s *SQLStore) AddClient(info *spb.ClientInfo) (*Client, error) { + s.Lock() + defer s.Unlock() + var ( + protoBytes []byte + err error + ) + protoBytes, err = info.Marshal() + if err != nil { + return nil, err + } + client := &Client{*info} + for i := 0; i < 2; i++ { + _, err = s.preparedStmts[sqlAddClient].Exec(client.ID, client.HbInbox, protoBytes) + if err == nil { + break + } + // We stop if this is the second AddClient failed attempt. + if i > 0 { + err = sqlStmtError(sqlAddClient, err) + break + } + // This is the first AddClient failed attempt. It could be because + // client was already in db, so delete now and try again. + _, err = s.preparedStmts[sqlDeleteClient].Exec(client.ID) + if err != nil { + err = sqlStmtError(sqlDeleteClient, err) + break + } + } + if err != nil { + return nil, err + } + return client, nil +} + +// DeleteClient implements the Store interface +func (s *SQLStore) DeleteClient(clientID string) error { + s.Lock() + _, err := s.preparedStmts[sqlDeleteClient].Exec(clientID) + if err != nil { + err = sqlStmtError(sqlDeleteClient, err) + } + s.Unlock() + return err +} + +// timeTick updates the store's time in nanosecond at regular +// interval. The time is used in Lookup() to compensate for possible +// delay in expiring messages. The Lookup() will check the message's +// expiration time against the time captured here. If it is expired +// even though it is still in the database, Lookup() will return nil. +func (s *SQLStore) timeTick() { + defer s.wg.Done() + timer := time.NewTicker(sqlTimeTickInterval) + for { + select { + case <-s.doneCh: + timer.Stop() + return + case <-timer.C: + atomic.StoreInt64(&s.nowInNano, time.Now().UnixNano()) + } + } +} + +// Close implements the Store interface +func (s *SQLStore) Close() error { + s.Lock() + if s.closed { + s.Unlock() + return nil + } + s.closed = true + // This will cause MsgStore's and SubStore's to be closed. + err := s.close() + db := s.db + wg := &s.wg + // Signal background go-routines to quit + if s.doneCh != nil { + close(s.doneCh) + } + s.Unlock() + + // Wait for go routine(s) to finish + wg.Wait() + + s.Lock() + for _, ps := range s.preparedStmts { + if lerr := ps.Close(); lerr != nil && err == nil { + err = lerr + } + } + if db != nil { + if s.dbLock != nil { + s.releaseDBLockIfOwner() + } + if lerr := db.Close(); lerr != nil && err == nil { + err = lerr + } + } + s.Unlock() + return err +} + +//////////////////////////////////////////////////////////////////////////// +// SQLMsgStore methods +//////////////////////////////////////////////////////////////////////////// + +func (mc *sqlMsgsCache) add(msg *pb.MsgProto, data []byte) { + cachedMsg := mc.free + if cachedMsg != nil { + mc.free = cachedMsg.next + cachedMsg.next = nil + // Remove old message from the map + delete(mc.msgs, cachedMsg.msg.Sequence) + } else { + cachedMsg = &sqlCachedMsg{} + } + cachedMsg.msg = msg + cachedMsg.data = data + mc.msgs[msg.Sequence] = cachedMsg + if mc.head == nil { + mc.head = cachedMsg + } else { + mc.tail.next = cachedMsg + } + mc.tail = cachedMsg +} + +func (mc *sqlMsgsCache) transferToFreeList() { + if mc.tail != nil { + mc.tail.next = mc.free + mc.free = mc.head + } + mc.head = nil + mc.tail = nil +} + +func (mc *sqlMsgsCache) pop() *sqlCachedMsg { + cm := mc.head + if cm != nil { + delete(mc.msgs, cm.msg.Sequence) + mc.head = cm.next + if mc.head == nil { + mc.tail = nil + } + } + return cm +} + +// Store implements the MsgStore interface +func (ms *SQLMsgStore) Store(m *pb.MsgProto) (uint64, error) { + ms.Lock() + defer ms.Unlock() + + if m.Sequence <= ms.last { + // We've already seen this message. + return m.Sequence, nil + } + + seq := m.Sequence + msgBytes, _ := m.Marshal() + + dataLen := uint64(len(msgBytes)) + + useCache := !ms.sqlStore.opts.NoCaching + if useCache { + ms.writeCache.add(m, msgBytes) + } else { + if _, err := ms.sqlStore.preparedStmts[sqlStoreMsg].Exec(ms.channelID, seq, m.Timestamp, dataLen, msgBytes); err != nil { + return 0, sqlStmtError(sqlStoreMsg, err) + } + } + if ms.first == 0 { + ms.first = seq + } + ms.last = seq + ms.totalCount++ + ms.totalBytes += dataLen + + // Check if we need to remove any (but leave at least the last added) + maxMsgs := ms.limits.MaxMsgs + maxBytes := ms.limits.MaxBytes + if maxMsgs > 0 || maxBytes > 0 { + for ms.totalCount > 1 && + ((maxMsgs > 0 && ms.totalCount > maxMsgs) || + (maxBytes > 0 && (ms.totalBytes > uint64(maxBytes)))) { + + didSQL := false + delBytes := uint64(0) + + if useCache && ms.writeCache.head.msg.Sequence == ms.first { + firstCachedMsg := ms.writeCache.pop() + delBytes = uint64(len(firstCachedMsg.data)) + } else { + r := ms.sqlStore.preparedStmts[sqlGetSizeOfMessage].QueryRow(ms.channelID, ms.first) + if err := r.Scan(&delBytes); err != nil && err != sql.ErrNoRows { + return 0, sqlStmtError(sqlGetSizeOfMessage, err) + } + didSQL = true + } + if delBytes > 0 { + if didSQL { + if _, err := ms.sqlStore.preparedStmts[sqlDeleteMessage].Exec(ms.channelID, ms.first); err != nil { + return 0, sqlStmtError(sqlDeleteMessage, err) + } + } + ms.totalCount-- + ms.totalBytes -= delBytes + ms.first++ + } + if !ms.hitLimit { + ms.hitLimit = true + ms.log.Noticef(droppingMsgsFmt, ms.subject, ms.totalCount, ms.limits.MaxMsgs, + util.FriendlyBytes(int64(ms.totalBytes)), util.FriendlyBytes(ms.limits.MaxBytes)) + } + } + } + + if !useCache && ms.limits.MaxAge > 0 && ms.expireTimer == nil { + ms.createExpireTimer() + } + return seq, nil +} + +func (ms *SQLMsgStore) createExpireTimer() { + ms.wg.Add(1) + ms.expireTimer = time.AfterFunc(ms.limits.MaxAge, ms.expireMsgs) +} + +// Lookup implements the MsgStore interface +func (ms *SQLMsgStore) Lookup(seq uint64) (*pb.MsgProto, error) { + ms.Lock() + msg, err := ms.lookup(seq) + ms.Unlock() + return msg, err +} + +func (ms *SQLMsgStore) lookup(seq uint64) (*pb.MsgProto, error) { + var ( + timestamp int64 + data []byte + msg *pb.MsgProto + ) + if seq < ms.first || seq > ms.last { + return nil, nil + } + if !ms.sqlStore.opts.NoCaching { + cm := ms.writeCache.msgs[seq] + if cm != nil { + msg = cm.msg + timestamp = msg.Timestamp + } + } + if msg == nil { + r := ms.sqlStore.preparedStmts[sqlLookupMsg].QueryRow(ms.channelID, seq) + err := r.Scan(×tamp, &data) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, sqlStmtError(sqlLookupMsg, err) + } + } + if maxAge := int64(ms.limits.MaxAge); maxAge > 0 && atomic.LoadInt64(&ms.sqlStore.nowInNano) > timestamp+maxAge { + return nil, nil + } + if msg == nil { + msg = &pb.MsgProto{} + msg.Unmarshal(data) + } + return msg, nil +} + +// GetSequenceFromTimestamp implements the MsgStore interface +func (ms *SQLMsgStore) GetSequenceFromTimestamp(timestamp int64) (uint64, error) { + ms.Lock() + defer ms.Unlock() + // No message ever stored + if ms.first == 0 { + return 0, nil + } + // All messages have expired + if ms.first > ms.last { + return ms.last + 1, nil + } + r := ms.sqlStore.preparedStmts[sqlGetSequenceFromTimestamp].QueryRow(ms.channelID, timestamp) + seq := uint64(0) + err := r.Scan(&seq) + if err == sql.ErrNoRows { + return ms.last + 1, nil + } + if err != nil { + return 0, sqlStmtError(sqlGetSequenceFromTimestamp, err) + } + return seq, nil +} + +// FirstMsg implements the MsgStore interface +func (ms *SQLMsgStore) FirstMsg() (*pb.MsgProto, error) { + ms.Lock() + msg, err := ms.lookup(ms.first) + ms.Unlock() + return msg, err +} + +// LastMsg implements the MsgStore interface +func (ms *SQLMsgStore) LastMsg() (*pb.MsgProto, error) { + ms.Lock() + msg, err := ms.lookup(ms.last) + ms.Unlock() + return msg, err +} + +// expireMsgsLocked removes all messages that have expired in this channel. +// Store lock is assumed held on entry +func (ms *SQLMsgStore) expireMsgs() { + ms.Lock() + defer ms.Unlock() + + if ms.closed { + ms.wg.Done() + return + } + + var ( + count int + maxSeq uint64 + totalSize uint64 + timestamp int64 + ) + processErr := func(errCode int, err error) { + ms.log.Errorf("Unable to perform expiration for channel %q: %v", ms.subject, sqlStmtError(errCode, err)) + ms.expireTimer.Reset(sqlExpirationIntervalOnError) + } + for { + expiredTimestamp := time.Now().UnixNano() - int64(ms.limits.MaxAge) + r := ms.sqlStore.preparedStmts[sqlGetExpiredMessages].QueryRow(ms.channelID, expiredTimestamp) + if err := r.Scan(&count, &maxSeq, &totalSize); err != nil { + processErr(sqlGetExpiredMessages, err) + return + } + // It could be that messages that should have expired have been + // removed due to count/size limit. We still need to adjust the + // expiration timer based on the first message that need to expire. + if count > 0 { + if maxSeq == ms.last { + if _, err := ms.sqlStore.preparedStmts[sqlUpdateChannelMaxSeq].Exec(maxSeq, ms.channelID); err != nil { + processErr(sqlUpdateChannelMaxSeq, err) + return + } + } + if _, err := ms.sqlStore.preparedStmts[sqlDeletedMsgsWithSeqLowerThan].Exec(ms.channelID, maxSeq); err != nil { + processErr(sqlDeletedMsgsWithSeqLowerThan, err) + return + } + ms.first = maxSeq + 1 + ms.totalCount -= count + ms.totalBytes -= totalSize + } + // Reset since we are in a loop + timestamp = 0 + // If there is any message left in the channel, find out what the expiration + // timer needs to be set to. + if ms.totalCount > 0 { + r = ms.sqlStore.preparedStmts[sqlGetFirstMsgTimestamp].QueryRow(ms.channelID, ms.first) + if err := r.Scan(×tamp); err != nil { + processErr(sqlGetFirstMsgTimestamp, err) + return + } + } + // No message left or no message to expire. The timer will be recreated when + // a new message is added to the channel. + if timestamp == 0 { + ms.wg.Done() + ms.expireTimer = nil + return + } + elapsed := time.Duration(time.Now().UnixNano() - timestamp) + if elapsed < ms.limits.MaxAge { + ms.expireTimer.Reset(ms.limits.MaxAge - elapsed) + // Done with the for loop + return + } + } +} + +func (ms *SQLMsgStore) flush() error { + if ms.sqlStore.opts.NoCaching { + return nil + } + if ms.writeCache.head == nil { + return nil + } + var ( + tx *sql.Tx + ps *sql.Stmt + ) + defer func() { + ms.writeCache.transferToFreeList() + if ps != nil { + ps.Close() + } + if tx != nil { + tx.Rollback() + } + }() + tx, err := ms.sqlStore.db.Begin() + if err != nil { + return err + } + ps, err = tx.Prepare(sqlStmts[sqlStoreMsg]) + if err != nil { + return err + } + // Iterate through the cache, but do not remove elements from the list. + // They are needed in transferToFreeList(). + for cm := ms.writeCache.head; cm != nil; cm = cm.next { + if _, err := ps.Exec(ms.channelID, cm.msg.Sequence, cm.msg.Timestamp, len(cm.data), cm.data); err != nil { + return err + } + } + if err := ps.Close(); err != nil { + return err + } + ps = nil + if err := tx.Commit(); err != nil { + return err + } + tx = nil + if ms.limits.MaxAge > 0 && ms.expireTimer == nil { + ms.createExpireTimer() + } + return nil +} + +// Empty implements the MsgStore interface +func (ms *SQLMsgStore) Empty() error { + ms.Lock() + tx, err := ms.sqlStore.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(sqlStmts[sqlDeletedMsgsWithSeqLowerThan], ms.channelID, ms.last); err != nil { + return err + } + if _, err := tx.Exec(sqlStmts[sqlUpdateChannelMaxSeq], 0, ms.channelID); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return err + } + ms.empty() + if ms.expireTimer != nil { + if ms.expireTimer.Stop() { + ms.wg.Done() + } + ms.expireTimer = nil + } + if ms.writeCache != nil { + ms.writeCache.transferToFreeList() + } + ms.Unlock() + return err +} + +// Flush implements the MsgStore interface +func (ms *SQLMsgStore) Flush() error { + ms.Lock() + err := ms.flush() + ms.Unlock() + return err +} + +// Close implements the MsgStore interface +func (ms *SQLMsgStore) Close() error { + ms.Lock() + if ms.closed { + ms.Unlock() + return nil + } + // Flush before switching the state to closed + err := ms.flush() + ms.closed = true + if ms.expireTimer != nil { + if ms.expireTimer.Stop() { + ms.wg.Done() + } + } + ms.Unlock() + + ms.wg.Wait() + return err +} + +//////////////////////////////////////////////////////////////////////////// +// SQLSubStore methods +//////////////////////////////////////////////////////////////////////////// + +// CreateSub implements the SubStore interface +func (ss *SQLSubStore) CreateSub(sub *spb.SubState) error { + ss.Lock() + defer ss.Unlock() + // Check limits only if needed + if ss.limits.MaxSubscriptions > 0 { + r := ss.sqlStore.preparedStmts[sqlCheckMaxSubs].QueryRow(ss.channelID) + count := 0 + if err := r.Scan(&count); err != nil { + return sqlStmtError(sqlCheckMaxSubs, err) + } + if count >= ss.limits.MaxSubscriptions { + return ErrTooManySubs + } + } + sub.ID = atomic.AddUint64(ss.maxSubID, 1) + subBytes, _ := sub.Marshal() + if _, err := ss.sqlStore.preparedStmts[sqlCreateSub].Exec(ss.channelID, sub.ID, subBytes); err != nil { + sub.ID = 0 + return sqlStmtError(sqlCreateSub, err) + } + if ss.hasMarkedAsDel { + if _, err := ss.sqlStore.preparedStmts[sqlDeleteSubMarkedAsDeleted].Exec(ss.channelID); err != nil { + return sqlStmtError(sqlDeleteSubMarkedAsDeleted, err) + } + ss.hasMarkedAsDel = false + } + return nil +} + +// UpdateSub implements the SubStore interface +func (ss *SQLSubStore) UpdateSub(sub *spb.SubState) error { + ss.Lock() + defer ss.Unlock() + subBytes, _ := sub.Marshal() + r, err := ss.sqlStore.preparedStmts[sqlUpdateSub].Exec(subBytes, ss.channelID, sub.ID) + if err != nil { + return sqlStmtError(sqlUpdateSub, err) + } + // FileSubStoe supports updating a subscription for which there was no CreateSub. + // Not sure if this is necessary, since I think server would never do that. + // Stay consistent. + c, err := r.RowsAffected() + if err != nil { + return err + } + if c == 0 { + if _, err := ss.sqlStore.preparedStmts[sqlCreateSub].Exec(ss.channelID, sub.ID, subBytes); err != nil { + return sqlStmtError(sqlCreateSub, err) + } + } + return nil +} + +// DeleteSub implements the SubStore interface +func (ss *SQLSubStore) DeleteSub(subid uint64) error { + ss.Lock() + defer ss.Unlock() + if subid == atomic.LoadUint64(ss.maxSubID) { + if _, err := ss.sqlStore.preparedStmts[sqlMarkSubscriptionAsDeleted].Exec(ss.channelID, subid); err != nil { + return sqlStmtError(sqlMarkSubscriptionAsDeleted, err) + } + ss.hasMarkedAsDel = true + } else { + if _, err := ss.sqlStore.preparedStmts[sqlDeleteSubscription].Exec(ss.channelID, subid); err != nil { + return sqlStmtError(sqlDeleteSubscription, err) + } + } + if ss.cache != nil { + delete(ss.cache.subs, subid) + } else { + delete(ss.subLastSent, subid) + } + // Ignore error on this since subscription would not be recovered + // if above executed ok. + ss.sqlStore.preparedStmts[sqlDeleteSubPendingMessages].Exec(subid) + return nil +} + +// This returns the structure responsible to keep track of +// pending messages and acks for a given subscription ID. +func (ss *SQLSubStore) getOrCreateAcksPending(subid, seqno uint64) *sqlSubAcksPending { + if !ss.cache.needsFlush { + ss.cache.needsFlush = true + ss.sqlStore.scheduleSubStoreFlush(ss) + } + ap := ss.cache.subs[subid] + if ap == nil { + ap = &sqlSubAcksPending{ + msgToRow: make(map[uint64]*sqlSubsPendingRow), + ackToRow: make(map[uint64]*sqlSubsPendingRow), + msgs: make(map[uint64]struct{}), + acks: make(map[uint64]struct{}), + } + ss.cache.subs[subid] = ap + } + if seqno > ap.lastSent { + ap.lastSent = seqno + } + return ap +} + +// Adds the given sequence to the list of pending messages. +// Returns true if the number of pending messages has +// reached a certain threshold, indicating that the +// store should be flushed. +func (ss *SQLSubStore) addSeq(subid, seqno uint64) bool { + ap := ss.getOrCreateAcksPending(subid, seqno) + ap.msgs[seqno] = struct{}{} + return len(ap.msgs) >= sqlMaxPendingAcks +} + +// Adds the given sequence to the list of acks and possibly +// delete rows that have all their pending messages acknowledged. +// Returns true if the number of acks has reached a certain threshold, +// indicating that the store should be flushed. +func (ss *SQLSubStore) ackSeq(subid, seqno uint64) (bool, error) { + ap := ss.getOrCreateAcksPending(subid, seqno) + // If still in cache and not persisted into a row, + // then simply remove from map and do not persist the ack. + if _, exists := ap.msgs[seqno]; exists { + delete(ap.msgs, seqno) + } else if row := ap.msgToRow[seqno]; row != nil { + ap.acks[seqno] = struct{}{} + // This is an ack for a pending msg that was persisted + // in a row. Update the row's msgRef count. + delete(ap.msgToRow, seqno) + row.msgsRefs-- + // If all pending messages in that row have been ack'ed + if row.msgsRefs == 0 { + // and if all acks on that row are no longer needed + // (or there was none) + if row.acksRefs == 0 { + // then this row can be deleted. + if err := ss.deleteSubPendingRow(subid, row.ID); err != nil { + return false, err + } + // If there is no error, we don't even need + // to persist this ack. + delete(ap.acks, seqno) + } + // Since there is no pending message left in this + // row, let's find all the corresponding acks' rows + // for these sequences and update their acksRefs + for seq := range row.msgs { + delete(row.msgs, seq) + ackRow := ap.ackToRow[seq] + if ackRow != nil { + // We found the row for the ack of this sequence, + // remove from map and update reference count. + // delete(ap.ackToRow, seq) + ackRow.acksRefs-- + // If all acks for that row are no longer needed and + // that row has also no pending messages, then ok to + // delete. + if ackRow.acksRefs == 0 && ackRow.msgsRefs == 0 { + if err := ss.deleteSubPendingRow(subid, ackRow.ID); err != nil { + return false, err + } + } + } else { + // That means the ack is in current cache so we won't + // need to persist it. + delete(ap.acks, seq) + } + } + sqlSeqMapPool.Put(row.msgs) + row.msgs = nil + } + } + return len(ap.acks) >= sqlMaxPendingAcks, nil +} + +// AddSeqPending implements the SubStore interface +func (ss *SQLSubStore) AddSeqPending(subid, seqno uint64) error { + var err error + ss.Lock() + if !ss.closed { + if ss.cache != nil { + if isFull := ss.addSeq(subid, seqno); isFull { + err = ss.flush() + } + } else { + ls := ss.subLastSent[subid] + if seqno > ls { + ss.subLastSent[subid] = seqno + } + ss.curRow++ + _, err = ss.sqlStore.preparedStmts[sqlSubAddPending].Exec(subid, ss.curRow, seqno) + if err != nil { + err = sqlStmtError(sqlSubAddPending, err) + } + } + } + ss.Unlock() + return err +} + +// AckSeqPending implements the SubStore interface +func (ss *SQLSubStore) AckSeqPending(subid, seqno uint64) error { + var err error + ss.Lock() + if !ss.closed { + if ss.cache != nil { + var isFull bool + isFull, err = ss.ackSeq(subid, seqno) + if err == nil && isFull { + err = ss.flush() + } + } else { + updateLastSent := false + ls := ss.subLastSent[subid] + if seqno >= ls { + if seqno > ls { + ss.subLastSent[subid] = seqno + } + updateLastSent = true + } + if updateLastSent { + if _, err := ss.sqlStore.preparedStmts[sqlSubUpdateLastSent].Exec(seqno, ss.channelID, subid); err != nil { + ss.Unlock() + return sqlStmtError(sqlSubUpdateLastSent, err) + } + } + _, err = ss.sqlStore.preparedStmts[sqlSubDeletePending].Exec(subid, seqno) + if err != nil { + err = sqlStmtError(sqlSubDeletePending, err) + } + } + } + ss.Unlock() + return err +} + +func (ss *SQLSubStore) deleteSubPendingRow(subid, rowid uint64) error { + if _, err := ss.sqlStore.preparedStmts[sqlSubDeletePendingRow].Exec(subid, rowid); err != nil { + return sqlStmtError(sqlSubDeletePendingRow, err) + } + return nil +} + +func (ss *SQLSubStore) recoverPendingRow(rows *sql.Rows, sub *spb.SubState, ap *sqlSubAcksPending, pendingAcks PendingAcks, + gcedRows map[uint64]struct{}) error { + var ( + seq, lastSent uint64 + pendingBytes, acksBytes []byte + ) + if err := rows.Scan(&ss.curRow, &seq, &lastSent, &pendingBytes, &acksBytes); err != nil && err != sql.ErrNoRows { + return err + } + // If seq is non zero, this was created from a non-buffered run. + if seq > 0 { + if seq > sub.LastSent { + sub.LastSent = seq + } + pendingAcks[seq] = struct{}{} + } else { + var row *sqlSubsPendingRow + if ap != nil { + row = &sqlSubsPendingRow{ + ID: ss.curRow, + msgs: sqlSeqMapPool.Get().(map[uint64]struct{}), + } + ap.lastSent = lastSent + ap.prevLastSent = lastSent + } + + if lastSent > sub.LastSent { + sub.LastSent = lastSent + } + if len(pendingBytes) > 0 { + if err := sqlDecodeSeqs(pendingBytes, func(seq uint64) { + pendingAcks[seq] = struct{}{} + if ap != nil { + row.msgsRefs++ + row.msgs[seq] = struct{}{} + ap.msgToRow[seq] = row + } + }); err != nil { + return err + } + } + if len(acksBytes) > 0 { + if err := sqlDecodeSeqs(acksBytes, func(seq uint64) { + if _, exists := pendingAcks[seq]; exists { + delete(pendingAcks, seq) + if ap != nil { + row.acksRefs++ + ap.ackToRow[seq] = row + + seqRow := ap.msgToRow[seq] + if seqRow != nil { + delete(ap.msgToRow, seq) + seqRow.msgsRefs-- + if seqRow.msgsRefs == 0 && seqRow.acksRefs == 0 { + gcedRows[seqRow.ID] = struct{}{} + } + } + } + } + }); err != nil { + return err + } + } + } + return nil +} + +// Flush implements the SubStore interface +func (ss *SQLSubStore) Flush() error { + ss.Lock() + err := ss.flush() + ss.Unlock() + return err +} + +func (ss *SQLSubStore) flush() error { + if ss.cache == nil || !ss.cache.needsFlush || ss.closed { + return nil + } + var ( + tx *sql.Tx + ps *sql.Stmt + err error + ) + defer func() { + if ps != nil { + ps.Close() + } + if tx != nil { + tx.Rollback() + } + }() + tx, err = ss.sqlStore.db.Begin() + if err != nil { + return err + } + ps, err = tx.Prepare(sqlStmts[sqlSubAddPendingRow]) + if err != nil { + return err + } + for subid, ap := range ss.cache.subs { + if len(ap.msgs) == 0 && len(ap.acks) == 0 { + // Update subscription's lastSent column if it has changed. + if ap.lastSent != ap.prevLastSent { + if _, err := tx.Exec(sqlStmts[sqlSubUpdateLastSent], ap.lastSent, ss.channelID, subid); err != nil { + return err + } + ap.prevLastSent = ap.lastSent + } + // Since there was no pending nor ack for this sub, simply continue + // with the next subscription. + continue + } + var ( + pendingBytes []byte + acksBytes []byte + ) + ss.curRow++ + row := &sqlSubsPendingRow{ID: ss.curRow} + if len(ap.msgs) > 0 { + pendingBytes, err = sqlEncodeSeqs(ap.msgs, func(seqno uint64) { + row.msgsRefs++ + ap.msgToRow[seqno] = row + }) + if err != nil { + return err + } + row.msgs = ap.msgs + ap.msgs = sqlSeqMapPool.Get().(map[uint64]struct{}) + } + if len(ap.acks) > 0 { + acksBytes, err = sqlEncodeSeqs(ap.acks, func(seqno uint64) { + delete(ap.acks, seqno) + row.acksRefs++ + ap.ackToRow[seqno] = row + }) + if err != nil { + return err + } + } + if _, err := ps.Exec(subid, ss.curRow, ap.lastSent, pendingBytes, acksBytes); err != nil { + return err + } + } + if err := ps.Close(); err != nil { + return err + } + ps = nil + if err := tx.Commit(); err != nil { + return err + } + tx = nil + ss.cache.needsFlush = false + return nil +} + +func sqlEncodeSeqs(m map[uint64]struct{}, f func(seq uint64)) ([]byte, error) { + // We store as a pointer in the sync pool. + pseqarray := sqlSeqArrayPool.Get().(*[]uint64) + seqarray := *pseqarray + for seqno := range m { + f(seqno) + seqarray = append(seqarray, seqno) + } + b, err := json.Marshal(seqarray) + if err != nil { + return nil, err + } + seqarray = seqarray[:0] + sqlSeqArrayPool.Put(&seqarray) + return b, nil +} + +func sqlDecodeSeqs(data []byte, f func(seq uint64)) error { + var seqarray []uint64 + if err := json.Unmarshal(data, &seqarray); err != nil { + return err + } + if seqarray != nil { + for _, seq := range seqarray { + f(seq) + } + seqarray = seqarray[:0] + sqlSeqArrayPool.Put(&seqarray) + } + return nil +} + +// Close implements the SubStore interface +func (ss *SQLSubStore) Close() error { + ss.Lock() + if ss.closed { + ss.Unlock() + return nil + } + // Flush before switching the state to closed. + err := ss.flush() + ss.closed = true + ss.Unlock() + return err +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/stores/store.go b/vendor/github.com/nats-io/nats-streaming-server/stores/store.go new file mode 100644 index 00000000000..a77bef54817 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/stores/store.go @@ -0,0 +1,303 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stores + +import ( + "errors" + "time" + + "github.com/nats-io/go-nats-streaming/pb" + "github.com/nats-io/nats-streaming-server/spb" +) + +const ( + // TypeMemory is the store type name for memory based stores + TypeMemory = "MEMORY" + // TypeFile is the store type name for file based stores + TypeFile = "FILE" + // TypeSQL is the store type name for sql based stores + TypeSQL = "SQL" + // TypeRaft is the store type name for the raft stores + TypeRaft = "RAFT" +) + +// Errors. +var ( + ErrTooManyChannels = errors.New("too many channels") + ErrTooManySubs = errors.New("too many subscriptions per channel") + ErrNotSupported = errors.New("not supported") + ErrAlreadyExists = errors.New("already exists") + ErrNotFound = errors.New("not found") +) + +// StoreLimits define limits for a store. +type StoreLimits struct { + // How many channels are allowed. + MaxChannels int `json:"max_channels"` + // Global limits. Any 0 value means that the limit is ignored (unlimited). + ChannelLimits + // Per-channel limits. Special values for limits in this map: + // - == 0 means that the corresponding global limit is used. + // - < 0 means that limit is ignored (unlimited). + PerChannel map[string]*ChannelLimits `json:"channels,omitempty"` +} + +// ChannelLimits defines limits for a given channel +type ChannelLimits struct { + // Limits for message stores + MsgStoreLimits + // Limits for subscriptions stores + SubStoreLimits + // How long without any active subscription and no new message + // before this channel can be deleted. + MaxInactivity time.Duration +} + +// MsgStoreLimits defines limits for a MsgStore. +// For global limits, a value of 0 means "unlimited". +// For per-channel limits, it means that the corresponding global +// limit is used. +type MsgStoreLimits struct { + // How many messages are allowed. + MaxMsgs int `json:"max_msgs"` + // How many bytes are allowed. + MaxBytes int64 `json:"max_bytes"` + // How long messages are kept in the log (unit is seconds) + MaxAge time.Duration `json:"max_age"` +} + +// SubStoreLimits defines limits for a SubStore +type SubStoreLimits struct { + // How many subscriptions are allowed. + MaxSubscriptions int `json:"max_subscriptions"` +} + +// DefaultStoreLimits are the limits that a Store must +// use when none are specified to the Store constructor. +// Store limits can be changed with the Store.SetLimits() method. +var DefaultStoreLimits = StoreLimits{ + 100, + ChannelLimits{ + MsgStoreLimits{ + MaxMsgs: 1000000, + MaxBytes: 1000000 * 1024, + }, + SubStoreLimits{ + MaxSubscriptions: 1000, + }, + 0, + }, + nil, +} + +// RecoveredState allows the server to reconstruct its state after a restart. +type RecoveredState struct { + Info *spb.ServerInfo + Clients []*Client + Channels map[string]*RecoveredChannel +} + +// RecoveredChannel represents a channel that has been recovered, with all its subscriptions +type RecoveredChannel struct { + Channel *Channel + Subscriptions []*RecoveredSubscription +} + +// PendingAcks is a set of message sequences waiting to be acknowledged. +type PendingAcks map[uint64]struct{} + +// RecoveredSubscription represents a recovered Subscription with a map +// of pending messages. +type RecoveredSubscription struct { + Sub *spb.SubState + Pending PendingAcks +} + +// Client represents a client with ID and Heartbeat Inbox. +type Client struct { + spb.ClientInfo +} + +// Channel contains a reference to both Subscription and Message stores. +type Channel struct { + // Subs is the Subscriptions Store. + Subs SubStore + // Msgs is the Messages Store. + Msgs MsgStore +} + +// Store is the storage interface for NATS Streaming servers. +// +// If an implementation has a Store constructor with StoreLimits, it should be +// noted that the limits don't apply to any state being recovered, for Store +// implementations supporting recovery. +// +type Store interface { + // GetExclusiveLock is an advisory lock to prevent concurrent + // access to the store from multiple instances. + // This is not to protect individual API calls, instead, it + // is meant to protect the store for the entire duration the + // store is being used. This is why there is no `Unlock` API. + // The lock should be released when the store is closed. + // + // If an exclusive lock can be immediately acquired (that is, + // it should not block waiting for the lock to be acquired), + // this call will return `true` with no error. Once a store + // instance has acquired an exclusive lock, calling this + // function has no effect and `true` with no error will again + // be returned. + // + // If the lock cannot be acquired, this call will return + // `false` with no error: the caller can try again later. + // + // If, however, the lock cannot be acquired due to a fatal + // error, this call should return `false` and the error. + // + // It is important to note that the implementation should + // make an effort to distinguish error conditions deemed + // fatal (and therefore trying again would invariably result + // in the same error) and those deemed transient, in which + // case no error should be returned to indicate that the + // caller could try later. + // + // Implementations that do not support exclusive locks should + // return `false` and `ErrNotSupported`. + GetExclusiveLock() (bool, error) + + // Init can be used to initialize the store with server's information. + Init(info *spb.ServerInfo) error + + // Name returns the name type of this store (e.g: MEMORY, FILESTORE, etc...). + Name() string + + // Recover returns the recovered state. + // Implementations that do not persist state and therefore cannot + // recover from a previous run MUST return nil, not an error. + // However, an error must be returned for implementations that are + // attempting to recover the state but fail to do so. + Recover() (*RecoveredState, error) + + // SetLimits sets limits for this store. The action is not expected + // to be retroactive. + // The store implementation should make a deep copy as to not change + // the content of the structure passed by the caller. + // This call may return an error due to limits validation errors. + SetLimits(limits *StoreLimits) error + + // GetChannelLimits returns the limit for this channel. If the channel + // does not exist, returns nil. + GetChannelLimits(name string) *ChannelLimits + + // CreateChannel creates a Channel. + // Implementations should return ErrAlreadyExists if the channel was + // already created. + // Limits defined for this channel in StoreLimits.PeChannel map, if present, + // will apply. Otherwise, the global limits in StoreLimits will apply. + CreateChannel(channel string) (*Channel, error) + + // DeleteChannel deletes a Channel. + // Implementations should make sure that if no error is returned, the + // channel would not be recovered after a restart, unless CreateChannel() + // with the same channel is invoked. + // If processing is expecting to be time consuming, work should be done + // in the background as long as the above condition is guaranteed. + // It is also acceptable for an implementation to have CreateChannel() + // return an error if background deletion is still happening for a + // channel of the same name. + DeleteChannel(channel string) error + + // AddClient stores information about the client identified by `clientID`. + AddClient(info *spb.ClientInfo) (*Client, error) + + // DeleteClient removes the client identified by `clientID` from the store. + DeleteClient(clientID string) error + + // Close closes this store (including all MsgStore and SubStore). + // If an exclusive lock was acquired, the lock shall be released. + Close() error +} + +// SubStore is the interface for storage of Subscriptions on a given channel. +// +// Implementations of this interface should not attempt to validate that +// a subscription is valid (that is, has not been deleted) when processing +// updates. +type SubStore interface { + // CreateSub records a new subscription represented by SubState. On success, + // it records the subscription's ID in SubState.ID. This ID is to be used + // by the other SubStore methods. + CreateSub(*spb.SubState) error + + // UpdateSub updates a given subscription represented by SubState. + UpdateSub(*spb.SubState) error + + // DeleteSub invalidates the subscription 'subid'. + DeleteSub(subid uint64) error + + // AddSeqPending adds the given message 'seqno' to the subscription 'subid'. + AddSeqPending(subid, seqno uint64) error + + // AckSeqPending records that the given message 'seqno' has been acknowledged + // by the subscription 'subid'. + AckSeqPending(subid, seqno uint64) error + + // Flush is for stores that may buffer operations and need them to be persisted. + Flush() error + + // Close closes the subscriptions store. + Close() error +} + +// MsgStore is the interface for storage of Messages on a given channel. +type MsgStore interface { + // State returns some statistics related to this store. + State() (numMessages int, byteSize uint64, err error) + + // Store stores a message and returns the message sequence. + Store(msg *pb.MsgProto) (uint64, error) + + // Lookup returns the stored message with given sequence number. + Lookup(seq uint64) (*pb.MsgProto, error) + + // FirstSequence returns sequence for first message stored, 0 if no + // message is stored. + FirstSequence() (uint64, error) + + // LastSequence returns sequence for last message stored, 0 if no + // message is stored. + LastSequence() (uint64, error) + + // FirstAndLastSequence returns sequences for the first and last messages stored, + // 0 if no message is stored. + FirstAndLastSequence() (uint64, uint64, error) + + // GetSequenceFromTimestamp returns the sequence of the first message whose + // timestamp is greater or equal to given timestamp. + GetSequenceFromTimestamp(timestamp int64) (uint64, error) + + // FirstMsg returns the first message stored. + FirstMsg() (*pb.MsgProto, error) + + // LastMsg returns the last message stored. + LastMsg() (*pb.MsgProto, error) + + // Flush is for stores that may buffer operations and need them to be persisted. + Flush() error + + // Empty removes all messages from the store + Empty() error + + // Close closes the store. + Close() error +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/util/channels.go b/vendor/github.com/nats-io/nats-streaming-server/util/channels.go new file mode 100644 index 00000000000..9428b552d5d --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/util/channels.go @@ -0,0 +1,130 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "errors" + "fmt" + + "github.com/nats-io/go-nats" + + "github.com/nats-io/nats-streaming-server/spb" +) + +// Number of bytes used to encode a channel name +const encodedChannelLen = 2 + +// SendsChannelsList sends the list of channels to the given subject, possibly +// splitting the list in several requests if it cannot fit in a single message. +func SendChannelsList(channels []string, sendInbox, replyInbox string, nc *nats.Conn, serverID string) error { + // Since the NATS message payload is limited, we need to repeat + // requests if all channels can't fit in a request. + maxPayload := int(nc.MaxPayload()) + // Reuse this request object to send the (possibly many) protocol message(s). + header := &spb.CtrlMsg{ + ServerID: serverID, + MsgType: spb.CtrlMsg_Partitioning, + } + // The Data field (a byte array) will require 1+len(array)+(encoded size of array). + // To be conservative, let's just use a 8 bytes integer + headerSize := header.Size() + 1 + 8 + var ( + bytes []byte // Reused buffer in which the request is to marshal info + n int // Size of the serialized request in the above buffer + count int // Number of channels added to the request + ) + for start := 0; start != len(channels); start += count { + bytes, n, count = encodeChannelsRequest(header, channels, bytes, headerSize, maxPayload, start) + if count == 0 { + return errors.New("message payload too small to send channels list") + } + if err := nc.PublishRequest(sendInbox, replyInbox, bytes[:n]); err != nil { + return err + } + } + return nc.Flush() +} + +// DecodeChannels decodes from the given byte array the list of channel names +// and return them as an array of strings. +func DecodeChannels(data []byte) ([]string, error) { + channels := []string{} + pos := 0 + for pos < len(data) { + if pos+2 > len(data) { + return nil, fmt.Errorf("unable to decode size, pos=%v len=%v", pos, len(data)) + } + cl := int(ByteOrder.Uint16(data[pos:])) + pos += encodedChannelLen + end := pos + cl + if end > len(data) { + return nil, fmt.Errorf("unable to decode channel, pos=%v len=%v max=%v (string=%v)", + pos, cl, len(data), string(data[pos:])) + } + c := string(data[pos:end]) + channels = append(channels, c) + pos = end + } + return channels, nil +} + +// Adds as much channels as possible (based on the NATS max message payload) and +// returns a serialized request. The buffer `reqBytes` is passed (and returned) so +// that it can be reused if more than one request is needed. This call will +// expand the size as needed. The number of bytes used in this buffer is returned +// along with the number of encoded channels. +func encodeChannelsRequest(request *spb.CtrlMsg, channels []string, reqBytes []byte, + headerSize, maxPayload, start int) ([]byte, int, int) { + + // Each string will be encoded in the form: + // - length (2 bytes) + // - string as a byte array. + var _encodedSize = [encodedChannelLen]byte{} + encodedSize := _encodedSize[:] + // We are going to encode the channels in this buffer + chanBuf := make([]byte, 0, maxPayload) + var ( + count int // Number of encoded channels + estimatedSize = headerSize // This is not an overestimation of the total size + numBytes int // This is what is returned by MarshalTo + ) + for i := start; i < len(channels); i++ { + c := []byte(channels[i]) + cl := len(c) + needed := encodedChannelLen + cl + // Check if adding this channel to current buffer makes us go over + if estimatedSize+needed > maxPayload { + // Special case if we cannot even encode 1 channel + if count == 0 { + return reqBytes, 0, 0 + } + break + } + // Encoding the channel here. First the size, then the channel name as byte array. + ByteOrder.PutUint16(encodedSize, uint16(cl)) + chanBuf = append(chanBuf, encodedSize...) + chanBuf = append(chanBuf, c...) + count++ + estimatedSize += needed + } + if count > 0 { + request.Data = chanBuf + reqBytes = EnsureBufBigEnough(reqBytes, estimatedSize) + numBytes, _ = request.MarshalTo(reqBytes) + if numBytes > maxPayload { + panic(fmt.Errorf("request size is %v (max payload is %v)", numBytes, maxPayload)) + } + } + return reqBytes, numBytes, count +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/util/lockfile_unix.go b/vendor/github.com/nats-io/nats-streaming-server/util/lockfile_unix.go new file mode 100644 index 00000000000..f917b725acb --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/util/lockfile_unix.go @@ -0,0 +1,97 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !windows + +package util + +import ( + "io" + "os" + "sync" + "syscall" +) + +type lockFile struct { + sync.Mutex + f *os.File +} + +// CreateLockFile attempt to lock the given file, creating it +// if necessary. On success, the file is returned, otherwise +// an error is returned. +// The file returned should be closed to release the lock +// quicker than if left to the operating system. +func CreateLockFile(file string) (LockFile, error) { + f, err := os.Create(file) + if err != nil { + // Consider those fatal, others may be considered transient + // (for instance FD limit reached, etc...) + if os.IsNotExist(err) || os.IsPermission(err) { + return nil, err + } + return nil, ErrUnableToLockNow + } + spec := &syscall.Flock_t{ + Type: syscall.F_WRLCK, + Whence: int16(io.SeekStart), + Start: 0, + Len: 0, // 0 means to lock the entire file. + } + if err := syscall.FcntlFlock(f.Fd(), syscall.F_SETLK, spec); err != nil { + // Try to gather all errors that we deem transient and return + // ErrUnableToLockNow in this case to indicate the caller that + // the lock could not be acquired at this time but it could + // try later. + // Basing this from possible ERRORS from this page: + // http://pubs.opengroup.org/onlinepubs/009695399/functions/fcntl.html + if err == syscall.EAGAIN || err == syscall.EACCES || + err == syscall.EINTR || err == syscall.ENOLCK { + err = ErrUnableToLockNow + } + // TODO: If error is not ErrUnableToLockNow, it may mean that + // the call is not supported on that platform, etc... + // We should have another level of verification, for instance + // check content of the lockfile is not being updated by the + // owner of the file, etc... + f.Close() + return nil, err + } + return &lockFile{f: f}, nil +} + +// Close implements the LockFile interface +func (lf *lockFile) Close() error { + lf.Lock() + defer lf.Unlock() + if lf.f == nil { + return nil + } + spec := &syscall.Flock_t{ + Type: syscall.F_UNLCK, + Whence: int16(io.SeekStart), + Start: 0, + Len: 0, // 0 means to lock the entire file. + } + err := syscall.FcntlFlock(lf.f.Fd(), syscall.F_SETLK, spec) + err = CloseFile(err, lf.f) + lf.f = nil + return err +} + +// IsClosed implements the LockFile interface +func (lf *lockFile) IsClosed() bool { + lf.Lock() + defer lf.Unlock() + return lf.f == nil +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/util/lockfile_win.go b/vendor/github.com/nats-io/nats-streaming-server/util/lockfile_win.go new file mode 100644 index 00000000000..da6fe96ccc0 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/util/lockfile_win.go @@ -0,0 +1,77 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build windows + +package util + +import ( + "strings" + "sync" + "syscall" +) + +type lockFile struct { + sync.Mutex + f syscall.Handle +} + +// CreateLockFile attempt to lock the given file, creating it +// if necessary. On success, the file is returned, otherwise +// an error is returned. +// The file returned should be closed to release the lock +// quicker than if left to the operating system. +func CreateLockFile(file string) (LockFile, error) { + fname, err := syscall.UTF16PtrFromString(file) + if err != nil { + return nil, err + } + f, err := syscall.CreateFile(fname, + syscall.GENERIC_READ|syscall.GENERIC_WRITE, + 0, // dwShareMode: 0 means "Prevents other processes from opening a file or device if they request delete, read, or write access." + nil, + syscall.CREATE_ALWAYS, + syscall.FILE_ATTRIBUTE_NORMAL, + 0, + ) + if err != nil { + // TODO: There HAS to be a better way, but I can't seem to + // find how to get Windows error codes (also syscall.GetLastError() + // returns nil here). + if strings.Contains(err.Error(), "used by another process") { + err = ErrUnableToLockNow + } + syscall.CloseHandle(f) + return nil, err + } + return &lockFile{f: f}, nil +} + +// Close implements the LockFile interface +func (lf *lockFile) Close() error { + lf.Lock() + defer lf.Unlock() + if lf.f == syscall.InvalidHandle { + return nil + } + err := syscall.CloseHandle(lf.f) + lf.f = syscall.InvalidHandle + return err +} + +// IsClosed implements the LockFile interface +func (lf *lockFile) IsClosed() bool { + lf.Lock() + defer lf.Unlock() + return lf.f == syscall.InvalidHandle +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/util/no_race.go b/vendor/github.com/nats-io/nats-streaming-server/util/no_race.go new file mode 100644 index 00000000000..0031b9c27ec --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/util/no_race.go @@ -0,0 +1,21 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !race + +package util + +// RaceEnabled indicates that program/tests are running with race detection +// enabled or not. Some tests may chose to skip execution when race +// detection is on. +const RaceEnabled = false diff --git a/vendor/github.com/nats-io/nats-streaming-server/util/race.go b/vendor/github.com/nats-io/nats-streaming-server/util/race.go new file mode 100644 index 00000000000..e7404c750cb --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/util/race.go @@ -0,0 +1,21 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build race + +package util + +// RaceEnabled indicates that program/tests are running with race detection +// enabled or not. Some tests may chose to skip execution when race +// detection is on. +const RaceEnabled = true diff --git a/vendor/github.com/nats-io/nats-streaming-server/util/sublist.go b/vendor/github.com/nats-io/nats-streaming-server/util/sublist.go new file mode 100644 index 00000000000..3d2b5292fa7 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/util/sublist.go @@ -0,0 +1,521 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "errors" + "sync" +) + +// This is taken from NATS Server's sublist and modified to use +// an interface{} instead of a *subscription which is relevant +// only in NATS Server. + +// Common byte variables for wildcards and token separator. +const ( + pwc = '*' + spwc = "*" + fwc = '>' + sfwc = ">" + tsep = "." + btsep = '.' +) + +// cacheMax is used to bound limit the frontend cache +const slCacheMax = 1024 + +// Sublist related errors +var ( + ErrInvalidSubject = errors.New("sublist: invalid subject") + ErrNotFound = errors.New("sublist: no match found") +) + +// A Sublist stores and efficiently retrieves subscriptions. +type Sublist struct { + sync.RWMutex + root *level + cache map[string][]interface{} + count uint32 +} + +// A node contains subscriptions and a pointer to the next level. +type node struct { + next *level + elements []interface{} +} + +// A level represents a group of nodes and special pointers to +// wildcard nodes. +type level struct { + nodes map[string]*node + pwc, fwc *node +} + +// Create a new default node. +func newNode() *node { + return &node{elements: make([]interface{}, 0, 4)} +} + +// Create a new default level. +func newLevel() *level { + return &level{nodes: make(map[string]*node)} +} + +// NewSublist creates a default sublist +func NewSublist() *Sublist { + return &Sublist{root: newLevel(), cache: make(map[string][]interface{})} +} + +// Insert adds a subscription into the sublist +func (s *Sublist) Insert(subject string, element interface{}) error { + tsa := [32]string{} + tokens := tsa[:0] + start := 0 + for i := 0; i < len(subject); i++ { + if subject[i] == btsep { + tokens = append(tokens, subject[start:i]) + start = i + 1 + } + } + tokens = append(tokens, subject[start:]) + + s.Lock() + + sfwc := false + l := s.root + var n *node + + for _, t := range tokens { + if len(t) == 0 || sfwc { + s.Unlock() + return ErrInvalidSubject + } + + switch t[0] { + case pwc: + n = l.pwc + case fwc: + n = l.fwc + sfwc = true + default: + n = l.nodes[t] + } + if n == nil { + n = newNode() + switch t[0] { + case pwc: + l.pwc = n + case fwc: + l.fwc = n + default: + l.nodes[t] = n + } + } + if n.next == nil { + n.next = newLevel() + } + l = n.next + } + n.elements = append(n.elements, element) + s.addToCache(subject, element) + s.count++ + s.Unlock() + return nil +} + +// addToCache will add the new entry to existing cache +// entries if needed. Assumes write lock is held. +func (s *Sublist) addToCache(subject string, element interface{}) { + for k, r := range s.cache { + if matchLiteral(k, subject) { + // Copy since others may have a reference. + nr := append([]interface{}(nil), r...) + nr = append(nr, element) + s.cache[k] = nr + } + } +} + +// removeFromCache will remove any active cache entries on that subject. +// Assumes write lock is held. +func (s *Sublist) removeFromCache(subject string) { + for k := range s.cache { + if !matchLiteral(k, subject) { + continue + } + // Since someone else may be referencing, can't modify the list + // safely, just let it re-populate. + delete(s.cache, k) + } +} + +// Match will match all entries to the literal subject. +// It will return a set of results. +func (s *Sublist) Match(subject string) []interface{} { + s.RLock() + rc, ok := s.cache[subject] + s.RUnlock() + if ok { + return rc + } + + tsa := [32]string{} + tokens := tsa[:0] + start := 0 + for i := 0; i < len(subject); i++ { + if subject[i] == btsep { + tokens = append(tokens, subject[start:i]) + start = i + 1 + } + } + tokens = append(tokens, subject[start:]) + result := make([]interface{}, 0, 4) + + s.Lock() + matchLevel(s.root, tokens, &result) + + // Add to our cache + s.cache[subject] = result + // Bound the number of entries to sublistMaxCache + if len(s.cache) > slCacheMax { + for k := range s.cache { + delete(s.cache, k) + break + } + } + s.Unlock() + + return result +} + +// matchLevel is used to recursively descend into the trie. +func matchLevel(l *level, toks []string, results *[]interface{}) { + var pwc, n *node + for i, t := range toks { + if l == nil { + return + } + if l.fwc != nil { + *results = append(*results, l.fwc.elements...) + } + if pwc = l.pwc; pwc != nil { + matchLevel(pwc.next, toks[i+1:], results) + } + n = l.nodes[t] + if n != nil { + l = n.next + } else { + l = nil + } + } + if pwc != nil { + *results = append(*results, pwc.elements...) + } + if n != nil { + *results = append(*results, n.elements...) + } +} + +// lnt is used to track descent into levels for a removal for pruning. +type lnt struct { + l *level + n *node + t string +} + +// Remove will remove an element from the sublist. +func (s *Sublist) Remove(subject string, element interface{}) error { + tsa := [32]string{} + tokens := tsa[:0] + start := 0 + for i := 0; i < len(subject); i++ { + if subject[i] == btsep { + tokens = append(tokens, subject[start:i]) + start = i + 1 + } + } + tokens = append(tokens, subject[start:]) + + s.Lock() + defer s.Unlock() + + sfwc := false + l := s.root + var n *node + + // Track levels for pruning + var lnts [32]lnt + levels := lnts[:0] + + for _, t := range tokens { + if len(t) == 0 || sfwc { + return ErrInvalidSubject + } + if l == nil { + return ErrNotFound + } + switch t[0] { + case pwc: + n = l.pwc + case fwc: + n = l.fwc + sfwc = true + default: + n = l.nodes[t] + } + if n != nil { + levels = append(levels, lnt{l, n, t}) + l = n.next + } else { + l = nil + } + } + if !s.removeFromNode(n, element) { + return ErrNotFound + } + s.count-- + for i := len(levels) - 1; i >= 0; i-- { + l, n, t := levels[i].l, levels[i].n, levels[i].t + if n.isEmpty() { + l.pruneNode(n, t) + } + } + s.removeFromCache(subject) + return nil +} + +// pruneNode is used to prune an empty node from the tree. +func (l *level) pruneNode(n *node, t string) { + if n == nil { + return + } + if n == l.fwc { + l.fwc = nil + } else if n == l.pwc { + l.pwc = nil + } else { + delete(l.nodes, t) + } +} + +// isEmpty will test if the node has any entries. Used +// in pruning. +func (n *node) isEmpty() bool { + if len(n.elements) == 0 { + if n.next == nil || n.next.numNodes() == 0 { + return true + } + } + return false +} + +// Return the number of nodes for the given level. +func (l *level) numNodes() int { + num := len(l.nodes) + if l.pwc != nil { + num++ + } + if l.fwc != nil { + num++ + } + return num +} + +// Removes an element from a list. +func removeFromList(element interface{}, l []interface{}) ([]interface{}, bool) { + for i := 0; i < len(l); i++ { + if l[i] == element { + last := len(l) - 1 + l[i] = l[last] + l[last] = nil + l = l[:last] + return shrinkAsNeeded(l), true + } + } + return l, false +} + +// Remove the sub for the given node. +func (s *Sublist) removeFromNode(n *node, element interface{}) (found bool) { + if n == nil { + return false + } + n.elements, found = removeFromList(element, n.elements) + return found +} + +// Checks if we need to do a resize. This is for very large growth then +// subsequent return to a more normal size from unsubscribe. +func shrinkAsNeeded(l []interface{}) []interface{} { + ll := len(l) + cl := cap(l) + // Don't bother if list not too big + if cl <= 8 { + return l + } + pFree := float32(cl-ll) / float32(cl) + if pFree > 0.50 { + return append([]interface{}(nil), l...) + } + return l +} + +// Count returns the number of subscriptions. +func (s *Sublist) Count() uint32 { + s.RLock() + defer s.RUnlock() + return s.count +} + +// CacheCount returns the number of result sets in the cache. +func (s *Sublist) CacheCount() int { + s.RLock() + defer s.RUnlock() + return len(s.cache) +} + +// matchLiteral is used to test literal subjects, those that do not have any +// wildcards, with a target subject. This is used in the cache layer. +func matchLiteral(literal, subject string) bool { + li := 0 + ll := len(literal) + for i := 0; i < len(subject); i++ { + if li >= ll { + return false + } + b := subject[i] + switch b { + case pwc: + // Skip token in literal + ll := len(literal) + for { + if li >= ll || literal[li] == btsep { + li-- + break + } + li++ + } + case fwc: + return true + default: + if b != literal[li] { + return false + } + } + li++ + } + // Make sure we have processed all of the literal's chars.. + return li >= ll +} + +// NumLevels returns the maximum number of levels in the sublist. +func (s *Sublist) NumLevels() int { + return visitLevel(s.root, 0) +} + +// visitLevel is used to descend the Sublist tree structure +// recursively. +func visitLevel(l *level, depth int) int { + if l == nil || l.numNodes() == 0 { + return depth + } + + depth++ + maxDepth := depth + + for _, n := range l.nodes { + if n == nil { + continue + } + newDepth := visitLevel(n.next, depth) + if newDepth > maxDepth { + maxDepth = newDepth + } + } + if l.pwc != nil { + pwcDepth := visitLevel(l.pwc.next, depth) + if pwcDepth > maxDepth { + maxDepth = pwcDepth + } + } + if l.fwc != nil { + fwcDepth := visitLevel(l.fwc.next, depth) + if fwcDepth > maxDepth { + maxDepth = fwcDepth + } + } + return maxDepth +} + +// Subjects returns an array of all subjects in this sublist +// ordered from the widest to the narrowest of subjects. +// Order between non wildcard tokens in a given level is +// random though. +// +// For instance, if the sublist contains (in any inserted order): +// +// *.*, foo.>, *.>, foo.*.>, >, bar.>, foo.bar.>, bar.baz +// +// the returned array will be one of the two possibilities: +// +// >, *.>, *.*, foo.>, foo.*.>, foo.bar.>, bar.>, bar.baz +// +// or +// +// >, *.>, *.*, bar.>, bar.baz, foo.>, foo.*.>, foo.bar.> +// +// For a given level, the order will still always be from +// wider to narrower, that is, foo.> comes before foo.*.> +// which comes before foo.bar.>, and bar.> always comes +// before bar.baz, but all the "bar" subjects may be +// before or after all the "foo" subjects. +func (s *Sublist) Subjects() []string { + s.RLock() + defer s.RUnlock() + subjects := make([]string, 0, s.count) + getSubjects(s.root, "", &subjects) + return subjects +} + +func getSubjects(l *level, subject string, res *[]string) { + if l == nil || l.numNodes() == 0 { + *res = append(*res, subject) + return + } + var fs string + if l.fwc != nil { + if subject != "" { + fs = subject + tsep + sfwc + } else { + fs = sfwc + } + getSubjects(l.fwc.next, fs, res) + } + if l.pwc != nil { + if subject != "" { + fs = subject + tsep + spwc + } else { + fs = spwc + } + getSubjects(l.pwc.next, fs, res) + } + for s, n := range l.nodes { + if subject != "" { + fs = subject + tsep + s + } else { + fs = s + } + getSubjects(n.next, fs, res) + } +} diff --git a/vendor/github.com/nats-io/nats-streaming-server/util/util.go b/vendor/github.com/nats-io/nats-streaming-server/util/util.go new file mode 100644 index 00000000000..bc9b0d1d347 --- /dev/null +++ b/vendor/github.com/nats-io/nats-streaming-server/util/util.go @@ -0,0 +1,217 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "time" +) + +// ErrUnableToLockNow is used to indicate that a lock cannot be +// immediately acquired. +var ErrUnableToLockNow = errors.New("unable to acquire the lock at the moment") + +// LockFile is an interface for lock files utility. +type LockFile interface { + io.Closer + IsClosed() bool +} + +// ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit +// unsigned integers. +var ByteOrder binary.ByteOrder + +func init() { + ByteOrder = binary.LittleEndian +} + +// BackoffTimeCheck allows to execute some code, but not too often. +type BackoffTimeCheck struct { + nextTime time.Time + frequency time.Duration + minFrequency time.Duration + maxFrequency time.Duration + factor int +} + +// NewBackoffTimeCheck creates an instance of BackoffTimeCheck. +// The `minFrequency` indicates how frequently BackoffTimeCheck.Ok() can return true. +// When Ok() returns true, the allowed frequency is multiplied by `factor`. The +// resulting frequency is capped by `maxFrequency`. +func NewBackoffTimeCheck(minFrequency time.Duration, factor int, maxFrequency time.Duration) (*BackoffTimeCheck, error) { + if minFrequency <= 0 || factor < 1 || maxFrequency < minFrequency { + return nil, fmt.Errorf("minFrequency must be positive, factor at least 1 and maxFrequency at least equal to minFrequency, got %v - %v - %v", + minFrequency, factor, maxFrequency) + } + return &BackoffTimeCheck{ + frequency: minFrequency, + minFrequency: minFrequency, + maxFrequency: maxFrequency, + factor: factor, + }, nil +} + +// Ok returns true for the first time it is invoked after creation of the object +// or call to Reset(), or after an amount of time (based on the last success +// and the allowed frequency) has elapsed. +// When at the maximum frequency, if this call is made after a delay at least +// equal to 3x the max frequency (or in other words, 2x after what was the target +// for the next print), then the object is auto-reset. +func (bp *BackoffTimeCheck) Ok() bool { + if bp.nextTime.IsZero() { + bp.nextTime = time.Now().Add(bp.minFrequency) + return true + } + now := time.Now() + if now.Before(bp.nextTime) { + return false + } + // If we are already at the max frequency and this call + // is made after 2x the max frequency, then auto-reset. + if bp.frequency == bp.maxFrequency && + now.Sub(bp.nextTime) >= 2*bp.maxFrequency { + bp.Reset() + return true + } + if bp.frequency < bp.maxFrequency { + bp.frequency *= time.Duration(bp.factor) + if bp.frequency > bp.maxFrequency { + bp.frequency = bp.maxFrequency + } + } + bp.nextTime = now.Add(bp.frequency) + return true +} + +// Reset the state so that next call to BackoffPrint.Ok() will return true. +func (bp *BackoffTimeCheck) Reset() { + bp.nextTime = time.Time{} + bp.frequency = bp.minFrequency +} + +// EnsureBufBigEnough checks that given buffer is big enough to hold 'needed' +// bytes, otherwise returns a buffer of a size of at least 'needed' bytes. +func EnsureBufBigEnough(buf []byte, needed int) []byte { + if buf == nil { + return make([]byte, needed) + } else if needed > len(buf) { + return make([]byte, int(float32(needed)*1.1)) + } + return buf +} + +// WriteInt writes an int (4 bytes) to the given writer using ByteOrder. +func WriteInt(w io.Writer, v int) error { + var b [4]byte + + bs := b[:4] + + ByteOrder.PutUint32(bs, uint32(v)) + _, err := w.Write(bs) + return err +} + +// ReadInt reads an int (4 bytes) from the reader using ByteOrder. +func ReadInt(r io.Reader) (int, error) { + var b [4]byte + + bs := b[:4] + + _, err := io.ReadFull(r, bs) + if err != nil { + return 0, err + } + return int(ByteOrder.Uint32(bs)), nil +} + +// CloseFile closes the given file and report the possible error only +// if the given error `err` is not already set. +func CloseFile(err error, f io.Closer) error { + if lerr := f.Close(); lerr != nil && err == nil { + err = lerr + } + return err +} + +// IsChannelNameValid returns false if any of these conditions for +// the channel name apply: +// - is empty +// - contains the `/` character +// - token separator `.` is first or last +// - there are two consecutives token separators `.` +// if wildcardsAllowed is false: +// - contains wildcards `*` or `>` +// if wildcardsAllowed is true: +// - '*' or '>' are not a token in their own +// - `>` is not the last token +func IsChannelNameValid(channel string, wildcardsAllowed bool) bool { + if channel == "" || channel[0] == btsep { + return false + } + for i := 0; i < len(channel); i++ { + c := channel[i] + if c == '/' { + return false + } + if (c == btsep) && (i == len(channel)-1 || channel[i+1] == btsep) { + return false + } + if !wildcardsAllowed { + if c == pwc || c == fwc { + return false + } + } else if c == pwc || c == fwc { + if i > 0 && channel[i-1] != btsep { + return false + } + if c == fwc && i != len(channel)-1 { + return false + } + if i < len(channel)-1 && channel[i+1] != btsep { + return false + } + } + } + return true +} + +// IsChannelNameLiteral returns true if the channel name is a literal (that is, +// it does not contain any wildcard). +// The channel name is assumed to be valid. +func IsChannelNameLiteral(channel string) bool { + for i := 0; i < len(channel); i++ { + if channel[i] == pwc || channel[i] == fwc { + return false + } + } + return true +} + +// FriendlyBytes returns a string with the given bytes int64 +// represented as a size, such as 1KB, 10MB, etc... +func FriendlyBytes(bytes int64) string { + fbytes := float64(bytes) + base := 1024 + pre := []string{"K", "M", "G", "T", "P", "E"} + if fbytes < float64(base) { + return fmt.Sprintf("%v B", fbytes) + } + exp := int(math.Log(fbytes) / math.Log(float64(base))) + index := exp - 1 + return fmt.Sprintf("%.2f %sB", fbytes/math.Pow(float64(base), float64(exp)), pre[index]) +} diff --git a/vendor/github.com/nats-io/nuid/LICENSE b/vendor/github.com/nats-io/nuid/LICENSE new file mode 100644 index 00000000000..cadc3a496c8 --- /dev/null +++ b/vendor/github.com/nats-io/nuid/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2012-2016 Apcera Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/nats-io/nuid/nuid.go b/vendor/github.com/nats-io/nuid/nuid.go new file mode 100644 index 00000000000..1fda3770761 --- /dev/null +++ b/vendor/github.com/nats-io/nuid/nuid.go @@ -0,0 +1,124 @@ +// Copyright 2016 Apcera Inc. All rights reserved. + +// A unique identifier generator that is high performance, very fast, and tries to be entropy pool friendly. +package nuid + +import ( + "crypto/rand" + "fmt" + "math" + "math/big" + "sync" + "time" + + prand "math/rand" +) + +// NUID needs to be very fast to generate and truly unique, all while being entropy pool friendly. +// We will use 12 bytes of crypto generated data (entropy draining), and 10 bytes of sequential data +// that is started at a pseudo random number and increments with a pseudo-random increment. +// Total is 22 bytes of base 62 ascii text :) + +// Version of the library +const Version = "1.0.0" + +const ( + digits = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + base = 62 + preLen = 12 + seqLen = 10 + maxSeq = int64(839299365868340224) // base^seqLen == 62^10 + minInc = int64(33) + maxInc = int64(333) + totalLen = preLen + seqLen +) + +type NUID struct { + pre []byte + seq int64 + inc int64 +} + +type lockedNUID struct { + sync.Mutex + *NUID +} + +// Global NUID +var globalNUID *lockedNUID + +// Seed sequential random with crypto or math/random and current time +// and generate crypto prefix. +func init() { + r, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) + if err != nil { + prand.Seed(time.Now().UnixNano()) + } else { + prand.Seed(r.Int64()) + } + globalNUID = &lockedNUID{NUID: New()} + globalNUID.RandomizePrefix() +} + +// New will generate a new NUID and properly initialize the prefix, sequential start, and sequential increment. +func New() *NUID { + n := &NUID{ + seq: prand.Int63n(maxSeq), + inc: minInc + prand.Int63n(maxInc-minInc), + pre: make([]byte, preLen), + } + n.RandomizePrefix() + return n +} + +// Generate the next NUID string from the global locked NUID instance. +func Next() string { + globalNUID.Lock() + nuid := globalNUID.Next() + globalNUID.Unlock() + return nuid +} + +// Generate the next NUID string. +func (n *NUID) Next() string { + // Increment and capture. + n.seq += n.inc + if n.seq >= maxSeq { + n.RandomizePrefix() + n.resetSequential() + } + seq := n.seq + + // Copy prefix + var b [totalLen]byte + bs := b[:preLen] + copy(bs, n.pre) + + // copy in the seq in base36. + for i, l := len(b), seq; i > preLen; l /= base { + i -= 1 + b[i] = digits[l%base] + } + return string(b[:]) +} + +// Resets the sequential portion of the NUID. +func (n *NUID) resetSequential() { + n.seq = prand.Int63n(maxSeq) + n.inc = minInc + prand.Int63n(maxInc-minInc) +} + +// Generate a new prefix from crypto/rand. +// This call *can* drain entropy and will be called automatically when we exhaust the sequential range. +// Will panic if it gets an error from rand.Int() +func (n *NUID) RandomizePrefix() { + var cb [preLen]byte + cbs := cb[:] + if nb, err := rand.Read(cbs); nb != preLen || err != nil { + panic(fmt.Sprintf("nuid: failed generating crypto random number: %v\n", err)) + } + + for i := 0; i < preLen; i++ { + n.pre[i] = digits[int(cbs[i])%base] + } +} diff --git a/vendor/github.com/pmezard/go-difflib/LICENSE b/vendor/github.com/pmezard/go-difflib/LICENSE new file mode 100644 index 00000000000..c67dad612a3 --- /dev/null +++ b/vendor/github.com/pmezard/go-difflib/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2013, Patrick Mezard +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + The names of its contributors may not be used to endorse or promote +products derived from this software without specific prior written +permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/pmezard/go-difflib/difflib/difflib.go b/vendor/github.com/pmezard/go-difflib/difflib/difflib.go new file mode 100644 index 00000000000..003e99fadb4 --- /dev/null +++ b/vendor/github.com/pmezard/go-difflib/difflib/difflib.go @@ -0,0 +1,772 @@ +// Package difflib is a partial port of Python difflib module. +// +// It provides tools to compare sequences of strings and generate textual diffs. +// +// The following class and functions have been ported: +// +// - SequenceMatcher +// +// - unified_diff +// +// - context_diff +// +// Getting unified diffs was the main goal of the port. Keep in mind this code +// is mostly suitable to output text differences in a human friendly way, there +// are no guarantees generated diffs are consumable by patch(1). +package difflib + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strings" +) + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +func calculateRatio(matches, length int) float64 { + if length > 0 { + return 2.0 * float64(matches) / float64(length) + } + return 1.0 +} + +type Match struct { + A int + B int + Size int +} + +type OpCode struct { + Tag byte + I1 int + I2 int + J1 int + J2 int +} + +// SequenceMatcher compares sequence of strings. The basic +// algorithm predates, and is a little fancier than, an algorithm +// published in the late 1980's by Ratcliff and Obershelp under the +// hyperbolic name "gestalt pattern matching". The basic idea is to find +// the longest contiguous matching subsequence that contains no "junk" +// elements (R-O doesn't address junk). The same idea is then applied +// recursively to the pieces of the sequences to the left and to the right +// of the matching subsequence. This does not yield minimal edit +// sequences, but does tend to yield matches that "look right" to people. +// +// SequenceMatcher tries to compute a "human-friendly diff" between two +// sequences. Unlike e.g. UNIX(tm) diff, the fundamental notion is the +// longest *contiguous* & junk-free matching subsequence. That's what +// catches peoples' eyes. The Windows(tm) windiff has another interesting +// notion, pairing up elements that appear uniquely in each sequence. +// That, and the method here, appear to yield more intuitive difference +// reports than does diff. This method appears to be the least vulnerable +// to synching up on blocks of "junk lines", though (like blank lines in +// ordinary text files, or maybe "

" lines in HTML files). That may be +// because this is the only method of the 3 that has a *concept* of +// "junk" . +// +// Timing: Basic R-O is cubic time worst case and quadratic time expected +// case. SequenceMatcher is quadratic time for the worst case and has +// expected-case behavior dependent in a complicated way on how many +// elements the sequences have in common; best case time is linear. +type SequenceMatcher struct { + a []string + b []string + b2j map[string][]int + IsJunk func(string) bool + autoJunk bool + bJunk map[string]struct{} + matchingBlocks []Match + fullBCount map[string]int + bPopular map[string]struct{} + opCodes []OpCode +} + +func NewMatcher(a, b []string) *SequenceMatcher { + m := SequenceMatcher{autoJunk: true} + m.SetSeqs(a, b) + return &m +} + +func NewMatcherWithJunk(a, b []string, autoJunk bool, + isJunk func(string) bool) *SequenceMatcher { + + m := SequenceMatcher{IsJunk: isJunk, autoJunk: autoJunk} + m.SetSeqs(a, b) + return &m +} + +// Set two sequences to be compared. +func (m *SequenceMatcher) SetSeqs(a, b []string) { + m.SetSeq1(a) + m.SetSeq2(b) +} + +// Set the first sequence to be compared. The second sequence to be compared is +// not changed. +// +// SequenceMatcher computes and caches detailed information about the second +// sequence, so if you want to compare one sequence S against many sequences, +// use .SetSeq2(s) once and call .SetSeq1(x) repeatedly for each of the other +// sequences. +// +// See also SetSeqs() and SetSeq2(). +func (m *SequenceMatcher) SetSeq1(a []string) { + if &a == &m.a { + return + } + m.a = a + m.matchingBlocks = nil + m.opCodes = nil +} + +// Set the second sequence to be compared. The first sequence to be compared is +// not changed. +func (m *SequenceMatcher) SetSeq2(b []string) { + if &b == &m.b { + return + } + m.b = b + m.matchingBlocks = nil + m.opCodes = nil + m.fullBCount = nil + m.chainB() +} + +func (m *SequenceMatcher) chainB() { + // Populate line -> index mapping + b2j := map[string][]int{} + for i, s := range m.b { + indices := b2j[s] + indices = append(indices, i) + b2j[s] = indices + } + + // Purge junk elements + m.bJunk = map[string]struct{}{} + if m.IsJunk != nil { + junk := m.bJunk + for s, _ := range b2j { + if m.IsJunk(s) { + junk[s] = struct{}{} + } + } + for s, _ := range junk { + delete(b2j, s) + } + } + + // Purge remaining popular elements + popular := map[string]struct{}{} + n := len(m.b) + if m.autoJunk && n >= 200 { + ntest := n/100 + 1 + for s, indices := range b2j { + if len(indices) > ntest { + popular[s] = struct{}{} + } + } + for s, _ := range popular { + delete(b2j, s) + } + } + m.bPopular = popular + m.b2j = b2j +} + +func (m *SequenceMatcher) isBJunk(s string) bool { + _, ok := m.bJunk[s] + return ok +} + +// Find longest matching block in a[alo:ahi] and b[blo:bhi]. +// +// If IsJunk is not defined: +// +// Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where +// alo <= i <= i+k <= ahi +// blo <= j <= j+k <= bhi +// and for all (i',j',k') meeting those conditions, +// k >= k' +// i <= i' +// and if i == i', j <= j' +// +// In other words, of all maximal matching blocks, return one that +// starts earliest in a, and of all those maximal matching blocks that +// start earliest in a, return the one that starts earliest in b. +// +// If IsJunk is defined, first the longest matching block is +// determined as above, but with the additional restriction that no +// junk element appears in the block. Then that block is extended as +// far as possible by matching (only) junk elements on both sides. So +// the resulting block never matches on junk except as identical junk +// happens to be adjacent to an "interesting" match. +// +// If no blocks match, return (alo, blo, 0). +func (m *SequenceMatcher) findLongestMatch(alo, ahi, blo, bhi int) Match { + // CAUTION: stripping common prefix or suffix would be incorrect. + // E.g., + // ab + // acab + // Longest matching block is "ab", but if common prefix is + // stripped, it's "a" (tied with "b"). UNIX(tm) diff does so + // strip, so ends up claiming that ab is changed to acab by + // inserting "ca" in the middle. That's minimal but unintuitive: + // "it's obvious" that someone inserted "ac" at the front. + // Windiff ends up at the same place as diff, but by pairing up + // the unique 'b's and then matching the first two 'a's. + besti, bestj, bestsize := alo, blo, 0 + + // find longest junk-free match + // during an iteration of the loop, j2len[j] = length of longest + // junk-free match ending with a[i-1] and b[j] + j2len := map[int]int{} + for i := alo; i != ahi; i++ { + // look at all instances of a[i] in b; note that because + // b2j has no junk keys, the loop is skipped if a[i] is junk + newj2len := map[int]int{} + for _, j := range m.b2j[m.a[i]] { + // a[i] matches b[j] + if j < blo { + continue + } + if j >= bhi { + break + } + k := j2len[j-1] + 1 + newj2len[j] = k + if k > bestsize { + besti, bestj, bestsize = i-k+1, j-k+1, k + } + } + j2len = newj2len + } + + // Extend the best by non-junk elements on each end. In particular, + // "popular" non-junk elements aren't in b2j, which greatly speeds + // the inner loop above, but also means "the best" match so far + // doesn't contain any junk *or* popular non-junk elements. + for besti > alo && bestj > blo && !m.isBJunk(m.b[bestj-1]) && + m.a[besti-1] == m.b[bestj-1] { + besti, bestj, bestsize = besti-1, bestj-1, bestsize+1 + } + for besti+bestsize < ahi && bestj+bestsize < bhi && + !m.isBJunk(m.b[bestj+bestsize]) && + m.a[besti+bestsize] == m.b[bestj+bestsize] { + bestsize += 1 + } + + // Now that we have a wholly interesting match (albeit possibly + // empty!), we may as well suck up the matching junk on each + // side of it too. Can't think of a good reason not to, and it + // saves post-processing the (possibly considerable) expense of + // figuring out what to do with it. In the case of an empty + // interesting match, this is clearly the right thing to do, + // because no other kind of match is possible in the regions. + for besti > alo && bestj > blo && m.isBJunk(m.b[bestj-1]) && + m.a[besti-1] == m.b[bestj-1] { + besti, bestj, bestsize = besti-1, bestj-1, bestsize+1 + } + for besti+bestsize < ahi && bestj+bestsize < bhi && + m.isBJunk(m.b[bestj+bestsize]) && + m.a[besti+bestsize] == m.b[bestj+bestsize] { + bestsize += 1 + } + + return Match{A: besti, B: bestj, Size: bestsize} +} + +// Return list of triples describing matching subsequences. +// +// Each triple is of the form (i, j, n), and means that +// a[i:i+n] == b[j:j+n]. The triples are monotonically increasing in +// i and in j. It's also guaranteed that if (i, j, n) and (i', j', n') are +// adjacent triples in the list, and the second is not the last triple in the +// list, then i+n != i' or j+n != j'. IOW, adjacent triples never describe +// adjacent equal blocks. +// +// The last triple is a dummy, (len(a), len(b), 0), and is the only +// triple with n==0. +func (m *SequenceMatcher) GetMatchingBlocks() []Match { + if m.matchingBlocks != nil { + return m.matchingBlocks + } + + var matchBlocks func(alo, ahi, blo, bhi int, matched []Match) []Match + matchBlocks = func(alo, ahi, blo, bhi int, matched []Match) []Match { + match := m.findLongestMatch(alo, ahi, blo, bhi) + i, j, k := match.A, match.B, match.Size + if match.Size > 0 { + if alo < i && blo < j { + matched = matchBlocks(alo, i, blo, j, matched) + } + matched = append(matched, match) + if i+k < ahi && j+k < bhi { + matched = matchBlocks(i+k, ahi, j+k, bhi, matched) + } + } + return matched + } + matched := matchBlocks(0, len(m.a), 0, len(m.b), nil) + + // It's possible that we have adjacent equal blocks in the + // matching_blocks list now. + nonAdjacent := []Match{} + i1, j1, k1 := 0, 0, 0 + for _, b := range matched { + // Is this block adjacent to i1, j1, k1? + i2, j2, k2 := b.A, b.B, b.Size + if i1+k1 == i2 && j1+k1 == j2 { + // Yes, so collapse them -- this just increases the length of + // the first block by the length of the second, and the first + // block so lengthened remains the block to compare against. + k1 += k2 + } else { + // Not adjacent. Remember the first block (k1==0 means it's + // the dummy we started with), and make the second block the + // new block to compare against. + if k1 > 0 { + nonAdjacent = append(nonAdjacent, Match{i1, j1, k1}) + } + i1, j1, k1 = i2, j2, k2 + } + } + if k1 > 0 { + nonAdjacent = append(nonAdjacent, Match{i1, j1, k1}) + } + + nonAdjacent = append(nonAdjacent, Match{len(m.a), len(m.b), 0}) + m.matchingBlocks = nonAdjacent + return m.matchingBlocks +} + +// Return list of 5-tuples describing how to turn a into b. +// +// Each tuple is of the form (tag, i1, i2, j1, j2). The first tuple +// has i1 == j1 == 0, and remaining tuples have i1 == the i2 from the +// tuple preceding it, and likewise for j1 == the previous j2. +// +// The tags are characters, with these meanings: +// +// 'r' (replace): a[i1:i2] should be replaced by b[j1:j2] +// +// 'd' (delete): a[i1:i2] should be deleted, j1==j2 in this case. +// +// 'i' (insert): b[j1:j2] should be inserted at a[i1:i1], i1==i2 in this case. +// +// 'e' (equal): a[i1:i2] == b[j1:j2] +func (m *SequenceMatcher) GetOpCodes() []OpCode { + if m.opCodes != nil { + return m.opCodes + } + i, j := 0, 0 + matching := m.GetMatchingBlocks() + opCodes := make([]OpCode, 0, len(matching)) + for _, m := range matching { + // invariant: we've pumped out correct diffs to change + // a[:i] into b[:j], and the next matching block is + // a[ai:ai+size] == b[bj:bj+size]. So we need to pump + // out a diff to change a[i:ai] into b[j:bj], pump out + // the matching block, and move (i,j) beyond the match + ai, bj, size := m.A, m.B, m.Size + tag := byte(0) + if i < ai && j < bj { + tag = 'r' + } else if i < ai { + tag = 'd' + } else if j < bj { + tag = 'i' + } + if tag > 0 { + opCodes = append(opCodes, OpCode{tag, i, ai, j, bj}) + } + i, j = ai+size, bj+size + // the list of matching blocks is terminated by a + // sentinel with size 0 + if size > 0 { + opCodes = append(opCodes, OpCode{'e', ai, i, bj, j}) + } + } + m.opCodes = opCodes + return m.opCodes +} + +// Isolate change clusters by eliminating ranges with no changes. +// +// Return a generator of groups with up to n lines of context. +// Each group is in the same format as returned by GetOpCodes(). +func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode { + if n < 0 { + n = 3 + } + codes := m.GetOpCodes() + if len(codes) == 0 { + codes = []OpCode{OpCode{'e', 0, 1, 0, 1}} + } + // Fixup leading and trailing groups if they show no changes. + if codes[0].Tag == 'e' { + c := codes[0] + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + codes[0] = OpCode{c.Tag, max(i1, i2-n), i2, max(j1, j2-n), j2} + } + if codes[len(codes)-1].Tag == 'e' { + c := codes[len(codes)-1] + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + codes[len(codes)-1] = OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)} + } + nn := n + n + groups := [][]OpCode{} + group := []OpCode{} + for _, c := range codes { + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + // End the current group and start a new one whenever + // there is a large range with no changes. + if c.Tag == 'e' && i2-i1 > nn { + group = append(group, OpCode{c.Tag, i1, min(i2, i1+n), + j1, min(j2, j1+n)}) + groups = append(groups, group) + group = []OpCode{} + i1, j1 = max(i1, i2-n), max(j1, j2-n) + } + group = append(group, OpCode{c.Tag, i1, i2, j1, j2}) + } + if len(group) > 0 && !(len(group) == 1 && group[0].Tag == 'e') { + groups = append(groups, group) + } + return groups +} + +// Return a measure of the sequences' similarity (float in [0,1]). +// +// Where T is the total number of elements in both sequences, and +// M is the number of matches, this is 2.0*M / T. +// Note that this is 1 if the sequences are identical, and 0 if +// they have nothing in common. +// +// .Ratio() is expensive to compute if you haven't already computed +// .GetMatchingBlocks() or .GetOpCodes(), in which case you may +// want to try .QuickRatio() or .RealQuickRation() first to get an +// upper bound. +func (m *SequenceMatcher) Ratio() float64 { + matches := 0 + for _, m := range m.GetMatchingBlocks() { + matches += m.Size + } + return calculateRatio(matches, len(m.a)+len(m.b)) +} + +// Return an upper bound on ratio() relatively quickly. +// +// This isn't defined beyond that it is an upper bound on .Ratio(), and +// is faster to compute. +func (m *SequenceMatcher) QuickRatio() float64 { + // viewing a and b as multisets, set matches to the cardinality + // of their intersection; this counts the number of matches + // without regard to order, so is clearly an upper bound + if m.fullBCount == nil { + m.fullBCount = map[string]int{} + for _, s := range m.b { + m.fullBCount[s] = m.fullBCount[s] + 1 + } + } + + // avail[x] is the number of times x appears in 'b' less the + // number of times we've seen it in 'a' so far ... kinda + avail := map[string]int{} + matches := 0 + for _, s := range m.a { + n, ok := avail[s] + if !ok { + n = m.fullBCount[s] + } + avail[s] = n - 1 + if n > 0 { + matches += 1 + } + } + return calculateRatio(matches, len(m.a)+len(m.b)) +} + +// Return an upper bound on ratio() very quickly. +// +// This isn't defined beyond that it is an upper bound on .Ratio(), and +// is faster to compute than either .Ratio() or .QuickRatio(). +func (m *SequenceMatcher) RealQuickRatio() float64 { + la, lb := len(m.a), len(m.b) + return calculateRatio(min(la, lb), la+lb) +} + +// Convert range to the "ed" format +func formatRangeUnified(start, stop int) string { + // Per the diff spec at http://www.unix.org/single_unix_specification/ + beginning := start + 1 // lines start numbering with one + length := stop - start + if length == 1 { + return fmt.Sprintf("%d", beginning) + } + if length == 0 { + beginning -= 1 // empty ranges begin at line just before the range + } + return fmt.Sprintf("%d,%d", beginning, length) +} + +// Unified diff parameters +type UnifiedDiff struct { + A []string // First sequence lines + FromFile string // First file name + FromDate string // First file time + B []string // Second sequence lines + ToFile string // Second file name + ToDate string // Second file time + Eol string // Headers end of line, defaults to LF + Context int // Number of context lines +} + +// Compare two sequences of lines; generate the delta as a unified diff. +// +// Unified diffs are a compact way of showing line changes and a few +// lines of context. The number of context lines is set by 'n' which +// defaults to three. +// +// By default, the diff control lines (those with ---, +++, or @@) are +// created with a trailing newline. This is helpful so that inputs +// created from file.readlines() result in diffs that are suitable for +// file.writelines() since both the inputs and outputs have trailing +// newlines. +// +// For inputs that do not have trailing newlines, set the lineterm +// argument to "" so that the output will be uniformly newline free. +// +// The unidiff format normally has a header for filenames and modification +// times. Any or all of these may be specified using strings for +// 'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'. +// The modification times are normally expressed in the ISO 8601 format. +func WriteUnifiedDiff(writer io.Writer, diff UnifiedDiff) error { + buf := bufio.NewWriter(writer) + defer buf.Flush() + wf := func(format string, args ...interface{}) error { + _, err := buf.WriteString(fmt.Sprintf(format, args...)) + return err + } + ws := func(s string) error { + _, err := buf.WriteString(s) + return err + } + + if len(diff.Eol) == 0 { + diff.Eol = "\n" + } + + started := false + m := NewMatcher(diff.A, diff.B) + for _, g := range m.GetGroupedOpCodes(diff.Context) { + if !started { + started = true + fromDate := "" + if len(diff.FromDate) > 0 { + fromDate = "\t" + diff.FromDate + } + toDate := "" + if len(diff.ToDate) > 0 { + toDate = "\t" + diff.ToDate + } + if diff.FromFile != "" || diff.ToFile != "" { + err := wf("--- %s%s%s", diff.FromFile, fromDate, diff.Eol) + if err != nil { + return err + } + err = wf("+++ %s%s%s", diff.ToFile, toDate, diff.Eol) + if err != nil { + return err + } + } + } + first, last := g[0], g[len(g)-1] + range1 := formatRangeUnified(first.I1, last.I2) + range2 := formatRangeUnified(first.J1, last.J2) + if err := wf("@@ -%s +%s @@%s", range1, range2, diff.Eol); err != nil { + return err + } + for _, c := range g { + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + if c.Tag == 'e' { + for _, line := range diff.A[i1:i2] { + if err := ws(" " + line); err != nil { + return err + } + } + continue + } + if c.Tag == 'r' || c.Tag == 'd' { + for _, line := range diff.A[i1:i2] { + if err := ws("-" + line); err != nil { + return err + } + } + } + if c.Tag == 'r' || c.Tag == 'i' { + for _, line := range diff.B[j1:j2] { + if err := ws("+" + line); err != nil { + return err + } + } + } + } + } + return nil +} + +// Like WriteUnifiedDiff but returns the diff a string. +func GetUnifiedDiffString(diff UnifiedDiff) (string, error) { + w := &bytes.Buffer{} + err := WriteUnifiedDiff(w, diff) + return string(w.Bytes()), err +} + +// Convert range to the "ed" format. +func formatRangeContext(start, stop int) string { + // Per the diff spec at http://www.unix.org/single_unix_specification/ + beginning := start + 1 // lines start numbering with one + length := stop - start + if length == 0 { + beginning -= 1 // empty ranges begin at line just before the range + } + if length <= 1 { + return fmt.Sprintf("%d", beginning) + } + return fmt.Sprintf("%d,%d", beginning, beginning+length-1) +} + +type ContextDiff UnifiedDiff + +// Compare two sequences of lines; generate the delta as a context diff. +// +// Context diffs are a compact way of showing line changes and a few +// lines of context. The number of context lines is set by diff.Context +// which defaults to three. +// +// By default, the diff control lines (those with *** or ---) are +// created with a trailing newline. +// +// For inputs that do not have trailing newlines, set the diff.Eol +// argument to "" so that the output will be uniformly newline free. +// +// The context diff format normally has a header for filenames and +// modification times. Any or all of these may be specified using +// strings for diff.FromFile, diff.ToFile, diff.FromDate, diff.ToDate. +// The modification times are normally expressed in the ISO 8601 format. +// If not specified, the strings default to blanks. +func WriteContextDiff(writer io.Writer, diff ContextDiff) error { + buf := bufio.NewWriter(writer) + defer buf.Flush() + var diffErr error + wf := func(format string, args ...interface{}) { + _, err := buf.WriteString(fmt.Sprintf(format, args...)) + if diffErr == nil && err != nil { + diffErr = err + } + } + ws := func(s string) { + _, err := buf.WriteString(s) + if diffErr == nil && err != nil { + diffErr = err + } + } + + if len(diff.Eol) == 0 { + diff.Eol = "\n" + } + + prefix := map[byte]string{ + 'i': "+ ", + 'd': "- ", + 'r': "! ", + 'e': " ", + } + + started := false + m := NewMatcher(diff.A, diff.B) + for _, g := range m.GetGroupedOpCodes(diff.Context) { + if !started { + started = true + fromDate := "" + if len(diff.FromDate) > 0 { + fromDate = "\t" + diff.FromDate + } + toDate := "" + if len(diff.ToDate) > 0 { + toDate = "\t" + diff.ToDate + } + if diff.FromFile != "" || diff.ToFile != "" { + wf("*** %s%s%s", diff.FromFile, fromDate, diff.Eol) + wf("--- %s%s%s", diff.ToFile, toDate, diff.Eol) + } + } + + first, last := g[0], g[len(g)-1] + ws("***************" + diff.Eol) + + range1 := formatRangeContext(first.I1, last.I2) + wf("*** %s ****%s", range1, diff.Eol) + for _, c := range g { + if c.Tag == 'r' || c.Tag == 'd' { + for _, cc := range g { + if cc.Tag == 'i' { + continue + } + for _, line := range diff.A[cc.I1:cc.I2] { + ws(prefix[cc.Tag] + line) + } + } + break + } + } + + range2 := formatRangeContext(first.J1, last.J2) + wf("--- %s ----%s", range2, diff.Eol) + for _, c := range g { + if c.Tag == 'r' || c.Tag == 'i' { + for _, cc := range g { + if cc.Tag == 'd' { + continue + } + for _, line := range diff.B[cc.J1:cc.J2] { + ws(prefix[cc.Tag] + line) + } + } + break + } + } + } + return diffErr +} + +// Like WriteContextDiff but returns the diff a string. +func GetContextDiffString(diff ContextDiff) (string, error) { + w := &bytes.Buffer{} + err := WriteContextDiff(w, diff) + return string(w.Bytes()), err +} + +// Split a string on "\n" while preserving them. The output can be used +// as input for UnifiedDiff and ContextDiff structures. +func SplitLines(s string) []string { + lines := strings.SplitAfter(s, "\n") + lines[len(lines)-1] += "\n" + return lines +} diff --git a/vendor/golang.org/x/crypto/bcrypt/base64.go b/vendor/golang.org/x/crypto/bcrypt/base64.go new file mode 100644 index 00000000000..fc311609081 --- /dev/null +++ b/vendor/golang.org/x/crypto/bcrypt/base64.go @@ -0,0 +1,35 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bcrypt + +import "encoding/base64" + +const alphabet = "./ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + +var bcEncoding = base64.NewEncoding(alphabet) + +func base64Encode(src []byte) []byte { + n := bcEncoding.EncodedLen(len(src)) + dst := make([]byte, n) + bcEncoding.Encode(dst, src) + for dst[n-1] == '=' { + n-- + } + return dst[:n] +} + +func base64Decode(src []byte) ([]byte, error) { + numOfEquals := 4 - (len(src) % 4) + for i := 0; i < numOfEquals; i++ { + src = append(src, '=') + } + + dst := make([]byte, bcEncoding.DecodedLen(len(src))) + n, err := bcEncoding.Decode(dst, src) + if err != nil { + return nil, err + } + return dst[:n], nil +} diff --git a/vendor/golang.org/x/crypto/bcrypt/bcrypt.go b/vendor/golang.org/x/crypto/bcrypt/bcrypt.go new file mode 100644 index 00000000000..aeb73f81a14 --- /dev/null +++ b/vendor/golang.org/x/crypto/bcrypt/bcrypt.go @@ -0,0 +1,295 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package bcrypt implements Provos and Mazières's bcrypt adaptive hashing +// algorithm. See http://www.usenix.org/event/usenix99/provos/provos.pdf +package bcrypt // import "golang.org/x/crypto/bcrypt" + +// The code is a port of Provos and Mazières's C implementation. +import ( + "crypto/rand" + "crypto/subtle" + "errors" + "fmt" + "io" + "strconv" + + "golang.org/x/crypto/blowfish" +) + +const ( + MinCost int = 4 // the minimum allowable cost as passed in to GenerateFromPassword + MaxCost int = 31 // the maximum allowable cost as passed in to GenerateFromPassword + DefaultCost int = 10 // the cost that will actually be set if a cost below MinCost is passed into GenerateFromPassword +) + +// The error returned from CompareHashAndPassword when a password and hash do +// not match. +var ErrMismatchedHashAndPassword = errors.New("crypto/bcrypt: hashedPassword is not the hash of the given password") + +// The error returned from CompareHashAndPassword when a hash is too short to +// be a bcrypt hash. +var ErrHashTooShort = errors.New("crypto/bcrypt: hashedSecret too short to be a bcrypted password") + +// The error returned from CompareHashAndPassword when a hash was created with +// a bcrypt algorithm newer than this implementation. +type HashVersionTooNewError byte + +func (hv HashVersionTooNewError) Error() string { + return fmt.Sprintf("crypto/bcrypt: bcrypt algorithm version '%c' requested is newer than current version '%c'", byte(hv), majorVersion) +} + +// The error returned from CompareHashAndPassword when a hash starts with something other than '$' +type InvalidHashPrefixError byte + +func (ih InvalidHashPrefixError) Error() string { + return fmt.Sprintf("crypto/bcrypt: bcrypt hashes must start with '$', but hashedSecret started with '%c'", byte(ih)) +} + +type InvalidCostError int + +func (ic InvalidCostError) Error() string { + return fmt.Sprintf("crypto/bcrypt: cost %d is outside allowed range (%d,%d)", int(ic), int(MinCost), int(MaxCost)) +} + +const ( + majorVersion = '2' + minorVersion = 'a' + maxSaltSize = 16 + maxCryptedHashSize = 23 + encodedSaltSize = 22 + encodedHashSize = 31 + minHashSize = 59 +) + +// magicCipherData is an IV for the 64 Blowfish encryption calls in +// bcrypt(). It's the string "OrpheanBeholderScryDoubt" in big-endian bytes. +var magicCipherData = []byte{ + 0x4f, 0x72, 0x70, 0x68, + 0x65, 0x61, 0x6e, 0x42, + 0x65, 0x68, 0x6f, 0x6c, + 0x64, 0x65, 0x72, 0x53, + 0x63, 0x72, 0x79, 0x44, + 0x6f, 0x75, 0x62, 0x74, +} + +type hashed struct { + hash []byte + salt []byte + cost int // allowed range is MinCost to MaxCost + major byte + minor byte +} + +// GenerateFromPassword returns the bcrypt hash of the password at the given +// cost. If the cost given is less than MinCost, the cost will be set to +// DefaultCost, instead. Use CompareHashAndPassword, as defined in this package, +// to compare the returned hashed password with its cleartext version. +func GenerateFromPassword(password []byte, cost int) ([]byte, error) { + p, err := newFromPassword(password, cost) + if err != nil { + return nil, err + } + return p.Hash(), nil +} + +// CompareHashAndPassword compares a bcrypt hashed password with its possible +// plaintext equivalent. Returns nil on success, or an error on failure. +func CompareHashAndPassword(hashedPassword, password []byte) error { + p, err := newFromHash(hashedPassword) + if err != nil { + return err + } + + otherHash, err := bcrypt(password, p.cost, p.salt) + if err != nil { + return err + } + + otherP := &hashed{otherHash, p.salt, p.cost, p.major, p.minor} + if subtle.ConstantTimeCompare(p.Hash(), otherP.Hash()) == 1 { + return nil + } + + return ErrMismatchedHashAndPassword +} + +// Cost returns the hashing cost used to create the given hashed +// password. When, in the future, the hashing cost of a password system needs +// to be increased in order to adjust for greater computational power, this +// function allows one to establish which passwords need to be updated. +func Cost(hashedPassword []byte) (int, error) { + p, err := newFromHash(hashedPassword) + if err != nil { + return 0, err + } + return p.cost, nil +} + +func newFromPassword(password []byte, cost int) (*hashed, error) { + if cost < MinCost { + cost = DefaultCost + } + p := new(hashed) + p.major = majorVersion + p.minor = minorVersion + + err := checkCost(cost) + if err != nil { + return nil, err + } + p.cost = cost + + unencodedSalt := make([]byte, maxSaltSize) + _, err = io.ReadFull(rand.Reader, unencodedSalt) + if err != nil { + return nil, err + } + + p.salt = base64Encode(unencodedSalt) + hash, err := bcrypt(password, p.cost, p.salt) + if err != nil { + return nil, err + } + p.hash = hash + return p, err +} + +func newFromHash(hashedSecret []byte) (*hashed, error) { + if len(hashedSecret) < minHashSize { + return nil, ErrHashTooShort + } + p := new(hashed) + n, err := p.decodeVersion(hashedSecret) + if err != nil { + return nil, err + } + hashedSecret = hashedSecret[n:] + n, err = p.decodeCost(hashedSecret) + if err != nil { + return nil, err + } + hashedSecret = hashedSecret[n:] + + // The "+2" is here because we'll have to append at most 2 '=' to the salt + // when base64 decoding it in expensiveBlowfishSetup(). + p.salt = make([]byte, encodedSaltSize, encodedSaltSize+2) + copy(p.salt, hashedSecret[:encodedSaltSize]) + + hashedSecret = hashedSecret[encodedSaltSize:] + p.hash = make([]byte, len(hashedSecret)) + copy(p.hash, hashedSecret) + + return p, nil +} + +func bcrypt(password []byte, cost int, salt []byte) ([]byte, error) { + cipherData := make([]byte, len(magicCipherData)) + copy(cipherData, magicCipherData) + + c, err := expensiveBlowfishSetup(password, uint32(cost), salt) + if err != nil { + return nil, err + } + + for i := 0; i < 24; i += 8 { + for j := 0; j < 64; j++ { + c.Encrypt(cipherData[i:i+8], cipherData[i:i+8]) + } + } + + // Bug compatibility with C bcrypt implementations. We only encode 23 of + // the 24 bytes encrypted. + hsh := base64Encode(cipherData[:maxCryptedHashSize]) + return hsh, nil +} + +func expensiveBlowfishSetup(key []byte, cost uint32, salt []byte) (*blowfish.Cipher, error) { + csalt, err := base64Decode(salt) + if err != nil { + return nil, err + } + + // Bug compatibility with C bcrypt implementations. They use the trailing + // NULL in the key string during expansion. + // We copy the key to prevent changing the underlying array. + ckey := append(key[:len(key):len(key)], 0) + + c, err := blowfish.NewSaltedCipher(ckey, csalt) + if err != nil { + return nil, err + } + + var i, rounds uint64 + rounds = 1 << cost + for i = 0; i < rounds; i++ { + blowfish.ExpandKey(ckey, c) + blowfish.ExpandKey(csalt, c) + } + + return c, nil +} + +func (p *hashed) Hash() []byte { + arr := make([]byte, 60) + arr[0] = '$' + arr[1] = p.major + n := 2 + if p.minor != 0 { + arr[2] = p.minor + n = 3 + } + arr[n] = '$' + n++ + copy(arr[n:], []byte(fmt.Sprintf("%02d", p.cost))) + n += 2 + arr[n] = '$' + n++ + copy(arr[n:], p.salt) + n += encodedSaltSize + copy(arr[n:], p.hash) + n += encodedHashSize + return arr[:n] +} + +func (p *hashed) decodeVersion(sbytes []byte) (int, error) { + if sbytes[0] != '$' { + return -1, InvalidHashPrefixError(sbytes[0]) + } + if sbytes[1] > majorVersion { + return -1, HashVersionTooNewError(sbytes[1]) + } + p.major = sbytes[1] + n := 3 + if sbytes[2] != '$' { + p.minor = sbytes[2] + n++ + } + return n, nil +} + +// sbytes should begin where decodeVersion left off. +func (p *hashed) decodeCost(sbytes []byte) (int, error) { + cost, err := strconv.Atoi(string(sbytes[0:2])) + if err != nil { + return -1, err + } + err = checkCost(cost) + if err != nil { + return -1, err + } + p.cost = cost + return 3, nil +} + +func (p *hashed) String() string { + return fmt.Sprintf("&{hash: %#v, salt: %#v, cost: %d, major: %c, minor: %c}", string(p.hash), p.salt, p.cost, p.major, p.minor) +} + +func checkCost(cost int) error { + if cost < MinCost || cost > MaxCost { + return InvalidCostError(cost) + } + return nil +} diff --git a/vendor/golang.org/x/crypto/blowfish/block.go b/vendor/golang.org/x/crypto/blowfish/block.go new file mode 100644 index 00000000000..9d80f19521b --- /dev/null +++ b/vendor/golang.org/x/crypto/blowfish/block.go @@ -0,0 +1,159 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package blowfish + +// getNextWord returns the next big-endian uint32 value from the byte slice +// at the given position in a circular manner, updating the position. +func getNextWord(b []byte, pos *int) uint32 { + var w uint32 + j := *pos + for i := 0; i < 4; i++ { + w = w<<8 | uint32(b[j]) + j++ + if j >= len(b) { + j = 0 + } + } + *pos = j + return w +} + +// ExpandKey performs a key expansion on the given *Cipher. Specifically, it +// performs the Blowfish algorithm's key schedule which sets up the *Cipher's +// pi and substitution tables for calls to Encrypt. This is used, primarily, +// by the bcrypt package to reuse the Blowfish key schedule during its +// set up. It's unlikely that you need to use this directly. +func ExpandKey(key []byte, c *Cipher) { + j := 0 + for i := 0; i < 18; i++ { + // Using inlined getNextWord for performance. + var d uint32 + for k := 0; k < 4; k++ { + d = d<<8 | uint32(key[j]) + j++ + if j >= len(key) { + j = 0 + } + } + c.p[i] ^= d + } + + var l, r uint32 + for i := 0; i < 18; i += 2 { + l, r = encryptBlock(l, r, c) + c.p[i], c.p[i+1] = l, r + } + + for i := 0; i < 256; i += 2 { + l, r = encryptBlock(l, r, c) + c.s0[i], c.s0[i+1] = l, r + } + for i := 0; i < 256; i += 2 { + l, r = encryptBlock(l, r, c) + c.s1[i], c.s1[i+1] = l, r + } + for i := 0; i < 256; i += 2 { + l, r = encryptBlock(l, r, c) + c.s2[i], c.s2[i+1] = l, r + } + for i := 0; i < 256; i += 2 { + l, r = encryptBlock(l, r, c) + c.s3[i], c.s3[i+1] = l, r + } +} + +// This is similar to ExpandKey, but folds the salt during the key +// schedule. While ExpandKey is essentially expandKeyWithSalt with an all-zero +// salt passed in, reusing ExpandKey turns out to be a place of inefficiency +// and specializing it here is useful. +func expandKeyWithSalt(key []byte, salt []byte, c *Cipher) { + j := 0 + for i := 0; i < 18; i++ { + c.p[i] ^= getNextWord(key, &j) + } + + j = 0 + var l, r uint32 + for i := 0; i < 18; i += 2 { + l ^= getNextWord(salt, &j) + r ^= getNextWord(salt, &j) + l, r = encryptBlock(l, r, c) + c.p[i], c.p[i+1] = l, r + } + + for i := 0; i < 256; i += 2 { + l ^= getNextWord(salt, &j) + r ^= getNextWord(salt, &j) + l, r = encryptBlock(l, r, c) + c.s0[i], c.s0[i+1] = l, r + } + + for i := 0; i < 256; i += 2 { + l ^= getNextWord(salt, &j) + r ^= getNextWord(salt, &j) + l, r = encryptBlock(l, r, c) + c.s1[i], c.s1[i+1] = l, r + } + + for i := 0; i < 256; i += 2 { + l ^= getNextWord(salt, &j) + r ^= getNextWord(salt, &j) + l, r = encryptBlock(l, r, c) + c.s2[i], c.s2[i+1] = l, r + } + + for i := 0; i < 256; i += 2 { + l ^= getNextWord(salt, &j) + r ^= getNextWord(salt, &j) + l, r = encryptBlock(l, r, c) + c.s3[i], c.s3[i+1] = l, r + } +} + +func encryptBlock(l, r uint32, c *Cipher) (uint32, uint32) { + xl, xr := l, r + xl ^= c.p[0] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[1] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[2] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[3] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[4] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[5] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[6] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[7] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[8] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[9] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[10] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[11] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[12] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[13] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[14] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[15] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[16] + xr ^= c.p[17] + return xr, xl +} + +func decryptBlock(l, r uint32, c *Cipher) (uint32, uint32) { + xl, xr := l, r + xl ^= c.p[17] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[16] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[15] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[14] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[13] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[12] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[11] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[10] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[9] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[8] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[7] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[6] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[5] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[4] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[3] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[2] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[1] + xr ^= c.p[0] + return xr, xl +} diff --git a/vendor/golang.org/x/crypto/blowfish/cipher.go b/vendor/golang.org/x/crypto/blowfish/cipher.go new file mode 100644 index 00000000000..2641dadd649 --- /dev/null +++ b/vendor/golang.org/x/crypto/blowfish/cipher.go @@ -0,0 +1,91 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package blowfish implements Bruce Schneier's Blowfish encryption algorithm. +package blowfish // import "golang.org/x/crypto/blowfish" + +// The code is a port of Bruce Schneier's C implementation. +// See https://www.schneier.com/blowfish.html. + +import "strconv" + +// The Blowfish block size in bytes. +const BlockSize = 8 + +// A Cipher is an instance of Blowfish encryption using a particular key. +type Cipher struct { + p [18]uint32 + s0, s1, s2, s3 [256]uint32 +} + +type KeySizeError int + +func (k KeySizeError) Error() string { + return "crypto/blowfish: invalid key size " + strconv.Itoa(int(k)) +} + +// NewCipher creates and returns a Cipher. +// The key argument should be the Blowfish key, from 1 to 56 bytes. +func NewCipher(key []byte) (*Cipher, error) { + var result Cipher + if k := len(key); k < 1 || k > 56 { + return nil, KeySizeError(k) + } + initCipher(&result) + ExpandKey(key, &result) + return &result, nil +} + +// NewSaltedCipher creates a returns a Cipher that folds a salt into its key +// schedule. For most purposes, NewCipher, instead of NewSaltedCipher, is +// sufficient and desirable. For bcrypt compatibility, the key can be over 56 +// bytes. +func NewSaltedCipher(key, salt []byte) (*Cipher, error) { + if len(salt) == 0 { + return NewCipher(key) + } + var result Cipher + if k := len(key); k < 1 { + return nil, KeySizeError(k) + } + initCipher(&result) + expandKeyWithSalt(key, salt, &result) + return &result, nil +} + +// BlockSize returns the Blowfish block size, 8 bytes. +// It is necessary to satisfy the Block interface in the +// package "crypto/cipher". +func (c *Cipher) BlockSize() int { return BlockSize } + +// Encrypt encrypts the 8-byte buffer src using the key k +// and stores the result in dst. +// Note that for amounts of data larger than a block, +// it is not safe to just call Encrypt on successive blocks; +// instead, use an encryption mode like CBC (see crypto/cipher/cbc.go). +func (c *Cipher) Encrypt(dst, src []byte) { + l := uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3]) + r := uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7]) + l, r = encryptBlock(l, r, c) + dst[0], dst[1], dst[2], dst[3] = byte(l>>24), byte(l>>16), byte(l>>8), byte(l) + dst[4], dst[5], dst[6], dst[7] = byte(r>>24), byte(r>>16), byte(r>>8), byte(r) +} + +// Decrypt decrypts the 8-byte buffer src using the key k +// and stores the result in dst. +func (c *Cipher) Decrypt(dst, src []byte) { + l := uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3]) + r := uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7]) + l, r = decryptBlock(l, r, c) + dst[0], dst[1], dst[2], dst[3] = byte(l>>24), byte(l>>16), byte(l>>8), byte(l) + dst[4], dst[5], dst[6], dst[7] = byte(r>>24), byte(r>>16), byte(r>>8), byte(r) +} + +func initCipher(c *Cipher) { + copy(c.p[0:], p[0:]) + copy(c.s0[0:], s0[0:]) + copy(c.s1[0:], s1[0:]) + copy(c.s2[0:], s2[0:]) + copy(c.s3[0:], s3[0:]) +} diff --git a/vendor/golang.org/x/crypto/blowfish/const.go b/vendor/golang.org/x/crypto/blowfish/const.go new file mode 100644 index 00000000000..d04077595ab --- /dev/null +++ b/vendor/golang.org/x/crypto/blowfish/const.go @@ -0,0 +1,199 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The startup permutation array and substitution boxes. +// They are the hexadecimal digits of PI; see: +// https://www.schneier.com/code/constants.txt. + +package blowfish + +var s0 = [256]uint32{ + 0xd1310ba6, 0x98dfb5ac, 0x2ffd72db, 0xd01adfb7, 0xb8e1afed, 0x6a267e96, + 0xba7c9045, 0xf12c7f99, 0x24a19947, 0xb3916cf7, 0x0801f2e2, 0x858efc16, + 0x636920d8, 0x71574e69, 0xa458fea3, 0xf4933d7e, 0x0d95748f, 0x728eb658, + 0x718bcd58, 0x82154aee, 0x7b54a41d, 0xc25a59b5, 0x9c30d539, 0x2af26013, + 0xc5d1b023, 0x286085f0, 0xca417918, 0xb8db38ef, 0x8e79dcb0, 0x603a180e, + 0x6c9e0e8b, 0xb01e8a3e, 0xd71577c1, 0xbd314b27, 0x78af2fda, 0x55605c60, + 0xe65525f3, 0xaa55ab94, 0x57489862, 0x63e81440, 0x55ca396a, 0x2aab10b6, + 0xb4cc5c34, 0x1141e8ce, 0xa15486af, 0x7c72e993, 0xb3ee1411, 0x636fbc2a, + 0x2ba9c55d, 0x741831f6, 0xce5c3e16, 0x9b87931e, 0xafd6ba33, 0x6c24cf5c, + 0x7a325381, 0x28958677, 0x3b8f4898, 0x6b4bb9af, 0xc4bfe81b, 0x66282193, + 0x61d809cc, 0xfb21a991, 0x487cac60, 0x5dec8032, 0xef845d5d, 0xe98575b1, + 0xdc262302, 0xeb651b88, 0x23893e81, 0xd396acc5, 0x0f6d6ff3, 0x83f44239, + 0x2e0b4482, 0xa4842004, 0x69c8f04a, 0x9e1f9b5e, 0x21c66842, 0xf6e96c9a, + 0x670c9c61, 0xabd388f0, 0x6a51a0d2, 0xd8542f68, 0x960fa728, 0xab5133a3, + 0x6eef0b6c, 0x137a3be4, 0xba3bf050, 0x7efb2a98, 0xa1f1651d, 0x39af0176, + 0x66ca593e, 0x82430e88, 0x8cee8619, 0x456f9fb4, 0x7d84a5c3, 0x3b8b5ebe, + 0xe06f75d8, 0x85c12073, 0x401a449f, 0x56c16aa6, 0x4ed3aa62, 0x363f7706, + 0x1bfedf72, 0x429b023d, 0x37d0d724, 0xd00a1248, 0xdb0fead3, 0x49f1c09b, + 0x075372c9, 0x80991b7b, 0x25d479d8, 0xf6e8def7, 0xe3fe501a, 0xb6794c3b, + 0x976ce0bd, 0x04c006ba, 0xc1a94fb6, 0x409f60c4, 0x5e5c9ec2, 0x196a2463, + 0x68fb6faf, 0x3e6c53b5, 0x1339b2eb, 0x3b52ec6f, 0x6dfc511f, 0x9b30952c, + 0xcc814544, 0xaf5ebd09, 0xbee3d004, 0xde334afd, 0x660f2807, 0x192e4bb3, + 0xc0cba857, 0x45c8740f, 0xd20b5f39, 0xb9d3fbdb, 0x5579c0bd, 0x1a60320a, + 0xd6a100c6, 0x402c7279, 0x679f25fe, 0xfb1fa3cc, 0x8ea5e9f8, 0xdb3222f8, + 0x3c7516df, 0xfd616b15, 0x2f501ec8, 0xad0552ab, 0x323db5fa, 0xfd238760, + 0x53317b48, 0x3e00df82, 0x9e5c57bb, 0xca6f8ca0, 0x1a87562e, 0xdf1769db, + 0xd542a8f6, 0x287effc3, 0xac6732c6, 0x8c4f5573, 0x695b27b0, 0xbbca58c8, + 0xe1ffa35d, 0xb8f011a0, 0x10fa3d98, 0xfd2183b8, 0x4afcb56c, 0x2dd1d35b, + 0x9a53e479, 0xb6f84565, 0xd28e49bc, 0x4bfb9790, 0xe1ddf2da, 0xa4cb7e33, + 0x62fb1341, 0xcee4c6e8, 0xef20cada, 0x36774c01, 0xd07e9efe, 0x2bf11fb4, + 0x95dbda4d, 0xae909198, 0xeaad8e71, 0x6b93d5a0, 0xd08ed1d0, 0xafc725e0, + 0x8e3c5b2f, 0x8e7594b7, 0x8ff6e2fb, 0xf2122b64, 0x8888b812, 0x900df01c, + 0x4fad5ea0, 0x688fc31c, 0xd1cff191, 0xb3a8c1ad, 0x2f2f2218, 0xbe0e1777, + 0xea752dfe, 0x8b021fa1, 0xe5a0cc0f, 0xb56f74e8, 0x18acf3d6, 0xce89e299, + 0xb4a84fe0, 0xfd13e0b7, 0x7cc43b81, 0xd2ada8d9, 0x165fa266, 0x80957705, + 0x93cc7314, 0x211a1477, 0xe6ad2065, 0x77b5fa86, 0xc75442f5, 0xfb9d35cf, + 0xebcdaf0c, 0x7b3e89a0, 0xd6411bd3, 0xae1e7e49, 0x00250e2d, 0x2071b35e, + 0x226800bb, 0x57b8e0af, 0x2464369b, 0xf009b91e, 0x5563911d, 0x59dfa6aa, + 0x78c14389, 0xd95a537f, 0x207d5ba2, 0x02e5b9c5, 0x83260376, 0x6295cfa9, + 0x11c81968, 0x4e734a41, 0xb3472dca, 0x7b14a94a, 0x1b510052, 0x9a532915, + 0xd60f573f, 0xbc9bc6e4, 0x2b60a476, 0x81e67400, 0x08ba6fb5, 0x571be91f, + 0xf296ec6b, 0x2a0dd915, 0xb6636521, 0xe7b9f9b6, 0xff34052e, 0xc5855664, + 0x53b02d5d, 0xa99f8fa1, 0x08ba4799, 0x6e85076a, +} + +var s1 = [256]uint32{ + 0x4b7a70e9, 0xb5b32944, 0xdb75092e, 0xc4192623, 0xad6ea6b0, 0x49a7df7d, + 0x9cee60b8, 0x8fedb266, 0xecaa8c71, 0x699a17ff, 0x5664526c, 0xc2b19ee1, + 0x193602a5, 0x75094c29, 0xa0591340, 0xe4183a3e, 0x3f54989a, 0x5b429d65, + 0x6b8fe4d6, 0x99f73fd6, 0xa1d29c07, 0xefe830f5, 0x4d2d38e6, 0xf0255dc1, + 0x4cdd2086, 0x8470eb26, 0x6382e9c6, 0x021ecc5e, 0x09686b3f, 0x3ebaefc9, + 0x3c971814, 0x6b6a70a1, 0x687f3584, 0x52a0e286, 0xb79c5305, 0xaa500737, + 0x3e07841c, 0x7fdeae5c, 0x8e7d44ec, 0x5716f2b8, 0xb03ada37, 0xf0500c0d, + 0xf01c1f04, 0x0200b3ff, 0xae0cf51a, 0x3cb574b2, 0x25837a58, 0xdc0921bd, + 0xd19113f9, 0x7ca92ff6, 0x94324773, 0x22f54701, 0x3ae5e581, 0x37c2dadc, + 0xc8b57634, 0x9af3dda7, 0xa9446146, 0x0fd0030e, 0xecc8c73e, 0xa4751e41, + 0xe238cd99, 0x3bea0e2f, 0x3280bba1, 0x183eb331, 0x4e548b38, 0x4f6db908, + 0x6f420d03, 0xf60a04bf, 0x2cb81290, 0x24977c79, 0x5679b072, 0xbcaf89af, + 0xde9a771f, 0xd9930810, 0xb38bae12, 0xdccf3f2e, 0x5512721f, 0x2e6b7124, + 0x501adde6, 0x9f84cd87, 0x7a584718, 0x7408da17, 0xbc9f9abc, 0xe94b7d8c, + 0xec7aec3a, 0xdb851dfa, 0x63094366, 0xc464c3d2, 0xef1c1847, 0x3215d908, + 0xdd433b37, 0x24c2ba16, 0x12a14d43, 0x2a65c451, 0x50940002, 0x133ae4dd, + 0x71dff89e, 0x10314e55, 0x81ac77d6, 0x5f11199b, 0x043556f1, 0xd7a3c76b, + 0x3c11183b, 0x5924a509, 0xf28fe6ed, 0x97f1fbfa, 0x9ebabf2c, 0x1e153c6e, + 0x86e34570, 0xeae96fb1, 0x860e5e0a, 0x5a3e2ab3, 0x771fe71c, 0x4e3d06fa, + 0x2965dcb9, 0x99e71d0f, 0x803e89d6, 0x5266c825, 0x2e4cc978, 0x9c10b36a, + 0xc6150eba, 0x94e2ea78, 0xa5fc3c53, 0x1e0a2df4, 0xf2f74ea7, 0x361d2b3d, + 0x1939260f, 0x19c27960, 0x5223a708, 0xf71312b6, 0xebadfe6e, 0xeac31f66, + 0xe3bc4595, 0xa67bc883, 0xb17f37d1, 0x018cff28, 0xc332ddef, 0xbe6c5aa5, + 0x65582185, 0x68ab9802, 0xeecea50f, 0xdb2f953b, 0x2aef7dad, 0x5b6e2f84, + 0x1521b628, 0x29076170, 0xecdd4775, 0x619f1510, 0x13cca830, 0xeb61bd96, + 0x0334fe1e, 0xaa0363cf, 0xb5735c90, 0x4c70a239, 0xd59e9e0b, 0xcbaade14, + 0xeecc86bc, 0x60622ca7, 0x9cab5cab, 0xb2f3846e, 0x648b1eaf, 0x19bdf0ca, + 0xa02369b9, 0x655abb50, 0x40685a32, 0x3c2ab4b3, 0x319ee9d5, 0xc021b8f7, + 0x9b540b19, 0x875fa099, 0x95f7997e, 0x623d7da8, 0xf837889a, 0x97e32d77, + 0x11ed935f, 0x16681281, 0x0e358829, 0xc7e61fd6, 0x96dedfa1, 0x7858ba99, + 0x57f584a5, 0x1b227263, 0x9b83c3ff, 0x1ac24696, 0xcdb30aeb, 0x532e3054, + 0x8fd948e4, 0x6dbc3128, 0x58ebf2ef, 0x34c6ffea, 0xfe28ed61, 0xee7c3c73, + 0x5d4a14d9, 0xe864b7e3, 0x42105d14, 0x203e13e0, 0x45eee2b6, 0xa3aaabea, + 0xdb6c4f15, 0xfacb4fd0, 0xc742f442, 0xef6abbb5, 0x654f3b1d, 0x41cd2105, + 0xd81e799e, 0x86854dc7, 0xe44b476a, 0x3d816250, 0xcf62a1f2, 0x5b8d2646, + 0xfc8883a0, 0xc1c7b6a3, 0x7f1524c3, 0x69cb7492, 0x47848a0b, 0x5692b285, + 0x095bbf00, 0xad19489d, 0x1462b174, 0x23820e00, 0x58428d2a, 0x0c55f5ea, + 0x1dadf43e, 0x233f7061, 0x3372f092, 0x8d937e41, 0xd65fecf1, 0x6c223bdb, + 0x7cde3759, 0xcbee7460, 0x4085f2a7, 0xce77326e, 0xa6078084, 0x19f8509e, + 0xe8efd855, 0x61d99735, 0xa969a7aa, 0xc50c06c2, 0x5a04abfc, 0x800bcadc, + 0x9e447a2e, 0xc3453484, 0xfdd56705, 0x0e1e9ec9, 0xdb73dbd3, 0x105588cd, + 0x675fda79, 0xe3674340, 0xc5c43465, 0x713e38d8, 0x3d28f89e, 0xf16dff20, + 0x153e21e7, 0x8fb03d4a, 0xe6e39f2b, 0xdb83adf7, +} + +var s2 = [256]uint32{ + 0xe93d5a68, 0x948140f7, 0xf64c261c, 0x94692934, 0x411520f7, 0x7602d4f7, + 0xbcf46b2e, 0xd4a20068, 0xd4082471, 0x3320f46a, 0x43b7d4b7, 0x500061af, + 0x1e39f62e, 0x97244546, 0x14214f74, 0xbf8b8840, 0x4d95fc1d, 0x96b591af, + 0x70f4ddd3, 0x66a02f45, 0xbfbc09ec, 0x03bd9785, 0x7fac6dd0, 0x31cb8504, + 0x96eb27b3, 0x55fd3941, 0xda2547e6, 0xabca0a9a, 0x28507825, 0x530429f4, + 0x0a2c86da, 0xe9b66dfb, 0x68dc1462, 0xd7486900, 0x680ec0a4, 0x27a18dee, + 0x4f3ffea2, 0xe887ad8c, 0xb58ce006, 0x7af4d6b6, 0xaace1e7c, 0xd3375fec, + 0xce78a399, 0x406b2a42, 0x20fe9e35, 0xd9f385b9, 0xee39d7ab, 0x3b124e8b, + 0x1dc9faf7, 0x4b6d1856, 0x26a36631, 0xeae397b2, 0x3a6efa74, 0xdd5b4332, + 0x6841e7f7, 0xca7820fb, 0xfb0af54e, 0xd8feb397, 0x454056ac, 0xba489527, + 0x55533a3a, 0x20838d87, 0xfe6ba9b7, 0xd096954b, 0x55a867bc, 0xa1159a58, + 0xcca92963, 0x99e1db33, 0xa62a4a56, 0x3f3125f9, 0x5ef47e1c, 0x9029317c, + 0xfdf8e802, 0x04272f70, 0x80bb155c, 0x05282ce3, 0x95c11548, 0xe4c66d22, + 0x48c1133f, 0xc70f86dc, 0x07f9c9ee, 0x41041f0f, 0x404779a4, 0x5d886e17, + 0x325f51eb, 0xd59bc0d1, 0xf2bcc18f, 0x41113564, 0x257b7834, 0x602a9c60, + 0xdff8e8a3, 0x1f636c1b, 0x0e12b4c2, 0x02e1329e, 0xaf664fd1, 0xcad18115, + 0x6b2395e0, 0x333e92e1, 0x3b240b62, 0xeebeb922, 0x85b2a20e, 0xe6ba0d99, + 0xde720c8c, 0x2da2f728, 0xd0127845, 0x95b794fd, 0x647d0862, 0xe7ccf5f0, + 0x5449a36f, 0x877d48fa, 0xc39dfd27, 0xf33e8d1e, 0x0a476341, 0x992eff74, + 0x3a6f6eab, 0xf4f8fd37, 0xa812dc60, 0xa1ebddf8, 0x991be14c, 0xdb6e6b0d, + 0xc67b5510, 0x6d672c37, 0x2765d43b, 0xdcd0e804, 0xf1290dc7, 0xcc00ffa3, + 0xb5390f92, 0x690fed0b, 0x667b9ffb, 0xcedb7d9c, 0xa091cf0b, 0xd9155ea3, + 0xbb132f88, 0x515bad24, 0x7b9479bf, 0x763bd6eb, 0x37392eb3, 0xcc115979, + 0x8026e297, 0xf42e312d, 0x6842ada7, 0xc66a2b3b, 0x12754ccc, 0x782ef11c, + 0x6a124237, 0xb79251e7, 0x06a1bbe6, 0x4bfb6350, 0x1a6b1018, 0x11caedfa, + 0x3d25bdd8, 0xe2e1c3c9, 0x44421659, 0x0a121386, 0xd90cec6e, 0xd5abea2a, + 0x64af674e, 0xda86a85f, 0xbebfe988, 0x64e4c3fe, 0x9dbc8057, 0xf0f7c086, + 0x60787bf8, 0x6003604d, 0xd1fd8346, 0xf6381fb0, 0x7745ae04, 0xd736fccc, + 0x83426b33, 0xf01eab71, 0xb0804187, 0x3c005e5f, 0x77a057be, 0xbde8ae24, + 0x55464299, 0xbf582e61, 0x4e58f48f, 0xf2ddfda2, 0xf474ef38, 0x8789bdc2, + 0x5366f9c3, 0xc8b38e74, 0xb475f255, 0x46fcd9b9, 0x7aeb2661, 0x8b1ddf84, + 0x846a0e79, 0x915f95e2, 0x466e598e, 0x20b45770, 0x8cd55591, 0xc902de4c, + 0xb90bace1, 0xbb8205d0, 0x11a86248, 0x7574a99e, 0xb77f19b6, 0xe0a9dc09, + 0x662d09a1, 0xc4324633, 0xe85a1f02, 0x09f0be8c, 0x4a99a025, 0x1d6efe10, + 0x1ab93d1d, 0x0ba5a4df, 0xa186f20f, 0x2868f169, 0xdcb7da83, 0x573906fe, + 0xa1e2ce9b, 0x4fcd7f52, 0x50115e01, 0xa70683fa, 0xa002b5c4, 0x0de6d027, + 0x9af88c27, 0x773f8641, 0xc3604c06, 0x61a806b5, 0xf0177a28, 0xc0f586e0, + 0x006058aa, 0x30dc7d62, 0x11e69ed7, 0x2338ea63, 0x53c2dd94, 0xc2c21634, + 0xbbcbee56, 0x90bcb6de, 0xebfc7da1, 0xce591d76, 0x6f05e409, 0x4b7c0188, + 0x39720a3d, 0x7c927c24, 0x86e3725f, 0x724d9db9, 0x1ac15bb4, 0xd39eb8fc, + 0xed545578, 0x08fca5b5, 0xd83d7cd3, 0x4dad0fc4, 0x1e50ef5e, 0xb161e6f8, + 0xa28514d9, 0x6c51133c, 0x6fd5c7e7, 0x56e14ec4, 0x362abfce, 0xddc6c837, + 0xd79a3234, 0x92638212, 0x670efa8e, 0x406000e0, +} + +var s3 = [256]uint32{ + 0x3a39ce37, 0xd3faf5cf, 0xabc27737, 0x5ac52d1b, 0x5cb0679e, 0x4fa33742, + 0xd3822740, 0x99bc9bbe, 0xd5118e9d, 0xbf0f7315, 0xd62d1c7e, 0xc700c47b, + 0xb78c1b6b, 0x21a19045, 0xb26eb1be, 0x6a366eb4, 0x5748ab2f, 0xbc946e79, + 0xc6a376d2, 0x6549c2c8, 0x530ff8ee, 0x468dde7d, 0xd5730a1d, 0x4cd04dc6, + 0x2939bbdb, 0xa9ba4650, 0xac9526e8, 0xbe5ee304, 0xa1fad5f0, 0x6a2d519a, + 0x63ef8ce2, 0x9a86ee22, 0xc089c2b8, 0x43242ef6, 0xa51e03aa, 0x9cf2d0a4, + 0x83c061ba, 0x9be96a4d, 0x8fe51550, 0xba645bd6, 0x2826a2f9, 0xa73a3ae1, + 0x4ba99586, 0xef5562e9, 0xc72fefd3, 0xf752f7da, 0x3f046f69, 0x77fa0a59, + 0x80e4a915, 0x87b08601, 0x9b09e6ad, 0x3b3ee593, 0xe990fd5a, 0x9e34d797, + 0x2cf0b7d9, 0x022b8b51, 0x96d5ac3a, 0x017da67d, 0xd1cf3ed6, 0x7c7d2d28, + 0x1f9f25cf, 0xadf2b89b, 0x5ad6b472, 0x5a88f54c, 0xe029ac71, 0xe019a5e6, + 0x47b0acfd, 0xed93fa9b, 0xe8d3c48d, 0x283b57cc, 0xf8d56629, 0x79132e28, + 0x785f0191, 0xed756055, 0xf7960e44, 0xe3d35e8c, 0x15056dd4, 0x88f46dba, + 0x03a16125, 0x0564f0bd, 0xc3eb9e15, 0x3c9057a2, 0x97271aec, 0xa93a072a, + 0x1b3f6d9b, 0x1e6321f5, 0xf59c66fb, 0x26dcf319, 0x7533d928, 0xb155fdf5, + 0x03563482, 0x8aba3cbb, 0x28517711, 0xc20ad9f8, 0xabcc5167, 0xccad925f, + 0x4de81751, 0x3830dc8e, 0x379d5862, 0x9320f991, 0xea7a90c2, 0xfb3e7bce, + 0x5121ce64, 0x774fbe32, 0xa8b6e37e, 0xc3293d46, 0x48de5369, 0x6413e680, + 0xa2ae0810, 0xdd6db224, 0x69852dfd, 0x09072166, 0xb39a460a, 0x6445c0dd, + 0x586cdecf, 0x1c20c8ae, 0x5bbef7dd, 0x1b588d40, 0xccd2017f, 0x6bb4e3bb, + 0xdda26a7e, 0x3a59ff45, 0x3e350a44, 0xbcb4cdd5, 0x72eacea8, 0xfa6484bb, + 0x8d6612ae, 0xbf3c6f47, 0xd29be463, 0x542f5d9e, 0xaec2771b, 0xf64e6370, + 0x740e0d8d, 0xe75b1357, 0xf8721671, 0xaf537d5d, 0x4040cb08, 0x4eb4e2cc, + 0x34d2466a, 0x0115af84, 0xe1b00428, 0x95983a1d, 0x06b89fb4, 0xce6ea048, + 0x6f3f3b82, 0x3520ab82, 0x011a1d4b, 0x277227f8, 0x611560b1, 0xe7933fdc, + 0xbb3a792b, 0x344525bd, 0xa08839e1, 0x51ce794b, 0x2f32c9b7, 0xa01fbac9, + 0xe01cc87e, 0xbcc7d1f6, 0xcf0111c3, 0xa1e8aac7, 0x1a908749, 0xd44fbd9a, + 0xd0dadecb, 0xd50ada38, 0x0339c32a, 0xc6913667, 0x8df9317c, 0xe0b12b4f, + 0xf79e59b7, 0x43f5bb3a, 0xf2d519ff, 0x27d9459c, 0xbf97222c, 0x15e6fc2a, + 0x0f91fc71, 0x9b941525, 0xfae59361, 0xceb69ceb, 0xc2a86459, 0x12baa8d1, + 0xb6c1075e, 0xe3056a0c, 0x10d25065, 0xcb03a442, 0xe0ec6e0e, 0x1698db3b, + 0x4c98a0be, 0x3278e964, 0x9f1f9532, 0xe0d392df, 0xd3a0342b, 0x8971f21e, + 0x1b0a7441, 0x4ba3348c, 0xc5be7120, 0xc37632d8, 0xdf359f8d, 0x9b992f2e, + 0xe60b6f47, 0x0fe3f11d, 0xe54cda54, 0x1edad891, 0xce6279cf, 0xcd3e7e6f, + 0x1618b166, 0xfd2c1d05, 0x848fd2c5, 0xf6fb2299, 0xf523f357, 0xa6327623, + 0x93a83531, 0x56cccd02, 0xacf08162, 0x5a75ebb5, 0x6e163697, 0x88d273cc, + 0xde966292, 0x81b949d0, 0x4c50901b, 0x71c65614, 0xe6c6c7bd, 0x327a140a, + 0x45e1d006, 0xc3f27b9a, 0xc9aa53fd, 0x62a80f00, 0xbb25bfe2, 0x35bdd2f6, + 0x71126905, 0xb2040222, 0xb6cbcf7c, 0xcd769c2b, 0x53113ec0, 0x1640e3d3, + 0x38abbd60, 0x2547adf0, 0xba38209c, 0xf746ce76, 0x77afa1c5, 0x20756060, + 0x85cbfe4e, 0x8ae88dd8, 0x7aaaf9b0, 0x4cf9aa7e, 0x1948c25c, 0x02fb8a8c, + 0x01c36ae4, 0xd6ebe1f9, 0x90d4f869, 0xa65cdea0, 0x3f09252d, 0xc208e69f, + 0xb74e6132, 0xce77e25b, 0x578fdfe3, 0x3ac372e6, +} + +var p = [18]uint32{ + 0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344, 0xa4093822, 0x299f31d0, + 0x082efa98, 0xec4e6c89, 0x452821e6, 0x38d01377, 0xbe5466cf, 0x34e90c6c, + 0xc0ac29b7, 0xc97c50dd, 0x3f84d5b5, 0xb5470917, 0x9216d5d9, 0x8979fb1b, +} diff --git a/vendor/golang.org/x/sys/windows/registry/key.go b/vendor/golang.org/x/sys/windows/registry/key.go new file mode 100644 index 00000000000..c256483434f --- /dev/null +++ b/vendor/golang.org/x/sys/windows/registry/key.go @@ -0,0 +1,198 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +// Package registry provides access to the Windows registry. +// +// Here is a simple example, opening a registry key and reading a string value from it. +// +// k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) +// if err != nil { +// log.Fatal(err) +// } +// defer k.Close() +// +// s, _, err := k.GetStringValue("SystemRoot") +// if err != nil { +// log.Fatal(err) +// } +// fmt.Printf("Windows system root is %q\n", s) +// +package registry + +import ( + "io" + "syscall" + "time" +) + +const ( + // Registry key security and access rights. + // See https://msdn.microsoft.com/en-us/library/windows/desktop/ms724878.aspx + // for details. + ALL_ACCESS = 0xf003f + CREATE_LINK = 0x00020 + CREATE_SUB_KEY = 0x00004 + ENUMERATE_SUB_KEYS = 0x00008 + EXECUTE = 0x20019 + NOTIFY = 0x00010 + QUERY_VALUE = 0x00001 + READ = 0x20019 + SET_VALUE = 0x00002 + WOW64_32KEY = 0x00200 + WOW64_64KEY = 0x00100 + WRITE = 0x20006 +) + +// Key is a handle to an open Windows registry key. +// Keys can be obtained by calling OpenKey; there are +// also some predefined root keys such as CURRENT_USER. +// Keys can be used directly in the Windows API. +type Key syscall.Handle + +const ( + // Windows defines some predefined root keys that are always open. + // An application can use these keys as entry points to the registry. + // Normally these keys are used in OpenKey to open new keys, + // but they can also be used anywhere a Key is required. + CLASSES_ROOT = Key(syscall.HKEY_CLASSES_ROOT) + CURRENT_USER = Key(syscall.HKEY_CURRENT_USER) + LOCAL_MACHINE = Key(syscall.HKEY_LOCAL_MACHINE) + USERS = Key(syscall.HKEY_USERS) + CURRENT_CONFIG = Key(syscall.HKEY_CURRENT_CONFIG) + PERFORMANCE_DATA = Key(syscall.HKEY_PERFORMANCE_DATA) +) + +// Close closes open key k. +func (k Key) Close() error { + return syscall.RegCloseKey(syscall.Handle(k)) +} + +// OpenKey opens a new key with path name relative to key k. +// It accepts any open key, including CURRENT_USER and others, +// and returns the new key and an error. +// The access parameter specifies desired access rights to the +// key to be opened. +func OpenKey(k Key, path string, access uint32) (Key, error) { + p, err := syscall.UTF16PtrFromString(path) + if err != nil { + return 0, err + } + var subkey syscall.Handle + err = syscall.RegOpenKeyEx(syscall.Handle(k), p, 0, access, &subkey) + if err != nil { + return 0, err + } + return Key(subkey), nil +} + +// OpenRemoteKey opens a predefined registry key on another +// computer pcname. The key to be opened is specified by k, but +// can only be one of LOCAL_MACHINE, PERFORMANCE_DATA or USERS. +// If pcname is "", OpenRemoteKey returns local computer key. +func OpenRemoteKey(pcname string, k Key) (Key, error) { + var err error + var p *uint16 + if pcname != "" { + p, err = syscall.UTF16PtrFromString(`\\` + pcname) + if err != nil { + return 0, err + } + } + var remoteKey syscall.Handle + err = regConnectRegistry(p, syscall.Handle(k), &remoteKey) + if err != nil { + return 0, err + } + return Key(remoteKey), nil +} + +// ReadSubKeyNames returns the names of subkeys of key k. +// The parameter n controls the number of returned names, +// analogous to the way os.File.Readdirnames works. +func (k Key) ReadSubKeyNames(n int) ([]string, error) { + names := make([]string, 0) + // Registry key size limit is 255 bytes and described there: + // https://msdn.microsoft.com/library/windows/desktop/ms724872.aspx + buf := make([]uint16, 256) //plus extra room for terminating zero byte +loopItems: + for i := uint32(0); ; i++ { + if n > 0 { + if len(names) == n { + return names, nil + } + } + l := uint32(len(buf)) + for { + err := syscall.RegEnumKeyEx(syscall.Handle(k), i, &buf[0], &l, nil, nil, nil, nil) + if err == nil { + break + } + if err == syscall.ERROR_MORE_DATA { + // Double buffer size and try again. + l = uint32(2 * len(buf)) + buf = make([]uint16, l) + continue + } + if err == _ERROR_NO_MORE_ITEMS { + break loopItems + } + return names, err + } + names = append(names, syscall.UTF16ToString(buf[:l])) + } + if n > len(names) { + return names, io.EOF + } + return names, nil +} + +// CreateKey creates a key named path under open key k. +// CreateKey returns the new key and a boolean flag that reports +// whether the key already existed. +// The access parameter specifies the access rights for the key +// to be created. +func CreateKey(k Key, path string, access uint32) (newk Key, openedExisting bool, err error) { + var h syscall.Handle + var d uint32 + err = regCreateKeyEx(syscall.Handle(k), syscall.StringToUTF16Ptr(path), + 0, nil, _REG_OPTION_NON_VOLATILE, access, nil, &h, &d) + if err != nil { + return 0, false, err + } + return Key(h), d == _REG_OPENED_EXISTING_KEY, nil +} + +// DeleteKey deletes the subkey path of key k and its values. +func DeleteKey(k Key, path string) error { + return regDeleteKey(syscall.Handle(k), syscall.StringToUTF16Ptr(path)) +} + +// A KeyInfo describes the statistics of a key. It is returned by Stat. +type KeyInfo struct { + SubKeyCount uint32 + MaxSubKeyLen uint32 // size of the key's subkey with the longest name, in Unicode characters, not including the terminating zero byte + ValueCount uint32 + MaxValueNameLen uint32 // size of the key's longest value name, in Unicode characters, not including the terminating zero byte + MaxValueLen uint32 // longest data component among the key's values, in bytes + lastWriteTime syscall.Filetime +} + +// ModTime returns the key's last write time. +func (ki *KeyInfo) ModTime() time.Time { + return time.Unix(0, ki.lastWriteTime.Nanoseconds()) +} + +// Stat retrieves information about the open key k. +func (k Key) Stat() (*KeyInfo, error) { + var ki KeyInfo + err := syscall.RegQueryInfoKey(syscall.Handle(k), nil, nil, nil, + &ki.SubKeyCount, &ki.MaxSubKeyLen, nil, &ki.ValueCount, + &ki.MaxValueNameLen, &ki.MaxValueLen, nil, &ki.lastWriteTime) + if err != nil { + return nil, err + } + return &ki, nil +} diff --git a/vendor/golang.org/x/sys/windows/registry/mksyscall.go b/vendor/golang.org/x/sys/windows/registry/mksyscall.go new file mode 100644 index 00000000000..0ac95ffe731 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/registry/mksyscall.go @@ -0,0 +1,7 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package registry + +//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go syscall.go diff --git a/vendor/golang.org/x/sys/windows/registry/syscall.go b/vendor/golang.org/x/sys/windows/registry/syscall.go new file mode 100644 index 00000000000..e66643cbaa6 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/registry/syscall.go @@ -0,0 +1,32 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +package registry + +import "syscall" + +const ( + _REG_OPTION_NON_VOLATILE = 0 + + _REG_CREATED_NEW_KEY = 1 + _REG_OPENED_EXISTING_KEY = 2 + + _ERROR_NO_MORE_ITEMS syscall.Errno = 259 +) + +func LoadRegLoadMUIString() error { + return procRegLoadMUIStringW.Find() +} + +//sys regCreateKeyEx(key syscall.Handle, subkey *uint16, reserved uint32, class *uint16, options uint32, desired uint32, sa *syscall.SecurityAttributes, result *syscall.Handle, disposition *uint32) (regerrno error) = advapi32.RegCreateKeyExW +//sys regDeleteKey(key syscall.Handle, subkey *uint16) (regerrno error) = advapi32.RegDeleteKeyW +//sys regSetValueEx(key syscall.Handle, valueName *uint16, reserved uint32, vtype uint32, buf *byte, bufsize uint32) (regerrno error) = advapi32.RegSetValueExW +//sys regEnumValue(key syscall.Handle, index uint32, name *uint16, nameLen *uint32, reserved *uint32, valtype *uint32, buf *byte, buflen *uint32) (regerrno error) = advapi32.RegEnumValueW +//sys regDeleteValue(key syscall.Handle, name *uint16) (regerrno error) = advapi32.RegDeleteValueW +//sys regLoadMUIString(key syscall.Handle, name *uint16, buf *uint16, buflen uint32, buflenCopied *uint32, flags uint32, dir *uint16) (regerrno error) = advapi32.RegLoadMUIStringW +//sys regConnectRegistry(machinename *uint16, key syscall.Handle, result *syscall.Handle) (regerrno error) = advapi32.RegConnectRegistryW + +//sys expandEnvironmentStrings(src *uint16, dst *uint16, size uint32) (n uint32, err error) = kernel32.ExpandEnvironmentStringsW diff --git a/vendor/golang.org/x/sys/windows/registry/value.go b/vendor/golang.org/x/sys/windows/registry/value.go new file mode 100644 index 00000000000..71d4e15bab1 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/registry/value.go @@ -0,0 +1,384 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +package registry + +import ( + "errors" + "io" + "syscall" + "unicode/utf16" + "unsafe" +) + +const ( + // Registry value types. + NONE = 0 + SZ = 1 + EXPAND_SZ = 2 + BINARY = 3 + DWORD = 4 + DWORD_BIG_ENDIAN = 5 + LINK = 6 + MULTI_SZ = 7 + RESOURCE_LIST = 8 + FULL_RESOURCE_DESCRIPTOR = 9 + RESOURCE_REQUIREMENTS_LIST = 10 + QWORD = 11 +) + +var ( + // ErrShortBuffer is returned when the buffer was too short for the operation. + ErrShortBuffer = syscall.ERROR_MORE_DATA + + // ErrNotExist is returned when a registry key or value does not exist. + ErrNotExist = syscall.ERROR_FILE_NOT_FOUND + + // ErrUnexpectedType is returned by Get*Value when the value's type was unexpected. + ErrUnexpectedType = errors.New("unexpected key value type") +) + +// GetValue retrieves the type and data for the specified value associated +// with an open key k. It fills up buffer buf and returns the retrieved +// byte count n. If buf is too small to fit the stored value it returns +// ErrShortBuffer error along with the required buffer size n. +// If no buffer is provided, it returns true and actual buffer size n. +// If no buffer is provided, GetValue returns the value's type only. +// If the value does not exist, the error returned is ErrNotExist. +// +// GetValue is a low level function. If value's type is known, use the appropriate +// Get*Value function instead. +func (k Key) GetValue(name string, buf []byte) (n int, valtype uint32, err error) { + pname, err := syscall.UTF16PtrFromString(name) + if err != nil { + return 0, 0, err + } + var pbuf *byte + if len(buf) > 0 { + pbuf = (*byte)(unsafe.Pointer(&buf[0])) + } + l := uint32(len(buf)) + err = syscall.RegQueryValueEx(syscall.Handle(k), pname, nil, &valtype, pbuf, &l) + if err != nil { + return int(l), valtype, err + } + return int(l), valtype, nil +} + +func (k Key) getValue(name string, buf []byte) (date []byte, valtype uint32, err error) { + p, err := syscall.UTF16PtrFromString(name) + if err != nil { + return nil, 0, err + } + var t uint32 + n := uint32(len(buf)) + for { + err = syscall.RegQueryValueEx(syscall.Handle(k), p, nil, &t, (*byte)(unsafe.Pointer(&buf[0])), &n) + if err == nil { + return buf[:n], t, nil + } + if err != syscall.ERROR_MORE_DATA { + return nil, 0, err + } + if n <= uint32(len(buf)) { + return nil, 0, err + } + buf = make([]byte, n) + } +} + +// GetStringValue retrieves the string value for the specified +// value name associated with an open key k. It also returns the value's type. +// If value does not exist, GetStringValue returns ErrNotExist. +// If value is not SZ or EXPAND_SZ, it will return the correct value +// type and ErrUnexpectedType. +func (k Key) GetStringValue(name string) (val string, valtype uint32, err error) { + data, typ, err2 := k.getValue(name, make([]byte, 64)) + if err2 != nil { + return "", typ, err2 + } + switch typ { + case SZ, EXPAND_SZ: + default: + return "", typ, ErrUnexpectedType + } + if len(data) == 0 { + return "", typ, nil + } + u := (*[1 << 29]uint16)(unsafe.Pointer(&data[0]))[:] + return syscall.UTF16ToString(u), typ, nil +} + +// GetMUIStringValue retrieves the localized string value for +// the specified value name associated with an open key k. +// If the value name doesn't exist or the localized string value +// can't be resolved, GetMUIStringValue returns ErrNotExist. +// GetMUIStringValue panics if the system doesn't support +// regLoadMUIString; use LoadRegLoadMUIString to check if +// regLoadMUIString is supported before calling this function. +func (k Key) GetMUIStringValue(name string) (string, error) { + pname, err := syscall.UTF16PtrFromString(name) + if err != nil { + return "", err + } + + buf := make([]uint16, 1024) + var buflen uint32 + var pdir *uint16 + + err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir) + if err == syscall.ERROR_FILE_NOT_FOUND { // Try fallback path + + // Try to resolve the string value using the system directory as + // a DLL search path; this assumes the string value is of the form + // @[path]\dllname,-strID but with no path given, e.g. @tzres.dll,-320. + + // This approach works with tzres.dll but may have to be revised + // in the future to allow callers to provide custom search paths. + + var s string + s, err = ExpandString("%SystemRoot%\\system32\\") + if err != nil { + return "", err + } + pdir, err = syscall.UTF16PtrFromString(s) + if err != nil { + return "", err + } + + err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir) + } + + for err == syscall.ERROR_MORE_DATA { // Grow buffer if needed + if buflen <= uint32(len(buf)) { + break // Buffer not growing, assume race; break + } + buf = make([]uint16, buflen) + err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir) + } + + if err != nil { + return "", err + } + + return syscall.UTF16ToString(buf), nil +} + +// ExpandString expands environment-variable strings and replaces +// them with the values defined for the current user. +// Use ExpandString to expand EXPAND_SZ strings. +func ExpandString(value string) (string, error) { + if value == "" { + return "", nil + } + p, err := syscall.UTF16PtrFromString(value) + if err != nil { + return "", err + } + r := make([]uint16, 100) + for { + n, err := expandEnvironmentStrings(p, &r[0], uint32(len(r))) + if err != nil { + return "", err + } + if n <= uint32(len(r)) { + u := (*[1 << 29]uint16)(unsafe.Pointer(&r[0]))[:] + return syscall.UTF16ToString(u), nil + } + r = make([]uint16, n) + } +} + +// GetStringsValue retrieves the []string value for the specified +// value name associated with an open key k. It also returns the value's type. +// If value does not exist, GetStringsValue returns ErrNotExist. +// If value is not MULTI_SZ, it will return the correct value +// type and ErrUnexpectedType. +func (k Key) GetStringsValue(name string) (val []string, valtype uint32, err error) { + data, typ, err2 := k.getValue(name, make([]byte, 64)) + if err2 != nil { + return nil, typ, err2 + } + if typ != MULTI_SZ { + return nil, typ, ErrUnexpectedType + } + if len(data) == 0 { + return nil, typ, nil + } + p := (*[1 << 29]uint16)(unsafe.Pointer(&data[0]))[:len(data)/2] + if len(p) == 0 { + return nil, typ, nil + } + if p[len(p)-1] == 0 { + p = p[:len(p)-1] // remove terminating null + } + val = make([]string, 0, 5) + from := 0 + for i, c := range p { + if c == 0 { + val = append(val, string(utf16.Decode(p[from:i]))) + from = i + 1 + } + } + return val, typ, nil +} + +// GetIntegerValue retrieves the integer value for the specified +// value name associated with an open key k. It also returns the value's type. +// If value does not exist, GetIntegerValue returns ErrNotExist. +// If value is not DWORD or QWORD, it will return the correct value +// type and ErrUnexpectedType. +func (k Key) GetIntegerValue(name string) (val uint64, valtype uint32, err error) { + data, typ, err2 := k.getValue(name, make([]byte, 8)) + if err2 != nil { + return 0, typ, err2 + } + switch typ { + case DWORD: + if len(data) != 4 { + return 0, typ, errors.New("DWORD value is not 4 bytes long") + } + return uint64(*(*uint32)(unsafe.Pointer(&data[0]))), DWORD, nil + case QWORD: + if len(data) != 8 { + return 0, typ, errors.New("QWORD value is not 8 bytes long") + } + return uint64(*(*uint64)(unsafe.Pointer(&data[0]))), QWORD, nil + default: + return 0, typ, ErrUnexpectedType + } +} + +// GetBinaryValue retrieves the binary value for the specified +// value name associated with an open key k. It also returns the value's type. +// If value does not exist, GetBinaryValue returns ErrNotExist. +// If value is not BINARY, it will return the correct value +// type and ErrUnexpectedType. +func (k Key) GetBinaryValue(name string) (val []byte, valtype uint32, err error) { + data, typ, err2 := k.getValue(name, make([]byte, 64)) + if err2 != nil { + return nil, typ, err2 + } + if typ != BINARY { + return nil, typ, ErrUnexpectedType + } + return data, typ, nil +} + +func (k Key) setValue(name string, valtype uint32, data []byte) error { + p, err := syscall.UTF16PtrFromString(name) + if err != nil { + return err + } + if len(data) == 0 { + return regSetValueEx(syscall.Handle(k), p, 0, valtype, nil, 0) + } + return regSetValueEx(syscall.Handle(k), p, 0, valtype, &data[0], uint32(len(data))) +} + +// SetDWordValue sets the data and type of a name value +// under key k to value and DWORD. +func (k Key) SetDWordValue(name string, value uint32) error { + return k.setValue(name, DWORD, (*[4]byte)(unsafe.Pointer(&value))[:]) +} + +// SetQWordValue sets the data and type of a name value +// under key k to value and QWORD. +func (k Key) SetQWordValue(name string, value uint64) error { + return k.setValue(name, QWORD, (*[8]byte)(unsafe.Pointer(&value))[:]) +} + +func (k Key) setStringValue(name string, valtype uint32, value string) error { + v, err := syscall.UTF16FromString(value) + if err != nil { + return err + } + buf := (*[1 << 29]byte)(unsafe.Pointer(&v[0]))[:len(v)*2] + return k.setValue(name, valtype, buf) +} + +// SetStringValue sets the data and type of a name value +// under key k to value and SZ. The value must not contain a zero byte. +func (k Key) SetStringValue(name, value string) error { + return k.setStringValue(name, SZ, value) +} + +// SetExpandStringValue sets the data and type of a name value +// under key k to value and EXPAND_SZ. The value must not contain a zero byte. +func (k Key) SetExpandStringValue(name, value string) error { + return k.setStringValue(name, EXPAND_SZ, value) +} + +// SetStringsValue sets the data and type of a name value +// under key k to value and MULTI_SZ. The value strings +// must not contain a zero byte. +func (k Key) SetStringsValue(name string, value []string) error { + ss := "" + for _, s := range value { + for i := 0; i < len(s); i++ { + if s[i] == 0 { + return errors.New("string cannot have 0 inside") + } + } + ss += s + "\x00" + } + v := utf16.Encode([]rune(ss + "\x00")) + buf := (*[1 << 29]byte)(unsafe.Pointer(&v[0]))[:len(v)*2] + return k.setValue(name, MULTI_SZ, buf) +} + +// SetBinaryValue sets the data and type of a name value +// under key k to value and BINARY. +func (k Key) SetBinaryValue(name string, value []byte) error { + return k.setValue(name, BINARY, value) +} + +// DeleteValue removes a named value from the key k. +func (k Key) DeleteValue(name string) error { + return regDeleteValue(syscall.Handle(k), syscall.StringToUTF16Ptr(name)) +} + +// ReadValueNames returns the value names of key k. +// The parameter n controls the number of returned names, +// analogous to the way os.File.Readdirnames works. +func (k Key) ReadValueNames(n int) ([]string, error) { + ki, err := k.Stat() + if err != nil { + return nil, err + } + names := make([]string, 0, ki.ValueCount) + buf := make([]uint16, ki.MaxValueNameLen+1) // extra room for terminating null character +loopItems: + for i := uint32(0); ; i++ { + if n > 0 { + if len(names) == n { + return names, nil + } + } + l := uint32(len(buf)) + for { + err := regEnumValue(syscall.Handle(k), i, &buf[0], &l, nil, nil, nil, nil) + if err == nil { + break + } + if err == syscall.ERROR_MORE_DATA { + // Double buffer size and try again. + l = uint32(2 * len(buf)) + buf = make([]uint16, l) + continue + } + if err == _ERROR_NO_MORE_ITEMS { + break loopItems + } + return names, err + } + names = append(names, syscall.UTF16ToString(buf[:l])) + } + if n > len(names) { + return names, io.EOF + } + return names, nil +} diff --git a/vendor/golang.org/x/sys/windows/registry/zsyscall_windows.go b/vendor/golang.org/x/sys/windows/registry/zsyscall_windows.go new file mode 100644 index 00000000000..ceebdd7726d --- /dev/null +++ b/vendor/golang.org/x/sys/windows/registry/zsyscall_windows.go @@ -0,0 +1,120 @@ +// MACHINE GENERATED BY 'go generate' COMMAND; DO NOT EDIT + +package registry + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return nil + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + + procRegCreateKeyExW = modadvapi32.NewProc("RegCreateKeyExW") + procRegDeleteKeyW = modadvapi32.NewProc("RegDeleteKeyW") + procRegSetValueExW = modadvapi32.NewProc("RegSetValueExW") + procRegEnumValueW = modadvapi32.NewProc("RegEnumValueW") + procRegDeleteValueW = modadvapi32.NewProc("RegDeleteValueW") + procRegLoadMUIStringW = modadvapi32.NewProc("RegLoadMUIStringW") + procRegConnectRegistryW = modadvapi32.NewProc("RegConnectRegistryW") + procExpandEnvironmentStringsW = modkernel32.NewProc("ExpandEnvironmentStringsW") +) + +func regCreateKeyEx(key syscall.Handle, subkey *uint16, reserved uint32, class *uint16, options uint32, desired uint32, sa *syscall.SecurityAttributes, result *syscall.Handle, disposition *uint32) (regerrno error) { + r0, _, _ := syscall.Syscall9(procRegCreateKeyExW.Addr(), 9, uintptr(key), uintptr(unsafe.Pointer(subkey)), uintptr(reserved), uintptr(unsafe.Pointer(class)), uintptr(options), uintptr(desired), uintptr(unsafe.Pointer(sa)), uintptr(unsafe.Pointer(result)), uintptr(unsafe.Pointer(disposition))) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func regDeleteKey(key syscall.Handle, subkey *uint16) (regerrno error) { + r0, _, _ := syscall.Syscall(procRegDeleteKeyW.Addr(), 2, uintptr(key), uintptr(unsafe.Pointer(subkey)), 0) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func regSetValueEx(key syscall.Handle, valueName *uint16, reserved uint32, vtype uint32, buf *byte, bufsize uint32) (regerrno error) { + r0, _, _ := syscall.Syscall6(procRegSetValueExW.Addr(), 6, uintptr(key), uintptr(unsafe.Pointer(valueName)), uintptr(reserved), uintptr(vtype), uintptr(unsafe.Pointer(buf)), uintptr(bufsize)) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func regEnumValue(key syscall.Handle, index uint32, name *uint16, nameLen *uint32, reserved *uint32, valtype *uint32, buf *byte, buflen *uint32) (regerrno error) { + r0, _, _ := syscall.Syscall9(procRegEnumValueW.Addr(), 8, uintptr(key), uintptr(index), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameLen)), uintptr(unsafe.Pointer(reserved)), uintptr(unsafe.Pointer(valtype)), uintptr(unsafe.Pointer(buf)), uintptr(unsafe.Pointer(buflen)), 0) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func regDeleteValue(key syscall.Handle, name *uint16) (regerrno error) { + r0, _, _ := syscall.Syscall(procRegDeleteValueW.Addr(), 2, uintptr(key), uintptr(unsafe.Pointer(name)), 0) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func regLoadMUIString(key syscall.Handle, name *uint16, buf *uint16, buflen uint32, buflenCopied *uint32, flags uint32, dir *uint16) (regerrno error) { + r0, _, _ := syscall.Syscall9(procRegLoadMUIStringW.Addr(), 7, uintptr(key), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buf)), uintptr(buflen), uintptr(unsafe.Pointer(buflenCopied)), uintptr(flags), uintptr(unsafe.Pointer(dir)), 0, 0) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func regConnectRegistry(machinename *uint16, key syscall.Handle, result *syscall.Handle) (regerrno error) { + r0, _, _ := syscall.Syscall(procRegConnectRegistryW.Addr(), 3, uintptr(unsafe.Pointer(machinename)), uintptr(key), uintptr(unsafe.Pointer(result))) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func expandEnvironmentStrings(src *uint16, dst *uint16, size uint32) (n uint32, err error) { + r0, _, e1 := syscall.Syscall(procExpandEnvironmentStringsW.Addr(), 3, uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(size)) + n = uint32(r0) + if n == 0 { + if e1 != 0 { + err = errnoErr(e1) + } else { + err = syscall.EINVAL + } + } + return +} diff --git a/vendor/golang.org/x/sys/windows/svc/debug/log.go b/vendor/golang.org/x/sys/windows/svc/debug/log.go new file mode 100644 index 00000000000..e51ab42a1a2 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/debug/log.go @@ -0,0 +1,56 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +package debug + +import ( + "os" + "strconv" +) + +// Log interface allows different log implementations to be used. +type Log interface { + Close() error + Info(eid uint32, msg string) error + Warning(eid uint32, msg string) error + Error(eid uint32, msg string) error +} + +// ConsoleLog provides access to the console. +type ConsoleLog struct { + Name string +} + +// New creates new ConsoleLog. +func New(source string) *ConsoleLog { + return &ConsoleLog{Name: source} +} + +// Close closes console log l. +func (l *ConsoleLog) Close() error { + return nil +} + +func (l *ConsoleLog) report(kind string, eid uint32, msg string) error { + s := l.Name + "." + kind + "(" + strconv.Itoa(int(eid)) + "): " + msg + "\n" + _, err := os.Stdout.Write([]byte(s)) + return err +} + +// Info writes an information event msg with event id eid to the console l. +func (l *ConsoleLog) Info(eid uint32, msg string) error { + return l.report("info", eid, msg) +} + +// Warning writes an warning event msg with event id eid to the console l. +func (l *ConsoleLog) Warning(eid uint32, msg string) error { + return l.report("warn", eid, msg) +} + +// Error writes an error event msg with event id eid to the console l. +func (l *ConsoleLog) Error(eid uint32, msg string) error { + return l.report("error", eid, msg) +} diff --git a/vendor/golang.org/x/sys/windows/svc/debug/service.go b/vendor/golang.org/x/sys/windows/svc/debug/service.go new file mode 100644 index 00000000000..e621b87adc6 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/debug/service.go @@ -0,0 +1,45 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +// Package debug provides facilities to execute svc.Handler on console. +// +package debug + +import ( + "os" + "os/signal" + "syscall" + + "golang.org/x/sys/windows/svc" +) + +// Run executes service name by calling appropriate handler function. +// The process is running on console, unlike real service. Use Ctrl+C to +// send "Stop" command to your service. +func Run(name string, handler svc.Handler) error { + cmds := make(chan svc.ChangeRequest) + changes := make(chan svc.Status) + + sig := make(chan os.Signal) + signal.Notify(sig) + + go func() { + status := svc.Status{State: svc.Stopped} + for { + select { + case <-sig: + cmds <- svc.ChangeRequest{Cmd: svc.Stop, CurrentStatus: status} + case status = <-changes: + } + } + }() + + _, errno := handler.Execute([]string{name}, cmds, changes) + if errno != 0 { + return syscall.Errno(errno) + } + return nil +} diff --git a/vendor/golang.org/x/sys/windows/svc/event.go b/vendor/golang.org/x/sys/windows/svc/event.go new file mode 100644 index 00000000000..0508e228818 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/event.go @@ -0,0 +1,48 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +package svc + +import ( + "errors" + + "golang.org/x/sys/windows" +) + +// event represents auto-reset, initially non-signaled Windows event. +// It is used to communicate between go and asm parts of this package. +type event struct { + h windows.Handle +} + +func newEvent() (*event, error) { + h, err := windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return nil, err + } + return &event{h: h}, nil +} + +func (e *event) Close() error { + return windows.CloseHandle(e.h) +} + +func (e *event) Set() error { + return windows.SetEvent(e.h) +} + +func (e *event) Wait() error { + s, err := windows.WaitForSingleObject(e.h, windows.INFINITE) + switch s { + case windows.WAIT_OBJECT_0: + break + case windows.WAIT_FAILED: + return err + default: + return errors.New("unexpected result from WaitForSingleObject") + } + return nil +} diff --git a/vendor/golang.org/x/sys/windows/svc/eventlog/install.go b/vendor/golang.org/x/sys/windows/svc/eventlog/install.go new file mode 100644 index 00000000000..c76a3760a42 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/eventlog/install.go @@ -0,0 +1,80 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +package eventlog + +import ( + "errors" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" +) + +const ( + // Log levels. + Info = windows.EVENTLOG_INFORMATION_TYPE + Warning = windows.EVENTLOG_WARNING_TYPE + Error = windows.EVENTLOG_ERROR_TYPE +) + +const addKeyName = `SYSTEM\CurrentControlSet\Services\EventLog\Application` + +// Install modifies PC registry to allow logging with an event source src. +// It adds all required keys and values to the event log registry key. +// Install uses msgFile as the event message file. If useExpandKey is true, +// the event message file is installed as REG_EXPAND_SZ value, +// otherwise as REG_SZ. Use bitwise of log.Error, log.Warning and +// log.Info to specify events supported by the new event source. +func Install(src, msgFile string, useExpandKey bool, eventsSupported uint32) error { + appkey, err := registry.OpenKey(registry.LOCAL_MACHINE, addKeyName, registry.CREATE_SUB_KEY) + if err != nil { + return err + } + defer appkey.Close() + + sk, alreadyExist, err := registry.CreateKey(appkey, src, registry.SET_VALUE) + if err != nil { + return err + } + defer sk.Close() + if alreadyExist { + return errors.New(addKeyName + `\` + src + " registry key already exists") + } + + err = sk.SetDWordValue("CustomSource", 1) + if err != nil { + return err + } + if useExpandKey { + err = sk.SetExpandStringValue("EventMessageFile", msgFile) + } else { + err = sk.SetStringValue("EventMessageFile", msgFile) + } + if err != nil { + return err + } + err = sk.SetDWordValue("TypesSupported", eventsSupported) + if err != nil { + return err + } + return nil +} + +// InstallAsEventCreate is the same as Install, but uses +// %SystemRoot%\System32\EventCreate.exe as the event message file. +func InstallAsEventCreate(src string, eventsSupported uint32) error { + return Install(src, "%SystemRoot%\\System32\\EventCreate.exe", true, eventsSupported) +} + +// Remove deletes all registry elements installed by the correspondent Install. +func Remove(src string) error { + appkey, err := registry.OpenKey(registry.LOCAL_MACHINE, addKeyName, registry.SET_VALUE) + if err != nil { + return err + } + defer appkey.Close() + return registry.DeleteKey(appkey, src) +} diff --git a/vendor/golang.org/x/sys/windows/svc/eventlog/log.go b/vendor/golang.org/x/sys/windows/svc/eventlog/log.go new file mode 100644 index 00000000000..46e5153d024 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/eventlog/log.go @@ -0,0 +1,70 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +// Package eventlog implements access to Windows event log. +// +package eventlog + +import ( + "errors" + "syscall" + + "golang.org/x/sys/windows" +) + +// Log provides access to the system log. +type Log struct { + Handle windows.Handle +} + +// Open retrieves a handle to the specified event log. +func Open(source string) (*Log, error) { + return OpenRemote("", source) +} + +// OpenRemote does the same as Open, but on different computer host. +func OpenRemote(host, source string) (*Log, error) { + if source == "" { + return nil, errors.New("Specify event log source") + } + var s *uint16 + if host != "" { + s = syscall.StringToUTF16Ptr(host) + } + h, err := windows.RegisterEventSource(s, syscall.StringToUTF16Ptr(source)) + if err != nil { + return nil, err + } + return &Log{Handle: h}, nil +} + +// Close closes event log l. +func (l *Log) Close() error { + return windows.DeregisterEventSource(l.Handle) +} + +func (l *Log) report(etype uint16, eid uint32, msg string) error { + ss := []*uint16{syscall.StringToUTF16Ptr(msg)} + return windows.ReportEvent(l.Handle, etype, 0, eid, 0, 1, 0, &ss[0], nil) +} + +// Info writes an information event msg with event id eid to the end of event log l. +// When EventCreate.exe is used, eid must be between 1 and 1000. +func (l *Log) Info(eid uint32, msg string) error { + return l.report(windows.EVENTLOG_INFORMATION_TYPE, eid, msg) +} + +// Warning writes an warning event msg with event id eid to the end of event log l. +// When EventCreate.exe is used, eid must be between 1 and 1000. +func (l *Log) Warning(eid uint32, msg string) error { + return l.report(windows.EVENTLOG_WARNING_TYPE, eid, msg) +} + +// Error writes an error event msg with event id eid to the end of event log l. +// When EventCreate.exe is used, eid must be between 1 and 1000. +func (l *Log) Error(eid uint32, msg string) error { + return l.report(windows.EVENTLOG_ERROR_TYPE, eid, msg) +} diff --git a/vendor/golang.org/x/sys/windows/svc/go12.c b/vendor/golang.org/x/sys/windows/svc/go12.c new file mode 100644 index 00000000000..6f1be1fa3bc --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/go12.c @@ -0,0 +1,24 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows +// +build !go1.3 + +// copied from pkg/runtime +typedef unsigned int uint32; +typedef unsigned long long int uint64; +#ifdef _64BIT +typedef uint64 uintptr; +#else +typedef uint32 uintptr; +#endif + +// from sys_386.s or sys_amd64.s +void ·servicemain(void); + +void +·getServiceMain(uintptr *r) +{ + *r = (uintptr)·servicemain; +} diff --git a/vendor/golang.org/x/sys/windows/svc/go12.go b/vendor/golang.org/x/sys/windows/svc/go12.go new file mode 100644 index 00000000000..cd8b913c99d --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/go12.go @@ -0,0 +1,11 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows +// +build !go1.3 + +package svc + +// from go12.c +func getServiceMain(r *uintptr) diff --git a/vendor/golang.org/x/sys/windows/svc/go13.go b/vendor/golang.org/x/sys/windows/svc/go13.go new file mode 100644 index 00000000000..9d7f3cec54c --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/go13.go @@ -0,0 +1,31 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows +// +build go1.3 + +package svc + +import "unsafe" + +const ptrSize = 4 << (^uintptr(0) >> 63) // unsafe.Sizeof(uintptr(0)) but an ideal const + +// Should be a built-in for unsafe.Pointer? +func add(p unsafe.Pointer, x uintptr) unsafe.Pointer { + return unsafe.Pointer(uintptr(p) + x) +} + +// funcPC returns the entry PC of the function f. +// It assumes that f is a func value. Otherwise the behavior is undefined. +func funcPC(f interface{}) uintptr { + return **(**uintptr)(add(unsafe.Pointer(&f), ptrSize)) +} + +// from sys_386.s and sys_amd64.s +func servicectlhandler(ctl uint32) uintptr +func servicemain(argc uint32, argv **uint16) + +func getServiceMain(r *uintptr) { + *r = funcPC(servicemain) +} diff --git a/vendor/golang.org/x/sys/windows/svc/mgr/config.go b/vendor/golang.org/x/sys/windows/svc/mgr/config.go new file mode 100644 index 00000000000..03bf41f5162 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/mgr/config.go @@ -0,0 +1,139 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +package mgr + +import ( + "syscall" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + // Service start types. + StartManual = windows.SERVICE_DEMAND_START // the service must be started manually + StartAutomatic = windows.SERVICE_AUTO_START // the service will start by itself whenever the computer reboots + StartDisabled = windows.SERVICE_DISABLED // the service cannot be started + + // The severity of the error, and action taken, + // if this service fails to start. + ErrorCritical = windows.SERVICE_ERROR_CRITICAL + ErrorIgnore = windows.SERVICE_ERROR_IGNORE + ErrorNormal = windows.SERVICE_ERROR_NORMAL + ErrorSevere = windows.SERVICE_ERROR_SEVERE +) + +// TODO(brainman): Password is not returned by windows.QueryServiceConfig, not sure how to get it. + +type Config struct { + ServiceType uint32 + StartType uint32 + ErrorControl uint32 + BinaryPathName string // fully qualified path to the service binary file, can also include arguments for an auto-start service + LoadOrderGroup string + TagId uint32 + Dependencies []string + ServiceStartName string // name of the account under which the service should run + DisplayName string + Password string + Description string +} + +func toString(p *uint16) string { + if p == nil { + return "" + } + return syscall.UTF16ToString((*[4096]uint16)(unsafe.Pointer(p))[:]) +} + +func toStringSlice(ps *uint16) []string { + if ps == nil { + return nil + } + r := make([]string, 0) + for from, i, p := 0, 0, (*[1 << 24]uint16)(unsafe.Pointer(ps)); true; i++ { + if p[i] == 0 { + // empty string marks the end + if i <= from { + break + } + r = append(r, string(utf16.Decode(p[from:i]))) + from = i + 1 + } + } + return r +} + +// Config retrieves service s configuration paramteres. +func (s *Service) Config() (Config, error) { + var p *windows.QUERY_SERVICE_CONFIG + n := uint32(1024) + for { + b := make([]byte, n) + p = (*windows.QUERY_SERVICE_CONFIG)(unsafe.Pointer(&b[0])) + err := windows.QueryServiceConfig(s.Handle, p, n, &n) + if err == nil { + break + } + if err.(syscall.Errno) != syscall.ERROR_INSUFFICIENT_BUFFER { + return Config{}, err + } + if n <= uint32(len(b)) { + return Config{}, err + } + } + + var p2 *windows.SERVICE_DESCRIPTION + n = uint32(1024) + for { + b := make([]byte, n) + p2 = (*windows.SERVICE_DESCRIPTION)(unsafe.Pointer(&b[0])) + err := windows.QueryServiceConfig2(s.Handle, + windows.SERVICE_CONFIG_DESCRIPTION, &b[0], n, &n) + if err == nil { + break + } + if err.(syscall.Errno) != syscall.ERROR_INSUFFICIENT_BUFFER { + return Config{}, err + } + if n <= uint32(len(b)) { + return Config{}, err + } + } + + return Config{ + ServiceType: p.ServiceType, + StartType: p.StartType, + ErrorControl: p.ErrorControl, + BinaryPathName: toString(p.BinaryPathName), + LoadOrderGroup: toString(p.LoadOrderGroup), + TagId: p.TagId, + Dependencies: toStringSlice(p.Dependencies), + ServiceStartName: toString(p.ServiceStartName), + DisplayName: toString(p.DisplayName), + Description: toString(p2.Description), + }, nil +} + +func updateDescription(handle windows.Handle, desc string) error { + d := windows.SERVICE_DESCRIPTION{Description: toPtr(desc)} + return windows.ChangeServiceConfig2(handle, + windows.SERVICE_CONFIG_DESCRIPTION, (*byte)(unsafe.Pointer(&d))) +} + +// UpdateConfig updates service s configuration parameters. +func (s *Service) UpdateConfig(c Config) error { + err := windows.ChangeServiceConfig(s.Handle, c.ServiceType, c.StartType, + c.ErrorControl, toPtr(c.BinaryPathName), toPtr(c.LoadOrderGroup), + nil, toStringBlock(c.Dependencies), toPtr(c.ServiceStartName), + toPtr(c.Password), toPtr(c.DisplayName)) + if err != nil { + return err + } + return updateDescription(s.Handle, c.Description) +} diff --git a/vendor/golang.org/x/sys/windows/svc/mgr/mgr.go b/vendor/golang.org/x/sys/windows/svc/mgr/mgr.go new file mode 100644 index 00000000000..76965b56018 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/mgr/mgr.go @@ -0,0 +1,162 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +// Package mgr can be used to manage Windows service programs. +// It can be used to install and remove them. It can also start, +// stop and pause them. The package can query / change current +// service state and config parameters. +// +package mgr + +import ( + "syscall" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +// Mgr is used to manage Windows service. +type Mgr struct { + Handle windows.Handle +} + +// Connect establishes a connection to the service control manager. +func Connect() (*Mgr, error) { + return ConnectRemote("") +} + +// ConnectRemote establishes a connection to the +// service control manager on computer named host. +func ConnectRemote(host string) (*Mgr, error) { + var s *uint16 + if host != "" { + s = syscall.StringToUTF16Ptr(host) + } + h, err := windows.OpenSCManager(s, nil, windows.SC_MANAGER_ALL_ACCESS) + if err != nil { + return nil, err + } + return &Mgr{Handle: h}, nil +} + +// Disconnect closes connection to the service control manager m. +func (m *Mgr) Disconnect() error { + return windows.CloseServiceHandle(m.Handle) +} + +func toPtr(s string) *uint16 { + if len(s) == 0 { + return nil + } + return syscall.StringToUTF16Ptr(s) +} + +// toStringBlock terminates strings in ss with 0, and then +// concatenates them together. It also adds extra 0 at the end. +func toStringBlock(ss []string) *uint16 { + if len(ss) == 0 { + return nil + } + t := "" + for _, s := range ss { + if s != "" { + t += s + "\x00" + } + } + if t == "" { + return nil + } + t += "\x00" + return &utf16.Encode([]rune(t))[0] +} + +// CreateService installs new service name on the system. +// The service will be executed by running exepath binary. +// Use config c to specify service parameters. +// Any args will be passed as command-line arguments when +// the service is started; these arguments are distinct from +// the arguments passed to Service.Start or via the "Start +// parameters" field in the service's Properties dialog box. +func (m *Mgr) CreateService(name, exepath string, c Config, args ...string) (*Service, error) { + if c.StartType == 0 { + c.StartType = StartManual + } + if c.ErrorControl == 0 { + c.ErrorControl = ErrorNormal + } + if c.ServiceType == 0 { + c.ServiceType = windows.SERVICE_WIN32_OWN_PROCESS + } + s := syscall.EscapeArg(exepath) + for _, v := range args { + s += " " + syscall.EscapeArg(v) + } + h, err := windows.CreateService(m.Handle, toPtr(name), toPtr(c.DisplayName), + windows.SERVICE_ALL_ACCESS, c.ServiceType, + c.StartType, c.ErrorControl, toPtr(s), toPtr(c.LoadOrderGroup), + nil, toStringBlock(c.Dependencies), toPtr(c.ServiceStartName), toPtr(c.Password)) + if err != nil { + return nil, err + } + if c.Description != "" { + err = updateDescription(h, c.Description) + if err != nil { + return nil, err + } + } + return &Service{Name: name, Handle: h}, nil +} + +// OpenService retrieves access to service name, so it can +// be interrogated and controlled. +func (m *Mgr) OpenService(name string) (*Service, error) { + h, err := windows.OpenService(m.Handle, syscall.StringToUTF16Ptr(name), windows.SERVICE_ALL_ACCESS) + if err != nil { + return nil, err + } + return &Service{Name: name, Handle: h}, nil +} + +// ListServices enumerates services in the specified +// service control manager database m. +// If the caller does not have the SERVICE_QUERY_STATUS +// access right to a service, the service is silently +// omitted from the list of services returned. +func (m *Mgr) ListServices() ([]string, error) { + var err error + var bytesNeeded, servicesReturned uint32 + var buf []byte + for { + var p *byte + if len(buf) > 0 { + p = &buf[0] + } + err = windows.EnumServicesStatusEx(m.Handle, windows.SC_ENUM_PROCESS_INFO, + windows.SERVICE_WIN32, windows.SERVICE_STATE_ALL, + p, uint32(len(buf)), &bytesNeeded, &servicesReturned, nil, nil) + if err == nil { + break + } + if err != syscall.ERROR_MORE_DATA { + return nil, err + } + if bytesNeeded <= uint32(len(buf)) { + return nil, err + } + buf = make([]byte, bytesNeeded) + } + if servicesReturned == 0 { + return nil, nil + } + services := (*[1 << 20]windows.ENUM_SERVICE_STATUS_PROCESS)(unsafe.Pointer(&buf[0]))[:servicesReturned] + var names []string + for _, s := range services { + name := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(s.ServiceName))[:]) + names = append(names, name) + } + return names, nil +} diff --git a/vendor/golang.org/x/sys/windows/svc/mgr/service.go b/vendor/golang.org/x/sys/windows/svc/mgr/service.go new file mode 100644 index 00000000000..fdc46af5fcb --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/mgr/service.go @@ -0,0 +1,72 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +package mgr + +import ( + "syscall" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" +) + +// TODO(brainman): Use EnumDependentServices to enumerate dependent services. + +// Service is used to access Windows service. +type Service struct { + Name string + Handle windows.Handle +} + +// Delete marks service s for deletion from the service control manager database. +func (s *Service) Delete() error { + return windows.DeleteService(s.Handle) +} + +// Close relinquish access to the service s. +func (s *Service) Close() error { + return windows.CloseServiceHandle(s.Handle) +} + +// Start starts service s. +// args will be passed to svc.Handler.Execute. +func (s *Service) Start(args ...string) error { + var p **uint16 + if len(args) > 0 { + vs := make([]*uint16, len(args)) + for i := range vs { + vs[i] = syscall.StringToUTF16Ptr(args[i]) + } + p = &vs[0] + } + return windows.StartService(s.Handle, uint32(len(args)), p) +} + +// Control sends state change request c to the servce s. +func (s *Service) Control(c svc.Cmd) (svc.Status, error) { + var t windows.SERVICE_STATUS + err := windows.ControlService(s.Handle, uint32(c), &t) + if err != nil { + return svc.Status{}, err + } + return svc.Status{ + State: svc.State(t.CurrentState), + Accepts: svc.Accepted(t.ControlsAccepted), + }, nil +} + +// Query returns current status of service s. +func (s *Service) Query() (svc.Status, error) { + var t windows.SERVICE_STATUS + err := windows.QueryServiceStatus(s.Handle, &t) + if err != nil { + return svc.Status{}, err + } + return svc.Status{ + State: svc.State(t.CurrentState), + Accepts: svc.Accepted(t.ControlsAccepted), + }, nil +} diff --git a/vendor/golang.org/x/sys/windows/svc/security.go b/vendor/golang.org/x/sys/windows/svc/security.go new file mode 100644 index 00000000000..6fbc9236ed5 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/security.go @@ -0,0 +1,62 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +package svc + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +func allocSid(subAuth0 uint32) (*windows.SID, error) { + var sid *windows.SID + err := windows.AllocateAndInitializeSid(&windows.SECURITY_NT_AUTHORITY, + 1, subAuth0, 0, 0, 0, 0, 0, 0, 0, &sid) + if err != nil { + return nil, err + } + return sid, nil +} + +// IsAnInteractiveSession determines if calling process is running interactively. +// It queries the process token for membership in the Interactive group. +// http://stackoverflow.com/questions/2668851/how-do-i-detect-that-my-application-is-running-as-service-or-in-an-interactive-s +func IsAnInteractiveSession() (bool, error) { + interSid, err := allocSid(windows.SECURITY_INTERACTIVE_RID) + if err != nil { + return false, err + } + defer windows.FreeSid(interSid) + + serviceSid, err := allocSid(windows.SECURITY_SERVICE_RID) + if err != nil { + return false, err + } + defer windows.FreeSid(serviceSid) + + t, err := windows.OpenCurrentProcessToken() + if err != nil { + return false, err + } + defer t.Close() + + gs, err := t.GetTokenGroups() + if err != nil { + return false, err + } + p := unsafe.Pointer(&gs.Groups[0]) + groups := (*[2 << 20]windows.SIDAndAttributes)(p)[:gs.GroupCount] + for _, g := range groups { + if windows.EqualSid(g.Sid, interSid) { + return true, nil + } + if windows.EqualSid(g.Sid, serviceSid) { + return false, nil + } + } + return false, nil +} diff --git a/vendor/golang.org/x/sys/windows/svc/service.go b/vendor/golang.org/x/sys/windows/svc/service.go new file mode 100644 index 00000000000..cda26b54b39 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/service.go @@ -0,0 +1,363 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +// Package svc provides everything required to build Windows service. +// +package svc + +import ( + "errors" + "runtime" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +// State describes service execution state (Stopped, Running and so on). +type State uint32 + +const ( + Stopped = State(windows.SERVICE_STOPPED) + StartPending = State(windows.SERVICE_START_PENDING) + StopPending = State(windows.SERVICE_STOP_PENDING) + Running = State(windows.SERVICE_RUNNING) + ContinuePending = State(windows.SERVICE_CONTINUE_PENDING) + PausePending = State(windows.SERVICE_PAUSE_PENDING) + Paused = State(windows.SERVICE_PAUSED) +) + +// Cmd represents service state change request. It is sent to a service +// by the service manager, and should be actioned upon by the service. +type Cmd uint32 + +const ( + Stop = Cmd(windows.SERVICE_CONTROL_STOP) + Pause = Cmd(windows.SERVICE_CONTROL_PAUSE) + Continue = Cmd(windows.SERVICE_CONTROL_CONTINUE) + Interrogate = Cmd(windows.SERVICE_CONTROL_INTERROGATE) + Shutdown = Cmd(windows.SERVICE_CONTROL_SHUTDOWN) + ParamChange = Cmd(windows.SERVICE_CONTROL_PARAMCHANGE) + NetBindAdd = Cmd(windows.SERVICE_CONTROL_NETBINDADD) + NetBindRemove = Cmd(windows.SERVICE_CONTROL_NETBINDREMOVE) + NetBindEnable = Cmd(windows.SERVICE_CONTROL_NETBINDENABLE) + NetBindDisable = Cmd(windows.SERVICE_CONTROL_NETBINDDISABLE) + DeviceEvent = Cmd(windows.SERVICE_CONTROL_DEVICEEVENT) + HardwareProfileChange = Cmd(windows.SERVICE_CONTROL_HARDWAREPROFILECHANGE) + PowerEvent = Cmd(windows.SERVICE_CONTROL_POWEREVENT) + SessionChange = Cmd(windows.SERVICE_CONTROL_SESSIONCHANGE) +) + +// Accepted is used to describe commands accepted by the service. +// Note that Interrogate is always accepted. +type Accepted uint32 + +const ( + AcceptStop = Accepted(windows.SERVICE_ACCEPT_STOP) + AcceptShutdown = Accepted(windows.SERVICE_ACCEPT_SHUTDOWN) + AcceptPauseAndContinue = Accepted(windows.SERVICE_ACCEPT_PAUSE_CONTINUE) + AcceptParamChange = Accepted(windows.SERVICE_ACCEPT_PARAMCHANGE) + AcceptNetBindChange = Accepted(windows.SERVICE_ACCEPT_NETBINDCHANGE) + AcceptHardwareProfileChange = Accepted(windows.SERVICE_ACCEPT_HARDWAREPROFILECHANGE) + AcceptPowerEvent = Accepted(windows.SERVICE_ACCEPT_POWEREVENT) + AcceptSessionChange = Accepted(windows.SERVICE_ACCEPT_SESSIONCHANGE) +) + +// Status combines State and Accepted commands to fully describe running service. +type Status struct { + State State + Accepts Accepted + CheckPoint uint32 // used to report progress during a lengthy operation + WaitHint uint32 // estimated time required for a pending operation, in milliseconds +} + +// ChangeRequest is sent to the service Handler to request service status change. +type ChangeRequest struct { + Cmd Cmd + EventType uint32 + EventData uintptr + CurrentStatus Status +} + +// Handler is the interface that must be implemented to build Windows service. +type Handler interface { + + // Execute will be called by the package code at the start of + // the service, and the service will exit once Execute completes. + // Inside Execute you must read service change requests from r and + // act accordingly. You must keep service control manager up to date + // about state of your service by writing into s as required. + // args contains service name followed by argument strings passed + // to the service. + // You can provide service exit code in exitCode return parameter, + // with 0 being "no error". You can also indicate if exit code, + // if any, is service specific or not by using svcSpecificEC + // parameter. + Execute(args []string, r <-chan ChangeRequest, s chan<- Status) (svcSpecificEC bool, exitCode uint32) +} + +var ( + // These are used by asm code. + goWaitsH uintptr + cWaitsH uintptr + ssHandle uintptr + sName *uint16 + sArgc uintptr + sArgv **uint16 + ctlHandlerExProc uintptr + cSetEvent uintptr + cWaitForSingleObject uintptr + cRegisterServiceCtrlHandlerExW uintptr +) + +func init() { + k := syscall.MustLoadDLL("kernel32.dll") + cSetEvent = k.MustFindProc("SetEvent").Addr() + cWaitForSingleObject = k.MustFindProc("WaitForSingleObject").Addr() + a := syscall.MustLoadDLL("advapi32.dll") + cRegisterServiceCtrlHandlerExW = a.MustFindProc("RegisterServiceCtrlHandlerExW").Addr() +} + +// The HandlerEx prototype also has a context pointer but since we don't use +// it at start-up time we don't have to pass it over either. +type ctlEvent struct { + cmd Cmd + eventType uint32 + eventData uintptr + errno uint32 +} + +// service provides access to windows service api. +type service struct { + name string + h windows.Handle + cWaits *event + goWaits *event + c chan ctlEvent + handler Handler +} + +func newService(name string, handler Handler) (*service, error) { + var s service + var err error + s.name = name + s.c = make(chan ctlEvent) + s.handler = handler + s.cWaits, err = newEvent() + if err != nil { + return nil, err + } + s.goWaits, err = newEvent() + if err != nil { + s.cWaits.Close() + return nil, err + } + return &s, nil +} + +func (s *service) close() error { + s.cWaits.Close() + s.goWaits.Close() + return nil +} + +type exitCode struct { + isSvcSpecific bool + errno uint32 +} + +func (s *service) updateStatus(status *Status, ec *exitCode) error { + if s.h == 0 { + return errors.New("updateStatus with no service status handle") + } + var t windows.SERVICE_STATUS + t.ServiceType = windows.SERVICE_WIN32_OWN_PROCESS + t.CurrentState = uint32(status.State) + if status.Accepts&AcceptStop != 0 { + t.ControlsAccepted |= windows.SERVICE_ACCEPT_STOP + } + if status.Accepts&AcceptShutdown != 0 { + t.ControlsAccepted |= windows.SERVICE_ACCEPT_SHUTDOWN + } + if status.Accepts&AcceptPauseAndContinue != 0 { + t.ControlsAccepted |= windows.SERVICE_ACCEPT_PAUSE_CONTINUE + } + if status.Accepts&AcceptParamChange != 0 { + t.ControlsAccepted |= windows.SERVICE_ACCEPT_PARAMCHANGE + } + if status.Accepts&AcceptNetBindChange != 0 { + t.ControlsAccepted |= windows.SERVICE_ACCEPT_NETBINDCHANGE + } + if status.Accepts&AcceptHardwareProfileChange != 0 { + t.ControlsAccepted |= windows.SERVICE_ACCEPT_HARDWAREPROFILECHANGE + } + if status.Accepts&AcceptPowerEvent != 0 { + t.ControlsAccepted |= windows.SERVICE_ACCEPT_POWEREVENT + } + if status.Accepts&AcceptSessionChange != 0 { + t.ControlsAccepted |= windows.SERVICE_ACCEPT_SESSIONCHANGE + } + if ec.errno == 0 { + t.Win32ExitCode = windows.NO_ERROR + t.ServiceSpecificExitCode = windows.NO_ERROR + } else if ec.isSvcSpecific { + t.Win32ExitCode = uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) + t.ServiceSpecificExitCode = ec.errno + } else { + t.Win32ExitCode = ec.errno + t.ServiceSpecificExitCode = windows.NO_ERROR + } + t.CheckPoint = status.CheckPoint + t.WaitHint = status.WaitHint + return windows.SetServiceStatus(s.h, &t) +} + +const ( + sysErrSetServiceStatusFailed = uint32(syscall.APPLICATION_ERROR) + iota + sysErrNewThreadInCallback +) + +func (s *service) run() { + s.goWaits.Wait() + s.h = windows.Handle(ssHandle) + argv := (*[100]*int16)(unsafe.Pointer(sArgv))[:sArgc] + args := make([]string, len(argv)) + for i, a := range argv { + args[i] = syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(a))[:]) + } + + cmdsToHandler := make(chan ChangeRequest) + changesFromHandler := make(chan Status) + exitFromHandler := make(chan exitCode) + + go func() { + ss, errno := s.handler.Execute(args, cmdsToHandler, changesFromHandler) + exitFromHandler <- exitCode{ss, errno} + }() + + status := Status{State: Stopped} + ec := exitCode{isSvcSpecific: true, errno: 0} + var outch chan ChangeRequest + inch := s.c + var cmd Cmd + var evtype uint32 + var evdata uintptr +loop: + for { + select { + case r := <-inch: + if r.errno != 0 { + ec.errno = r.errno + break loop + } + inch = nil + outch = cmdsToHandler + cmd = r.cmd + evtype = r.eventType + evdata = r.eventData + case outch <- ChangeRequest{cmd, evtype, evdata, status}: + inch = s.c + outch = nil + case c := <-changesFromHandler: + err := s.updateStatus(&c, &ec) + if err != nil { + // best suitable error number + ec.errno = sysErrSetServiceStatusFailed + if err2, ok := err.(syscall.Errno); ok { + ec.errno = uint32(err2) + } + break loop + } + status = c + case ec = <-exitFromHandler: + break loop + } + } + + s.updateStatus(&Status{State: Stopped}, &ec) + s.cWaits.Set() +} + +func newCallback(fn interface{}) (cb uintptr, err error) { + defer func() { + r := recover() + if r == nil { + return + } + cb = 0 + switch v := r.(type) { + case string: + err = errors.New(v) + case error: + err = v + default: + err = errors.New("unexpected panic in syscall.NewCallback") + } + }() + return syscall.NewCallback(fn), nil +} + +// BUG(brainman): There is no mechanism to run multiple services +// inside one single executable. Perhaps, it can be overcome by +// using RegisterServiceCtrlHandlerEx Windows api. + +// Run executes service name by calling appropriate handler function. +func Run(name string, handler Handler) error { + runtime.LockOSThread() + + tid := windows.GetCurrentThreadId() + + s, err := newService(name, handler) + if err != nil { + return err + } + + ctlHandler := func(ctl uint32, evtype uint32, evdata uintptr, context uintptr) uintptr { + e := ctlEvent{cmd: Cmd(ctl), eventType: evtype, eventData: evdata} + // We assume that this callback function is running on + // the same thread as Run. Nowhere in MS documentation + // I could find statement to guarantee that. So putting + // check here to verify, otherwise things will go bad + // quickly, if ignored. + i := windows.GetCurrentThreadId() + if i != tid { + e.errno = sysErrNewThreadInCallback + } + s.c <- e + // Always return NO_ERROR (0) for now. + return 0 + } + + var svcmain uintptr + getServiceMain(&svcmain) + t := []windows.SERVICE_TABLE_ENTRY{ + {ServiceName: syscall.StringToUTF16Ptr(s.name), ServiceProc: svcmain}, + {ServiceName: nil, ServiceProc: 0}, + } + + goWaitsH = uintptr(s.goWaits.h) + cWaitsH = uintptr(s.cWaits.h) + sName = t[0].ServiceName + ctlHandlerExProc, err = newCallback(ctlHandler) + if err != nil { + return err + } + + go s.run() + + err = windows.StartServiceCtrlDispatcher(&t[0]) + if err != nil { + return err + } + return nil +} + +// StatusHandle returns service status handle. It is safe to call this function +// from inside the Handler.Execute because then it is guaranteed to be set. +// This code will have to change once multiple services are possible per process. +func StatusHandle() windows.Handle { + return windows.Handle(ssHandle) +} diff --git a/vendor/golang.org/x/sys/windows/svc/sys_386.s b/vendor/golang.org/x/sys/windows/svc/sys_386.s new file mode 100644 index 00000000000..2c82a9d91d7 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/sys_386.s @@ -0,0 +1,68 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +// func servicemain(argc uint32, argv **uint16) +TEXT ·servicemain(SB),7,$0 + MOVL argc+0(FP), AX + MOVL AX, ·sArgc(SB) + MOVL argv+4(FP), AX + MOVL AX, ·sArgv(SB) + + PUSHL BP + PUSHL BX + PUSHL SI + PUSHL DI + + SUBL $12, SP + + MOVL ·sName(SB), AX + MOVL AX, (SP) + MOVL $·servicectlhandler(SB), AX + MOVL AX, 4(SP) + MOVL $0, 8(SP) + MOVL ·cRegisterServiceCtrlHandlerExW(SB), AX + MOVL SP, BP + CALL AX + MOVL BP, SP + CMPL AX, $0 + JE exit + MOVL AX, ·ssHandle(SB) + + MOVL ·goWaitsH(SB), AX + MOVL AX, (SP) + MOVL ·cSetEvent(SB), AX + MOVL SP, BP + CALL AX + MOVL BP, SP + + MOVL ·cWaitsH(SB), AX + MOVL AX, (SP) + MOVL $-1, AX + MOVL AX, 4(SP) + MOVL ·cWaitForSingleObject(SB), AX + MOVL SP, BP + CALL AX + MOVL BP, SP + +exit: + ADDL $12, SP + + POPL DI + POPL SI + POPL BX + POPL BP + + MOVL 0(SP), CX + ADDL $12, SP + JMP CX + +// I do not know why, but this seems to be the only way to call +// ctlHandlerProc on Windows 7. + +// func servicectlhandler(ctl uint32, evtype uint32, evdata uintptr, context uintptr) uintptr { +TEXT ·servicectlhandler(SB),7,$0 + MOVL ·ctlHandlerExProc(SB), CX + JMP CX diff --git a/vendor/golang.org/x/sys/windows/svc/sys_amd64.s b/vendor/golang.org/x/sys/windows/svc/sys_amd64.s new file mode 100644 index 00000000000..bde25e9c485 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/svc/sys_amd64.s @@ -0,0 +1,42 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +// func servicemain(argc uint32, argv **uint16) +TEXT ·servicemain(SB),7,$0 + MOVL CX, ·sArgc(SB) + MOVQ DX, ·sArgv(SB) + + SUBQ $32, SP // stack for the first 4 syscall params + + MOVQ ·sName(SB), CX + MOVQ $·servicectlhandler(SB), DX + // BUG(pastarmovj): Figure out a way to pass in context in R8. + MOVQ ·cRegisterServiceCtrlHandlerExW(SB), AX + CALL AX + CMPQ AX, $0 + JE exit + MOVQ AX, ·ssHandle(SB) + + MOVQ ·goWaitsH(SB), CX + MOVQ ·cSetEvent(SB), AX + CALL AX + + MOVQ ·cWaitsH(SB), CX + MOVQ $4294967295, DX + MOVQ ·cWaitForSingleObject(SB), AX + CALL AX + +exit: + ADDQ $32, SP + RET + +// I do not know why, but this seems to be the only way to call +// ctlHandlerProc on Windows 7. + +// func ·servicectlhandler(ctl uint32, evtype uint32, evdata uintptr, context uintptr) uintptr { +TEXT ·servicectlhandler(SB),7,$0 + MOVQ ·ctlHandlerExProc(SB), AX + JMP AX