Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ func (s *Server) CreateThread(ctx context.Context, req *threadsv1.CreateThreadRe
if len(ids) == 0 && len(identifiers) == 0 && !hasInitiator {
return nil, status.Error(codes.InvalidArgument, "participant_ids or participants must be provided")
}
seen := make(map[uuid.UUID]struct{}, len(ids)+len(identifiers))
seen := make(map[uuid.UUID]struct{}, len(ids)+len(identifiers)+1)
if hasInitiator {
seen[initiator.ID] = struct{}{}
}
resolved := make([]store.ParticipantInput, 0, len(ids)+len(identifiers))
if len(ids) > 0 {
for i, raw := range ids {
Expand Down Expand Up @@ -205,10 +208,10 @@ func (s *Server) CreateThread(ctx context.Context, req *threadsv1.CreateThreadRe
}

func addResolvedParticipant(id uuid.UUID, initiator initiatorInfo, hasInitiator bool, seen map[uuid.UUID]struct{}, resolved *[]store.ParticipantInput, fieldName string, index int) error {
if hasInitiator && id == initiator.ID {
return status.Errorf(codes.InvalidArgument, "%s must not include initiator", fieldName)
}
if _, ok := seen[id]; ok {
if hasInitiator && id == initiator.ID {
return nil
}
return status.Errorf(codes.InvalidArgument, "%s[%d]: duplicate participant", fieldName, index)
}
seen[id] = struct{}{}
Expand Down
78 changes: 66 additions & 12 deletions internal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/agynio/threads/internal/store"
Expand Down Expand Up @@ -923,25 +924,78 @@ func TestCreateThreadNicknameRequiresOrganizationIDForUser(t *testing.T) {
}
}

func TestCreateThreadRejectsInitiatorInParticipants(t *testing.T) {
func TestCreateThreadDedupesInitiatorInParticipantIDs(t *testing.T) {
initiatorID := uuid.New()
participantID := uuid.New()

srv := New(&stubThreadStore{t: t}, nil, nil, nil, nil, nil)
assertCreateThreadDedupesInitiator(t, initiatorID, participantID, &threadsv1.CreateThreadRequest{ParticipantIds: []string{initiatorID.String(), participantID.String()}})
}

func TestCreateThreadDedupesInitiatorInParticipants(t *testing.T) {
initiatorID := uuid.New()
participantID := uuid.New()

assertCreateThreadDedupesInitiator(t, initiatorID, participantID, &threadsv1.CreateThreadRequest{
Participants: []*threadsv1.ParticipantIdentifier{
{Identifier: &threadsv1.ParticipantIdentifier_ParticipantId{ParticipantId: initiatorID.String()}},
{Identifier: &threadsv1.ParticipantIdentifier_ParticipantId{ParticipantId: participantID.String()}},
},
})
}

func assertCreateThreadDedupesInitiator(t *testing.T, initiatorID, participantID uuid.UUID, req *threadsv1.CreateThreadRequest) {
t.Helper()
threadID := uuid.New()
organizationID := uuid.New()
now := time.Now().UTC()
storeCalled := false
req.OrganizationId = proto.String(organizationID.String())

storeStub := &stubThreadStore{
t: t,
createThreadFn: func(ctx context.Context, orgID uuid.UUID, participants []store.ParticipantInput) (store.Thread, error) {
storeCalled = true
if orgID != organizationID {
t.Fatalf("expected organization %s, got %s", organizationID, orgID)
}
if len(participants) != 2 {
t.Fatalf("expected 2 participants, got %d", len(participants))
}
if participants[0].ID != initiatorID {
t.Fatalf("expected initiator %s first, got %s", initiatorID, participants[0].ID)
}
if !participants[0].Passive {
t.Fatal("expected agent initiator to be passive")
}
if participants[1].ID != participantID {
t.Fatalf("expected participant %s second, got %s", participantID, participants[1].ID)
}
return store.Thread{
ID: threadID,
OrganizationID: &organizationID,
MessageCount: 0,
Status: store.ThreadStatusActive,
CreatedAt: now,
UpdatedAt: now,
Participants: []store.Participant{
{ID: initiatorID, JoinedAt: now, Passive: true},
{ID: participantID, JoinedAt: now, Passive: false},
},
}, nil
},
}

srv := New(storeStub, nil, allowAuthStub(t), &stubIdentityResolver{t: t}, nil, nil)
ctx := metadata.NewIncomingContext(
context.Background(),
metadata.Pairs("x-identity-id", initiatorID.String(), "x-identity-type", "agent"),
metadata.Pairs("x-identity-id", initiatorID.String(), "x-identity-type", "agent", "x-organization-id", organizationID.String()),
)
_, err := srv.CreateThread(ctx, &threadsv1.CreateThreadRequest{ParticipantIds: []string{initiatorID.String(), participantID.String()}})
if err == nil {
t.Fatal("expected error")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
_, err := srv.CreateThread(ctx, req)
if err != nil {
t.Fatalf("CreateThread returned error: %v", err)
}
if st.Code() != codes.InvalidArgument {
t.Fatalf("expected InvalidArgument, got %s: %s", st.Code(), st.Message())
if !storeCalled {
t.Fatal("expected CreateThread to be called")
}
}

Expand Down
17 changes: 17 additions & 0 deletions migrations/0004_enforce_thread_delete_cascades.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
ALTER TABLE thread_participants
DROP CONSTRAINT IF EXISTS thread_participants_thread_id_fkey,
ADD CONSTRAINT thread_participants_thread_id_fkey
FOREIGN KEY (thread_id) REFERENCES threads(id) ON DELETE CASCADE;

ALTER TABLE messages
DROP CONSTRAINT IF EXISTS messages_thread_id_fkey,
ADD CONSTRAINT messages_thread_id_fkey
FOREIGN KEY (thread_id) REFERENCES threads(id) ON DELETE CASCADE;

ALTER TABLE message_recipients
DROP CONSTRAINT IF EXISTS message_recipients_message_id_fkey,
DROP CONSTRAINT IF EXISTS message_recipients_thread_id_fkey,
ADD CONSTRAINT message_recipients_message_id_fkey
FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE,
ADD CONSTRAINT message_recipients_thread_id_fkey
FOREIGN KEY (thread_id) REFERENCES threads(id) ON DELETE CASCADE;
Loading