From 4c969d3075ea9b76133878b077b1b67910d703c6 Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Mon, 25 May 2026 02:55:46 +0000 Subject: [PATCH] fix: dedupe thread initiator participants --- internal/server/server.go | 11 ++- internal/server/server_test.go | 78 ++++++++++++++++--- .../0004_enforce_thread_delete_cascades.sql | 17 ++++ 3 files changed, 90 insertions(+), 16 deletions(-) create mode 100644 migrations/0004_enforce_thread_delete_cascades.sql diff --git a/internal/server/server.go b/internal/server/server.go index 7a9c271..fd1b67a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -99,7 +99,10 @@ func (s *Server) CreateThread(ctx context.Context, req *threadsv1.CreateThreadRe if len(ids) == 0 && len(identifiers) == 0 && !hasInitiator { return nil, status.Error(codes.InvalidArgument, "participant_ids or participants must be provided") } - seen := make(map[uuid.UUID]struct{}, len(ids)+len(identifiers)) + seen := make(map[uuid.UUID]struct{}, len(ids)+len(identifiers)+1) + if hasInitiator { + seen[initiator.ID] = struct{}{} + } resolved := make([]store.ParticipantInput, 0, len(ids)+len(identifiers)) if len(ids) > 0 { for i, raw := range ids { @@ -205,10 +208,10 @@ func (s *Server) CreateThread(ctx context.Context, req *threadsv1.CreateThreadRe } func addResolvedParticipant(id uuid.UUID, initiator initiatorInfo, hasInitiator bool, seen map[uuid.UUID]struct{}, resolved *[]store.ParticipantInput, fieldName string, index int) error { - if hasInitiator && id == initiator.ID { - return status.Errorf(codes.InvalidArgument, "%s must not include initiator", fieldName) - } if _, ok := seen[id]; ok { + if hasInitiator && id == initiator.ID { + return nil + } return status.Errorf(codes.InvalidArgument, "%s[%d]: duplicate participant", fieldName, index) } seen[id] = struct{}{} diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 3c03ef9..d836fab 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -16,6 +16,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" "github.com/agynio/threads/internal/store" @@ -923,25 +924,78 @@ func TestCreateThreadNicknameRequiresOrganizationIDForUser(t *testing.T) { } } -func TestCreateThreadRejectsInitiatorInParticipants(t *testing.T) { +func TestCreateThreadDedupesInitiatorInParticipantIDs(t *testing.T) { initiatorID := uuid.New() participantID := uuid.New() - srv := New(&stubThreadStore{t: t}, nil, nil, nil, nil, nil) + assertCreateThreadDedupesInitiator(t, initiatorID, participantID, &threadsv1.CreateThreadRequest{ParticipantIds: []string{initiatorID.String(), participantID.String()}}) +} + +func TestCreateThreadDedupesInitiatorInParticipants(t *testing.T) { + initiatorID := uuid.New() + participantID := uuid.New() + + assertCreateThreadDedupesInitiator(t, initiatorID, participantID, &threadsv1.CreateThreadRequest{ + Participants: []*threadsv1.ParticipantIdentifier{ + {Identifier: &threadsv1.ParticipantIdentifier_ParticipantId{ParticipantId: initiatorID.String()}}, + {Identifier: &threadsv1.ParticipantIdentifier_ParticipantId{ParticipantId: participantID.String()}}, + }, + }) +} + +func assertCreateThreadDedupesInitiator(t *testing.T, initiatorID, participantID uuid.UUID, req *threadsv1.CreateThreadRequest) { + t.Helper() + threadID := uuid.New() + organizationID := uuid.New() + now := time.Now().UTC() + storeCalled := false + req.OrganizationId = proto.String(organizationID.String()) + + storeStub := &stubThreadStore{ + t: t, + createThreadFn: func(ctx context.Context, orgID uuid.UUID, participants []store.ParticipantInput) (store.Thread, error) { + storeCalled = true + if orgID != organizationID { + t.Fatalf("expected organization %s, got %s", organizationID, orgID) + } + if len(participants) != 2 { + t.Fatalf("expected 2 participants, got %d", len(participants)) + } + if participants[0].ID != initiatorID { + t.Fatalf("expected initiator %s first, got %s", initiatorID, participants[0].ID) + } + if !participants[0].Passive { + t.Fatal("expected agent initiator to be passive") + } + if participants[1].ID != participantID { + t.Fatalf("expected participant %s second, got %s", participantID, participants[1].ID) + } + return store.Thread{ + ID: threadID, + OrganizationID: &organizationID, + MessageCount: 0, + Status: store.ThreadStatusActive, + CreatedAt: now, + UpdatedAt: now, + Participants: []store.Participant{ + {ID: initiatorID, JoinedAt: now, Passive: true}, + {ID: participantID, JoinedAt: now, Passive: false}, + }, + }, nil + }, + } + + srv := New(storeStub, nil, allowAuthStub(t), &stubIdentityResolver{t: t}, nil, nil) ctx := metadata.NewIncomingContext( context.Background(), - metadata.Pairs("x-identity-id", initiatorID.String(), "x-identity-type", "agent"), + metadata.Pairs("x-identity-id", initiatorID.String(), "x-identity-type", "agent", "x-organization-id", organizationID.String()), ) - _, err := srv.CreateThread(ctx, &threadsv1.CreateThreadRequest{ParticipantIds: []string{initiatorID.String(), participantID.String()}}) - if err == nil { - t.Fatal("expected error") - } - st, ok := status.FromError(err) - if !ok { - t.Fatalf("expected gRPC status error, got %v", err) + _, err := srv.CreateThread(ctx, req) + if err != nil { + t.Fatalf("CreateThread returned error: %v", err) } - if st.Code() != codes.InvalidArgument { - t.Fatalf("expected InvalidArgument, got %s: %s", st.Code(), st.Message()) + if !storeCalled { + t.Fatal("expected CreateThread to be called") } } diff --git a/migrations/0004_enforce_thread_delete_cascades.sql b/migrations/0004_enforce_thread_delete_cascades.sql new file mode 100644 index 0000000..84fdbfa --- /dev/null +++ b/migrations/0004_enforce_thread_delete_cascades.sql @@ -0,0 +1,17 @@ +ALTER TABLE thread_participants + DROP CONSTRAINT IF EXISTS thread_participants_thread_id_fkey, + ADD CONSTRAINT thread_participants_thread_id_fkey + FOREIGN KEY (thread_id) REFERENCES threads(id) ON DELETE CASCADE; + +ALTER TABLE messages + DROP CONSTRAINT IF EXISTS messages_thread_id_fkey, + ADD CONSTRAINT messages_thread_id_fkey + FOREIGN KEY (thread_id) REFERENCES threads(id) ON DELETE CASCADE; + +ALTER TABLE message_recipients + DROP CONSTRAINT IF EXISTS message_recipients_message_id_fkey, + DROP CONSTRAINT IF EXISTS message_recipients_thread_id_fkey, + ADD CONSTRAINT message_recipients_message_id_fkey + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE, + ADD CONSTRAINT message_recipients_thread_id_fkey + FOREIGN KEY (thread_id) REFERENCES threads(id) ON DELETE CASCADE;