diff --git a/apps/websocket-server/internal/app/app.go b/apps/websocket-server/internal/app/app.go index e05eb5d6f..e0986aabe 100644 --- a/apps/websocket-server/internal/app/app.go +++ b/apps/websocket-server/internal/app/app.go @@ -59,23 +59,11 @@ var Module = fx.Module("app", // Web socket route a.Use("/ws", setUpgradable) - a.Use("/logs", setUpgradable) - - a.Use("/logs", func(c *fiber.Ctx) error { - ctx := c.Context() - - return websocket.New(func(sockConn *websocket.Conn) { - if err := d.HandleWebSocketForLogs(ctx, sockConn); err != nil { - logr.Errorf(err, "while handling websocket for logs") - } - })(c) - }) - a.Use("/ws", func(c *fiber.Ctx) error { ctx := c.Context() return websocket.New(func(sockConn *websocket.Conn) { - if err := d.HandleWebSocketForRUpdate(ctx, sockConn); err != nil { + if err := d.HandleWebSocket(ctx, sockConn); err != nil { logr.Errorf(err, "while handling websocket for resource update") } })(c) diff --git a/apps/websocket-server/internal/domain/commons.go b/apps/websocket-server/internal/domain/commons.go index ecf1de012..d8293a962 100644 --- a/apps/websocket-server/internal/domain/commons.go +++ b/apps/websocket-server/internal/domain/commons.go @@ -23,10 +23,6 @@ func (d *domain) checkAccountAccess(ctx context.Context, accountName string, use Action: string(action), }) - // if err != nil { - // return err - // } - if err != nil { d.logger.Errorf(err, "iam.can check for action: ", action) return errors.Newf("unauthorized to perform action: %s", action) diff --git a/apps/websocket-server/internal/domain/domain.go b/apps/websocket-server/internal/domain/domain.go index 46cf781e5..f6978bc45 100644 --- a/apps/websocket-server/internal/domain/domain.go +++ b/apps/websocket-server/internal/domain/domain.go @@ -13,8 +13,7 @@ import ( ) type SocketService interface { - HandleWebSocketForRUpdate(ctx context.Context, c *websocket.Conn) error - HandleWebSocketForLogs(ctx context.Context, c *websocket.Conn) error + HandleWebSocket(ctx context.Context, c *websocket.Conn) error } type Domain interface { diff --git a/apps/websocket-server/internal/domain/logs.go b/apps/websocket-server/internal/domain/logs.go index c3114c2ac..0cd0708a2 100644 --- a/apps/websocket-server/internal/domain/logs.go +++ b/apps/websocket-server/internal/domain/logs.go @@ -5,88 +5,22 @@ import ( "crypto/md5" "encoding/json" "fmt" - "strconv" "strings" - "time" - "github.com/gofiber/websocket/v2" "github.com/google/uuid" iamT "github.com/kloudlite/api/apps/iam/types" - "github.com/kloudlite/api/common" + "github.com/kloudlite/api/apps/websocket-server/internal/domain/logs" + "github.com/kloudlite/api/apps/websocket-server/internal/domain/types" + "github.com/kloudlite/api/apps/websocket-server/internal/domain/utils" "github.com/kloudlite/api/pkg/errors" - httpServer "github.com/kloudlite/api/pkg/http-server" msg_nats "github.com/kloudlite/api/pkg/messaging/nats" - "github.com/kloudlite/api/pkg/messaging/types" - "github.com/kloudlite/api/pkg/repos" - "github.com/nats-io/nats.go/jetstream" -) - -type Event string + msg_types "github.com/kloudlite/api/pkg/messaging/types" -const ( - EventSubscribe Event = "subscribe" - EventUnsubscribe Event = "unsubscribe" + "github.com/nats-io/nats.go/jetstream" ) -func parseTime(since string) (time.Time, error) { - now := time.Now() - - // Split the string into the numeric and duration type parts - length := len(since) - if length < 2 { - return now, fmt.Errorf("invalid expiration format") - } - - durationValStr := since[:length-1] - durationVal, err := strconv.Atoi(durationValStr) - if err != nil { - return now, fmt.Errorf("invalid duration value: %v", err) - } - - durationType := since[length-1] - - switch durationType { - case 'm': - return now.Add(-time.Duration(durationVal) * time.Minute), nil - case 'h': - return now.Add(-time.Duration(durationVal) * time.Hour), nil - case 'd': - return now.AddDate(0, 0, -durationVal), nil - case 'w': - return now.AddDate(0, 0, -durationVal*7), nil - case 'M': - return now.AddDate(0, -durationVal, 0), nil - default: - return now, fmt.Errorf("invalid duration type: %v, available types: m, h, d, w, M", durationType) - } -} - -func parseSince(since *string) (*time.Time, error) { - if since == nil { - return nil, nil - } - - if *since == "" { - return nil, nil - } - - t, err := parseTime(*since) - if err != nil { - return nil, errors.NewE(err) - } - - return &t, nil -} - -type LogsReqData struct { - AccountName string `json:"account"` - ClusterName string `json:"cluster"` - TrackingId string `json:"trackingId"` - Since *string `json:"since,omitempty"` -} - func (d *domain) newJetstreamConsumerForLog(ctx context.Context, subject string, consumerId string, since *string) (*msg_nats.JetstreamConsumer, error) { - t, err := parseSince(since) + t, err := logs.ParseSince(since) if err != nil { return nil, errors.NewE(err) } @@ -94,8 +28,6 @@ func (d *domain) newJetstreamConsumerForLog(ctx context.Context, subject string, id := uuid.New().String() cid := fmt.Sprintf("%x", md5.Sum([]byte(fmt.Sprintf("%s-%s", consumerId, id)))) - fmt.Println("consumerId: ", cid) - if t != nil { return msg_nats.NewJetstreamConsumer(ctx, d.jetStreamClient, msg_nats.JetstreamConsumerArgs{ Stream: d.env.LogsStreamName, @@ -123,248 +55,127 @@ func (d *domain) newJetstreamConsumerForLog(ctx context.Context, subject string, }) } -func getLogHash(ld LogsReqData, userId repos.ID, sid string) string { - return fmt.Sprintf("%x", md5.Sum([]byte(fmt.Sprintf("%s-%s-%s-%s", ld.AccountName, ld.ClusterName, ld.TrackingId, userId)))) -} - -func (d *domain) getLogSubsId(ld LogsReqData) string { - return fmt.Sprintf("%s.%s.%s.%s.>", d.env.LogsStreamName, ld.AccountName, ld.ClusterName, ld.TrackingId) -} - -func (d *domain) HandleWebSocketForLogs(ctx context.Context, c *websocket.Conn) error { - sess := httpServer.GetSession[*common.AuthSession](ctx) - if sess == nil { - return errors.NewE(fmt.Errorf("session not found")) - } - - defer func() { - if err := c.Close(); err != nil { - d.logger.Warnf("websocket close: %w", err) - } - }() - +func (d *domain) handleLogsMsg(ctx types.Context, logsSubs *logs.LogsSubsMap, msgAny map[string]any) error { log := d.logger - type Subscription struct { - resource LogsReqData - jc *msg_nats.JetstreamConsumer - open bool - Id string - } - - resources := make(map[string]*Subscription) - - defer func() { - for _, v := range resources { - if v.jc != nil { - if err := v.jc.Stop(ctx); err != nil { - log.Warnf("stop jetstream consumer: %w", err) - } - } - } - }() - - type Message struct { - Event Event `json:"event"` - Data LogsReqData `json:"data"` - Id string - } - - type MessageType string - - const ( - MessageTypeError MessageType = "error" - MessageTypeUpdate MessageType = "update" - MessageTypeInfo MessageType = "info" - MessageTypeLog MessageType = "log" - ) - type MsgSpec struct { - PodName string `json:"podName"` - ContainerName string `json:"containerName"` - } - - type MessageResponse struct { - Timestamp time.Time `json:"timestamp"` - Message string `json:"message"` - Id string `json:"id"` - Spec *MsgSpec `json:"spec,omitempty"` - Type MessageType `json:"type"` + var msg logs.Message + b, err := json.Marshal(msgAny) + if err != nil { + return err } - closed := false - - writeError := func(c *websocket.Conn, err error) error { - if c != nil { - return c.WriteJSON(MessageResponse{ - Type: MessageTypeError, - Message: err.Error(), - }) - } - return nil + if err := json.Unmarshal(b, &msg); err != nil { + return err } - writeInfo := func(c *websocket.Conn, msg string) error { - if c != nil { - return c.WriteJSON(MessageResponse{ - Type: MessageTypeInfo, - Message: msg, - }) - } - return nil + if msg.Id == "" { + msg.Id = "default" } - for { - if closed { - break - } - - var msg Message - if err := c.ReadJSON(&msg); err != nil { + hash := logs.LogHash(msg.Spec, ctx.Session.UserId, msg.Id) - if websocket.IsCloseError(err, websocket.CloseGoingAway) { - break - } - if websocket.IsCloseError(err, websocket.CloseAbnormalClosure) { - break - } + switch msg.Event { + case logs.EventSubscribe: + { - if err := writeError(c, err); err != nil { - log.Warnf("websocket write: %w", err) + if err := d.checkAccountAccess(ctx.Context, msg.Spec.Account, ctx.Session.UserId, iamT.ReadLogs); err != nil { + return err } - continue - } - - if err := d.checkAccountAccess(ctx, msg.Data.AccountName, sess.UserId, iamT.ReadLogs); err != nil { - if err := writeError(c, err); err != nil { - log.Warnf("websocket write: %w", err) - } - continue - } - - if msg.Id == "" { - msg.Id = "default" - } - - hash := getLogHash(msg.Data, sess.UserId, msg.Id) - - switch msg.Event { - case EventSubscribe: - - if _, ok := resources[hash]; ok { - - if resources[hash].jc != nil { - err := resources[hash].jc.Stop(ctx) - if err != nil { - if err := writeError( - c, errors.Newf("already subscribed to logs for account: %s, cluster: %s, trackingId: %s", - msg.Data.AccountName, msg.Data.ClusterName, msg.Data.TrackingId, - ), - ); err != nil { - log.Warnf("websocket write: %w", err) - } - // todo: reverify - continue + if _, ok := (*logsSubs)[hash]; ok { + if (*logsSubs)[hash].Jc != nil { + if err := (*logsSubs)[hash].Jc.Stop(ctx.Context); err != nil { + return err } } - } - jc, err := d.newJetstreamConsumerForLog(ctx, d.getLogSubsId(msg.Data), hash, msg.Data.Since) + jc, err := d.newJetstreamConsumerForLog(ctx.Context, logs.LogSubsId(msg.Spec, d.env.LogsStreamName), hash, msg.Spec.Since) if err != nil { - if err := writeError(c, err); err != nil { - log.Warnf("websocket write: %w", err) - } - continue + return err + } + + if (*logsSubs) == nil { + *logsSubs = make(logs.LogsSubsMap) } - resources[hash] = &Subscription{ - resource: msg.Data, - jc: jc, - open: true, + (*logsSubs)[hash] = logs.LogsSubs{ + Jc: jc, Id: msg.Id, + Resource: msg.Spec, } go func() { - if err := writeInfo(c, "subscribed to logs"); err != nil { - log.Warnf("websocket write: %w", err) - } + + utils.WriteInfo(ctx, "subscribed to logs", msg.Id, types.ForLogs) if err := jc.Consume( - func(m *types.ConsumeMsg) error { - if c != nil { - var resp MessageResponse - if err := json.Unmarshal(m.Payload, &resp); err != nil { - if err := writeError(c, err); err != nil { - log.Warnf("websocket write: %w", err) - } + func(m *msg_types.ConsumeMsg) error { + if ctx.Connection != nil { + var data logs.Response + var resp types.Response[logs.Response] + if err := json.Unmarshal(m.Payload, &data); err != nil { + return err } - resp.Type = MessageTypeLog + + resp.Type = types.MessageTypeResponse resp.Id = msg.Id sp := strings.Split(m.Subject, ".") - resp.Spec = &MsgSpec{ - PodName: sp[len(sp)-2], - ContainerName: sp[len(sp)-1], - } - if err := c.WriteJSON(resp); err != nil { - log.Warnf("websocket write: %w", err) + + data.PodName = sp[len(sp)-2] + data.ContainerName = sp[len(sp)-1] + + resp.Data = data + resp.For = types.ForLogs + + ctx.Mutex.Lock() + if ctx.Connection != nil { + if err := ctx.Connection.WriteJSON(resp); err != nil { + log.Warnf("websocket write: %w", err) + } } + ctx.Mutex.Unlock() } return nil }, - types.ConsumeOpts{ + msg_types.ConsumeOpts{ OnError: func(err error) error { - if err := writeError(c, err); err != nil { - log.Warnf("websocket write: %w", err) - } - + utils.WriteError(ctx, err, msg.Id, types.ForLogs) return err }, }, ); err != nil { - if err := writeError(c, err); err != nil { - log.Warnf("websocket write: %w", err) - } + utils.WriteError(ctx, err, msg.Id, types.ForLogs) } }() - case "unsubscribe": - if _, ok := resources[hash]; !ok { - if err := writeError( - c, errors.Newf("not subscribed to logs for account: %s, cluster: %s, trackingId: %s", - msg.Data.AccountName, msg.Data.ClusterName, msg.Data.TrackingId, - ), - ); err != nil { - log.Warnf("websocket write: %w", err) - } + } - continue - } + case logs.EventUnsubscribe: + { - if resources[hash].jc != nil { - if err := resources[hash].jc.Stop(ctx); err != nil { - if err := writeError(c, err); err != nil { - log.Warnf("websocket write: %w", err) + ctx.Mutex.Lock() + if res, ok := (*logsSubs)[hash]; ok { + if res.Jc != nil { + if err := res.Jc.Stop(ctx.Context); err != nil { + return err } - continue - } - delete(resources, hash) - if err := writeInfo(c, "unsubscribed from logs"); err != nil { - log.Warnf("websocket write: %w", err) + delete(*logsSubs, hash) } + ctx.Mutex.Unlock() + utils.WriteInfo(ctx, "[logs] subscription cancelled for ", msg.Id, types.ForLogs) + } else { + ctx.Mutex.Unlock() + utils.WriteError(ctx, fmt.Errorf("[logs] no subscription found for account: %s, cluster: %s, trackingId: %s", + msg.Spec.Account, msg.Spec.Cluster, msg.Spec.TrackingId), msg.Id, types.ForLogs) } - default: - if err := writeError( - c, errors.Newf("invalid event: %s, available events: subscribe, unsubscribe", msg.Event), - ); err != nil { - log.Warnf("websocket write: %w", err) - } } - + default: + return fmt.Errorf("invalid event: %s", msg.Event) } return nil diff --git a/apps/websocket-server/internal/domain/logs/main.go b/apps/websocket-server/internal/domain/logs/main.go new file mode 100644 index 000000000..9cd6d91ed --- /dev/null +++ b/apps/websocket-server/internal/domain/logs/main.go @@ -0,0 +1,104 @@ +package logs + +import ( + "crypto/md5" + "fmt" + "strconv" + "time" + + "github.com/kloudlite/api/pkg/errors" + msg_nats "github.com/kloudlite/api/pkg/messaging/nats" + "github.com/kloudlite/api/pkg/repos" +) + +type Event string + +const ( + EventSubscribe Event = "subscribe" + EventUnsubscribe Event = "unsubscribe" +) + +type MsgData struct { + Account string `json:"account"` + Cluster string `json:"cluster"` + TrackingId string `json:"trackingId"` + Since *string `json:"since,omitempty"` +} + +type Message struct { + Event Event `json:"event"` + Spec MsgData `json:"spec"` + Id string `json:"id"` +} + +type Response struct { + Message string `json:"message"` + Timestamp time.Time `json:"timestamp"` + PodName string `json:"podName"` + ContainerName string `json:"containerName"` +} + +type LogsSubsMap map[string]LogsSubs +type LogsSubs struct { + Jc *msg_nats.JetstreamConsumer + Id string + Resource MsgData +} + +func LogHash(md MsgData, userId repos.ID, sid string) string { + return fmt.Sprintf("%x", md5.Sum([]byte(fmt.Sprintf("%s-%s-%s-%s", md.Account, md.Cluster, md.TrackingId, userId)))) +} + +func parseTime(since string) (time.Time, error) { + now := time.Now() + + // Split the string into the numeric and duration type parts + length := len(since) + if length < 2 { + return now, fmt.Errorf("invalid expiration format") + } + + durationValStr := since[:length-1] + durationVal, err := strconv.Atoi(durationValStr) + if err != nil { + return now, fmt.Errorf("invalid duration value: %v", err) + } + + durationType := since[length-1] + + switch durationType { + case 'm': + return now.Add(-time.Duration(durationVal) * time.Minute), nil + case 'h': + return now.Add(-time.Duration(durationVal) * time.Hour), nil + case 'd': + return now.AddDate(0, 0, -durationVal), nil + case 'w': + return now.AddDate(0, 0, -durationVal*7), nil + case 'M': + return now.AddDate(0, -durationVal, 0), nil + default: + return now, fmt.Errorf("invalid duration type: %v, available types: m, h, d, w, M", durationType) + } +} + +func ParseSince(since *string) (*time.Time, error) { + if since == nil { + return nil, nil + } + + if *since == "" { + return nil, nil + } + + t, err := parseTime(*since) + if err != nil { + return nil, errors.NewE(err) + } + + return &t, nil +} + +func LogSubsId(md MsgData, logStreamName string) string { + return fmt.Sprintf("%s.%s.%s.%s.>", logStreamName, md.Account, md.Cluster, md.TrackingId) +} diff --git a/apps/websocket-server/internal/domain/main.go b/apps/websocket-server/internal/domain/main.go new file mode 100644 index 000000000..e3ab586b5 --- /dev/null +++ b/apps/websocket-server/internal/domain/main.go @@ -0,0 +1,115 @@ +package domain + +import ( + "context" + "fmt" + "sync" + + "github.com/gofiber/websocket/v2" + "github.com/kloudlite/api/apps/websocket-server/internal/domain/logs" + res_watch "github.com/kloudlite/api/apps/websocket-server/internal/domain/resource_watch" + "github.com/kloudlite/api/apps/websocket-server/internal/domain/types" + "github.com/kloudlite/api/apps/websocket-server/internal/domain/utils" + "github.com/kloudlite/api/common" + "github.com/kloudlite/api/pkg/errors" + httpServer "github.com/kloudlite/api/pkg/http-server" +) + +func (d *domain) HandleWebSocket(ctx context.Context, c *websocket.Conn) error { + sess := httpServer.GetSession[*common.AuthSession](ctx) + if sess == nil { + return errors.NewE(fmt.Errorf("session not found")) + } + + mu := sync.Mutex{} + + var logsSubs = &logs.LogsSubsMap{} + var rWatchSubs = &res_watch.ResWatchSubsMap{} + + defer func() { + if err := c.Close(); err != nil { + d.logger.Warnf("websocket close: %w", err) + } + + if logsSubs != nil { + for _, v := range *logsSubs { + if v.Jc != nil { + if err := v.Jc.Stop(ctx); err != nil { + d.logger.Warnf("stop jetstream consumer: %w", err) + } + } + } + } + + if rWatchSubs != nil { + for _, v := range *rWatchSubs { + if v.Sub != nil { + if err := v.Sub.Unsubscribe(); err != nil { + d.logger.Warnf("unsubscribe: %w", err) + } + } + } + } + + }() + + closed := false + c.SetCloseHandler(func(_ int, _ string) error { + closed = true + return nil + }) + + sc := types.Context{ + Context: ctx, + Session: sess, + Connection: c, + Mutex: &mu, + } + + for { + + if closed { + break + } + + var msg types.Message + if err := c.ReadJSON(&msg); err != nil { + if websocket.IsCloseError(err, websocket.CloseGoingAway) { + break + } + if websocket.IsCloseError(err, websocket.CloseAbnormalClosure) { + break + } + + utils.WriteError(sc, err, "", "") + continue + } + + switch msg.For { + case types.ForLogs: + if err := d.handleLogsMsg(types.Context{ + Context: ctx, + Session: sess, + Connection: c, + Mutex: &mu, + }, logsSubs, msg.Data); err != nil { + utils.WriteError(sc, err, "", types.ForLogs) + } + + case types.ForResourceUpdate: + if err := d.handleResWatchMsg(types.Context{ + Context: ctx, + Session: sess, + Mutex: &mu, + }, rWatchSubs, msg.Data); err != nil { + utils.WriteError(sc, err, "", types.ForResourceUpdate) + } + + default: + utils.WriteError(sc, fmt.Errorf("invalid for: %s", msg.For), "", "") + } + + } + + return nil +} diff --git a/apps/websocket-server/internal/domain/resource-update.go b/apps/websocket-server/internal/domain/resource-update.go index a1cbcac20..77d10b635 100644 --- a/apps/websocket-server/internal/domain/resource-update.go +++ b/apps/websocket-server/internal/domain/resource-update.go @@ -2,30 +2,19 @@ package domain import ( "context" + "encoding/json" "fmt" - "strings" - "github.com/gofiber/websocket/v2" iamT "github.com/kloudlite/api/apps/iam/types" - "github.com/kloudlite/api/common" + res_watch "github.com/kloudlite/api/apps/websocket-server/internal/domain/resource_watch" + "github.com/kloudlite/api/apps/websocket-server/internal/domain/types" + "github.com/kloudlite/api/apps/websocket-server/internal/domain/utils" "github.com/kloudlite/api/grpc-interfaces/kloudlite.io/rpc/iam" - "github.com/kloudlite/api/pkg/errors" - httpServer "github.com/kloudlite/api/pkg/http-server" "github.com/kloudlite/api/pkg/repos" mnats "github.com/nats-io/nats.go" ) -type RUpdateReqData struct { - AccountName string `json:"account"` - ProjectName string `json:"project"` - - // ResourceName string `json:"resource"` - // ResourceType string `json:"resource_type"` - Topic string `json:"topic"` - ReqTopic string `json:"req_topic"` -} - -func (d *domain) checkAccess(ctx context.Context, rdata *RUpdateReqData, userId repos.ID) error { +func (d *domain) checkAccess(ctx context.Context, rdata *res_watch.ReqData, userId repos.ID) error { co, err := d.iamClient.Can(ctx, &iam.CanIn{ UserId: string(userId), ResourceRefs: func() []string { @@ -60,230 +49,65 @@ func (d *domain) checkAccess(ctx context.Context, rdata *RUpdateReqData, userId return nil } -func (d *domain) parseRUpdateReq(rt string) (*RUpdateReqData, error) { - - entriesStrs := strings.Split(rt, ".") - - rdata := &RUpdateReqData{} - - nTopics := "res-updates" - - for _, entryStr := range entriesStrs { - entry := strings.Split(entryStr, ":") - - if len(entry) != 2 { - nTopics += fmt.Sprintf(".%s.*", entry[0]) - } else { - nTopics += fmt.Sprintf(".%s.%s", entry[0], entry[1]) - } - - if (entry[0] == "account" || entry[0] == "project") && len(entry) == 2 { - if entry[0] == "account" { - rdata.AccountName = entry[1] - } - if entry[0] == "project" { - rdata.ProjectName = entry[1] - } - } - - } - - rdata.Topic = nTopics - rdata.ReqTopic = rt - if rdata.AccountName == "" { - return nil, fmt.Errorf("invalid topic %s", rt) - } - - return rdata, nil -} - -func (d *domain) HandleWebSocketForRUpdate(ctx context.Context, c *websocket.Conn) error { - - sess := httpServer.GetSession[*common.AuthSession](ctx) - if sess == nil { - return errors.NewE(fmt.Errorf("session not found")) - } - - defer func() { - if err := c.Close(); err != nil { - d.logger.Warnf("websocket close: %w", err) - } - }() - log := d.logger - - type Subscription struct { - resource RUpdateReqData - sub *mnats.Subscription - open bool +func (d *domain) handleResWatchMsg(ctx types.Context, resources *res_watch.ResWatchSubsMap, msgAny map[string]any) error { + var msg res_watch.Message + b, err := json.Marshal(msgAny) + if err != nil { + return err } - resources := make(map[string]*Subscription) - - type Message struct { - Event string `json:"event"` - Data string `json:"data"` + if err := json.Unmarshal(b, &msg); err != nil { + return err } - // "account:accid.cluster:clusterid.nodepool:nodepoolid" - - type MessageType string - - const ( - MessageTypeError MessageType = "error" - MessageTypeUpdate MessageType = "update" - MessageTypeInfo MessageType = "info" - ) - - type MessageResponse struct { - Topic string `json:"topic"` - Message string `json:"message"` - Type MessageType `json:"type"` + if msg.Id == "" { + msg.Id = "default" } - closed := false - - c.SetCloseHandler(func(_ int, _ string) error { - closed = true - return nil - }) - - defer func() { - for _, r := range resources { - if err := r.sub.Unsubscribe(); err != nil { - log.Warnf("websocket unsubscribe: %w", err) - } - } - }() + rd, err := res_watch.ParseReq(msg.ResPath) - writeError := func(c *websocket.Conn, err error) error { - if c != nil { - return c.WriteJSON(MessageResponse{ - Type: MessageTypeError, - Message: err.Error(), - }) - } - return nil + if err != nil { + return err } - writeInfo := func(c *websocket.Conn, msg string) error { - if c != nil { - return c.WriteJSON(MessageResponse{ - Type: MessageTypeInfo, - Message: msg, - }) - } - return nil + if err := d.checkAccess(ctx.Context, rd, ctx.Session.UserId); err != nil { + return err } - // Keep the connection open - for { - - if closed { - break - } - - var message Message - if err := c.ReadJSON(&message); err != nil { - - if websocket.IsCloseError(err, websocket.CloseGoingAway) { - break - } - if websocket.IsCloseError(err, websocket.CloseAbnormalClosure) { - break - } - - if err := writeError(c, err); err != nil { - log.Warnf("websocket write: %w", err) - } - - continue - } - - rd, err := d.parseRUpdateReq(message.Data) - if err != nil { - - if err := writeError(c, err); err != nil { - log.Warnf("websocket write: %w", err) - } - - continue - } - - if err := d.checkAccess(ctx, rd, sess.UserId); err != nil { - - if err := writeError(c, err); err != nil { - log.Warnf("websocket write: %w", err) - } - - continue - } - - switch message.Event { - case "subscribe": - if _, ok := resources[message.Data]; ok { - if err := writeError(c, fmt.Errorf("resource already subscribed")); err != nil { - log.Warnf("websocket write: %w", err) - } + switch msg.Event { + case res_watch.EventSubscribe: + { + if _, ok := (*resources)[rd.Topic]; ok { + return fmt.Errorf("resource already subscribed") } - sub, err := d.natsClient.Conn.Subscribe(rd.Topic, func(_ *mnats.Msg) { - - rmessage := MessageResponse{ - Topic: rd.ReqTopic, - Message: resources[rd.Topic].resource.ReqTopic, - Type: MessageTypeUpdate, - } - - if c != nil && resources[rd.Topic] != nil && resources[rd.Topic].open { - if err := c.WriteJSON(rmessage); err != nil { - log.Warnf("websocket write: %w", err) - } - } - - }) - + s, err := d.natsClient.Conn.Subscribe(rd.Topic, func(m *mnats.Msg) {}) if err != nil { - if err := writeError(c, err); err != nil { - log.Warnf("websocket write: %w", err) - } - - continue - } - - if err := writeInfo(c, fmt.Sprintf("subscribed to %s", rd.Topic)); err != nil { - log.Warnf("websocket write: %w", err) - } - - resources[rd.Topic] = &Subscription{ - resource: *rd, - sub: sub, - open: true, + return err } - case "unsubscribe": - if _, ok := resources[message.Data]; !ok { - if err := writeError(c, fmt.Errorf("resource not subscribed")); err != nil { - log.Warnf("websocket write: %w", err) - } - - continue + (*resources)[rd.Topic] = res_watch.ResWatchSubs{ + Sub: s, + Resource: *rd, } - if resources[rd.Topic].sub != nil { - if err := resources[rd.Topic].sub.Unsubscribe(); err != nil { - if err := writeError(c, err); err != nil { - log.Warnf("websocket write: %w", err) - } - break + utils.WriteInfo(ctx, fmt.Sprintf("subscribed to %s", rd.Topic), msg.Id, types.ForResourceUpdate) + } + case res_watch.EventUnsubscribe: + { + if s, ok := (*resources)[rd.Topic]; ok { + if err := s.Sub.Unsubscribe(); err != nil { + return err } - delete(resources, message.Data) + delete(*resources, rd.Topic) + utils.WriteInfo(ctx, fmt.Sprintf("unsubscribed from %s", rd.Topic), msg.Id, types.ForResourceUpdate) } - default: - log.Errorf(fmt.Errorf("websocket read: invalid event %s", message.Event)) + utils.WriteError(ctx, fmt.Errorf("resource not found"), msg.Id, types.ForResourceUpdate) } - + default: + return fmt.Errorf("invalid event: %s", msg.Event) } return nil diff --git a/apps/websocket-server/internal/domain/resource_watch/main.go b/apps/websocket-server/internal/domain/resource_watch/main.go new file mode 100644 index 000000000..a8ebb202f --- /dev/null +++ b/apps/websocket-server/internal/domain/resource_watch/main.go @@ -0,0 +1,73 @@ +package res_watch + +import ( + "fmt" + mnats "github.com/nats-io/nats.go" + "strings" +) + +type ResWatchSubsMap map[string]ResWatchSubs +type ResWatchSubs struct { + Sub *mnats.Subscription + Resource ReqData +} + +type Event string + +const ( + EventSubscribe Event = "subscribe" + EventUnsubscribe Event = "unsubscribe" +) + +type Message struct { + ResPath string + Id string + Event Event +} + +type ReqData struct { + AccountName string `json:"account"` + ProjectName string `json:"project"` + + // ResourceName string `json:"resource"` + // ResourceType string `json:"resource_type"` + Topic string `json:"topic"` + ReqTopic string `json:"req_topic"` +} + +func ParseReq(rt string) (*ReqData, error) { + + entriesStrs := strings.Split(rt, ".") + + rdata := &ReqData{} + + nTopics := "res-updates" + + for _, entryStr := range entriesStrs { + entry := strings.Split(entryStr, ":") + + if len(entry) != 2 { + nTopics += fmt.Sprintf(".%s.*", entry[0]) + } else { + nTopics += fmt.Sprintf(".%s.%s", entry[0], entry[1]) + } + + if (entry[0] == "account" || entry[0] == "project") && len(entry) == 2 { + if entry[0] == "account" { + rdata.AccountName = entry[1] + } + if entry[0] == "project" { + rdata.ProjectName = entry[1] + } + } + + } + + rdata.Topic = nTopics + rdata.ReqTopic = rt + if rdata.AccountName == "" { + return nil, fmt.Errorf("invalid topic %s", rt) + } + + return rdata, nil +} diff --git a/apps/websocket-server/internal/domain/samp/main.go b/apps/websocket-server/internal/domain/samp/main.go new file mode 100644 index 000000000..59db586b4 --- /dev/null +++ b/apps/websocket-server/internal/domain/samp/main.go @@ -0,0 +1,39 @@ +package main + +import ( + "fmt" + "sync" + "time" +) + +// SafeCounter is safe to use concurrently. +type SafeCounter struct { + mu sync.Mutex + v map[string]int +} + +// Inc increments the counter for the given key. +func (c *SafeCounter) Inc(key string) { + c.mu.Lock() + // Lock so only one goroutine at a time can access the map c.v. + c.v[key]++ + c.mu.Unlock() +} + +// Value returns the current value of the counter for the given key. +func (c *SafeCounter) Value(key string) int { + c.mu.Lock() + // Lock so only one goroutine at a time can access the map c.v. + defer c.mu.Unlock() + return c.v[key] +} + +func main() { + c := SafeCounter{v: make(map[string]int)} + for i := 0; i < 1000; i++ { + go c.Inc("somekey") + } + + time.Sleep(time.Second) + fmt.Println(c.Value("somekey")) +} diff --git a/apps/websocket-server/internal/domain/types/main.go b/apps/websocket-server/internal/domain/types/main.go new file mode 100644 index 000000000..ee17080d2 --- /dev/null +++ b/apps/websocket-server/internal/domain/types/main.go @@ -0,0 +1,45 @@ +package types + +import ( + "context" + "sync" + + "github.com/gofiber/websocket/v2" + "github.com/kloudlite/api/common" +) + +type For string + +const ( + ForLogs For = "logs" + ForResourceUpdate For = "resource-update" +) + +type MessageType string + +const ( + MessageTypeError MessageType = "error" + MessageTypeResponse MessageType = "response" + MessageTypeInfo MessageType = "info" +) + +type Response[T any] struct { + Type MessageType `json:"type"` + For For `json:"for"` + + Data T `json:"data"` + Message string `json:"message"` + Id string `json:"id"` +} + +type Message struct { + For For `json:"for"` + Data map[string]any `json:"data"` +} + +type Context struct { + Context context.Context + Session *common.AuthSession + Connection *websocket.Conn + Mutex *sync.Mutex +} diff --git a/apps/websocket-server/internal/domain/utils/main.go b/apps/websocket-server/internal/domain/utils/main.go new file mode 100644 index 000000000..b324174b1 --- /dev/null +++ b/apps/websocket-server/internal/domain/utils/main.go @@ -0,0 +1,38 @@ +package utils + +import ( + "github.com/gofiber/fiber/v2/log" + "github.com/kloudlite/api/apps/websocket-server/internal/domain/types" +) + +func WriteError(ctx types.Context, err error, id string, For types.For) { + if ctx.Context != nil { + ctx.Mutex.Lock() + if err := ctx.Connection.WriteJSON(types.Response[any]{ + Type: types.MessageTypeError, + Message: err.Error(), + For: For, + Id: id, + }); err != nil { + log.Warnf("websocket write: %w", err) + } + ctx.Mutex.Unlock() + } +} + +func WriteInfo(ctx types.Context, msg string, id string, For types.For) { + if ctx.Context != nil { + ctx.Mutex.Lock() + if err := ctx.Connection.WriteJSON(types.Response[any]{ + Type: types.MessageTypeInfo, + Message: msg, + Id: id, + For: For, + }); err != nil { + log.Warnf("websocket write: %w", err) + } + ctx.Mutex.Unlock() + } else { + log.Warnf("websocket connection is nil") + } +} diff --git a/pkg/messaging/nats/jetstream-consumer.go b/pkg/messaging/nats/jetstream-consumer.go index 81f1f79aa..5ddab7a32 100644 --- a/pkg/messaging/nats/jetstream-consumer.go +++ b/pkg/messaging/nats/jetstream-consumer.go @@ -86,7 +86,9 @@ func (jc *JetstreamConsumer) Consume(consumeFn func(msg *types.ConsumeMsg) error // Stop implements Consumer. func (nc *JetstreamConsumer) Stop(context.Context) error { - nc.consumeCtx.Stop() + if nc.consumeCtx != nil { + nc.consumeCtx.Stop() + } return nil }