diff --git a/apps/websocket-server/internal/domain/logs.go b/apps/websocket-server/internal/domain/logs.go index 8f171691b..db4e4c7fe 100644 --- a/apps/websocket-server/internal/domain/logs.go +++ b/apps/websocket-server/internal/domain/logs.go @@ -21,6 +21,13 @@ import ( "github.com/nats-io/nats.go/jetstream" ) +type Event string + +const ( + EventSubscribe Event = "subscribe" + EventUnsubscribe Event = "unsubscribe" +) + func parseTime(since string) (time.Time, error) { now := time.Now() @@ -84,13 +91,18 @@ func (d *domain) newJetstreamConsumerForLog(ctx context.Context, subject string, return nil, errors.NewE(err) } + id := uuid.New().String() + cid := fmt.Sprintf("%x", md5.Sum([]byte(fmt.Sprintf("%s-%s", consumerId, id)))) + + fmt.Println("consumerId: ", cid) + if t != nil { return msg_nats.NewJetstreamConsumer(ctx, d.jetStreamClient, msg_nats.JetstreamConsumerArgs{ Stream: d.env.LogsStreamName, ConsumerConfig: msg_nats.ConsumerConfig{ DeliverPolicy: jetstream.DeliverByStartTimePolicy, OptStartTime: t, - Name: consumerId, + Name: cid, Description: "this is an ephemeral consumer which dispatches logs to a websocket client", FilterSubjects: []string{ subject, @@ -111,9 +123,8 @@ func (d *domain) newJetstreamConsumerForLog(ctx context.Context, subject string, }) } -func getLogHash(ld LogsReqData, userId repos.ID) string { - uuid := uuid.New().String() - return fmt.Sprintf("%x", md5.Sum([]byte(fmt.Sprintf("%s-%s-%s-%s-%s", ld.AccountName, ld.ClusterName, ld.TrackingId, userId, uuid)))) +func getLogHash(ld LogsReqData, userId repos.ID, sid string) string { + return fmt.Sprintf("%x", md5.Sum([]byte(fmt.Sprintf("%s-%s-%s-%s", ld.AccountName, ld.ClusterName, ld.TrackingId, userId)))) } func (d *domain) getLogSubsId(ld LogsReqData) string { @@ -138,13 +149,15 @@ func (d *domain) HandleWebSocketForLogs(ctx context.Context, c *websocket.Conn) resource LogsReqData jc *msg_nats.JetstreamConsumer open bool + Id string } resources := make(map[string]*Subscription) type Message struct { - Event string `json:"event"` + Event Event `json:"event"` Data LogsReqData `json:"data"` + Id string } type MessageType string @@ -163,6 +176,7 @@ func (d *domain) HandleWebSocketForLogs(ctx context.Context, c *websocket.Conn) type MessageResponse struct { Timestamp time.Time `json:"timestamp"` Message string `json:"message"` + Id string `json:"id"` Spec *MsgSpec `json:"spec,omitempty"` Type MessageType `json:"type"` } @@ -211,27 +225,39 @@ func (d *domain) HandleWebSocketForLogs(ctx context.Context, c *websocket.Conn) continue } - if err := d.checkAccountAccess(ctx, msg.Data.AccountName, sess.UserId, iamT.GetAccount); err != nil { + if err := d.checkAccountAccess(ctx, msg.Data.AccountName, sess.UserId, iamT.ReadLogs); err != nil { if err := writeError(c, err); err != nil { log.Warnf("websocket write: %w", err) } continue } - hash := getLogHash(msg.Data, sess.UserId) + if msg.Id == "" { + msg.Id = "default" + } + + hash := getLogHash(msg.Data, sess.UserId, msg.Id) switch msg.Event { - case "subscribe": + case EventSubscribe: if _, ok := resources[hash]; ok { - if err := writeError( - c, errors.Newf("already subscribed to logs for account: %s, cluster: %s, trackingId: %s", - msg.Data.AccountName, msg.Data.ClusterName, msg.Data.TrackingId, - ), - ); err != nil { - log.Warnf("websocket write: %w", err) + + if resources[hash].jc != nil { + err := resources[hash].jc.Stop(ctx) + if err != nil { + if err := writeError( + c, errors.Newf("already subscribed to logs for account: %s, cluster: %s, trackingId: %s", + msg.Data.AccountName, msg.Data.ClusterName, msg.Data.TrackingId, + ), + ); err != nil { + log.Warnf("websocket write: %w", err) + } + // todo: reverify + continue + } } - continue + } jc, err := d.newJetstreamConsumerForLog(ctx, d.getLogSubsId(msg.Data), hash, msg.Data.Since) @@ -246,25 +272,26 @@ func (d *domain) HandleWebSocketForLogs(ctx context.Context, c *websocket.Conn) resource: msg.Data, jc: jc, open: true, + Id: msg.Id, } go func() { - if err := writeInfo(c, "subscribed to logs"); err != nil { log.Warnf("websocket write: %w", err) } if err := jc.Consume( - func(msg *types.ConsumeMsg) error { + func(m *types.ConsumeMsg) error { if c != nil { var resp MessageResponse - if err := json.Unmarshal(msg.Payload, &resp); err != nil { + if err := json.Unmarshal(m.Payload, &resp); err != nil { if err := writeError(c, err); err != nil { log.Warnf("websocket write: %w", err) } } resp.Type = MessageTypeLog - sp := strings.Split(msg.Subject, ".") + resp.Id = msg.Id + sp := strings.Split(m.Subject, ".") resp.Spec = &MsgSpec{ PodName: sp[len(sp)-2], ContainerName: sp[len(sp)-1], @@ -311,6 +338,12 @@ func (d *domain) HandleWebSocketForLogs(ctx context.Context, c *websocket.Conn) if err := writeError(c, err); err != nil { log.Warnf("websocket write: %w", err) } + continue + } + + delete(resources, hash) + if err := writeInfo(c, "unsubscribed from logs"); err != nil { + log.Warnf("websocket write: %w", err) } }