diff --git a/internal/server/server.go b/internal/server/server.go index 230b84a..4d382b5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -61,6 +61,7 @@ type threadStore interface { type identityResolver interface { ResolveNickname(ctx context.Context, in *identityv1.ResolveNicknameRequest, opts ...grpc.CallOption) (*identityv1.ResolveNicknameResponse, error) BatchGetNicknames(ctx context.Context, in *identityv1.BatchGetNicknamesRequest, opts ...grpc.CallOption) (*identityv1.BatchGetNicknamesResponse, error) + BatchGetIdentityTypes(ctx context.Context, in *identityv1.BatchGetIdentityTypesRequest, opts ...grpc.CallOption) (*identityv1.BatchGetIdentityTypesResponse, error) } type agentsService interface { @@ -171,6 +172,9 @@ func (s *Server) CreateThread(ctx context.Context, req *threadsv1.CreateThreadRe participants = append(participants, store.ParticipantInput{ID: initiator.ID, Passive: initiator.Passive}) } participants = append(participants, resolved...) + if err := s.requireCanInitiateAgentParticipants(ctx, identityID, participants); err != nil { + return nil, err + } thread, err := s.store.CreateThread(ctx, organizationID, participants) if err != nil { @@ -212,6 +216,60 @@ func addResolvedParticipant(id uuid.UUID, initiator initiatorInfo, hasInitiator return nil } +func (s *Server) requireCanInitiateAgentParticipants(ctx context.Context, callerID uuid.UUID, participants []store.ParticipantInput) error { + agentIDs, err := s.agentParticipantIDs(ctx, participants) + if err != nil { + return err + } + for _, agentID := range agentIDs { + if err := s.requireAllowed(ctx, callerID, "can_initiate", fmt.Sprintf("agent:%s", agentID.String())); err != nil { + return err + } + } + return nil +} + +func (s *Server) agentParticipantIDs(ctx context.Context, participants []store.ParticipantInput) ([]uuid.UUID, error) { + if len(participants) == 0 { + return nil, nil + } + if s.identity == nil { + return nil, status.Error(codes.Internal, "identity service not configured") + } + identityIDs := make([]string, len(participants)) + participantIDs := make(map[uuid.UUID]struct{}, len(participants)) + for i, participant := range participants { + identityIDs[i] = participant.ID.String() + participantIDs[participant.ID] = struct{}{} + } + identityCtx, err := identityClientContext(ctx) + if err != nil { + return nil, err + } + response, err := s.identity.BatchGetIdentityTypes(identityCtx, &identityv1.BatchGetIdentityTypesRequest{IdentityIds: identityIDs}) + if err != nil { + return nil, status.Errorf(codes.Internal, "batch get identity types: %v", err) + } + agentIDs := make([]uuid.UUID, 0) + for i, entry := range response.GetEntries() { + if entry == nil { + return nil, status.Errorf(codes.Internal, "identity type entry[%d]: missing", i) + } + if entry.GetIdentityType() != identityv1.IdentityType_IDENTITY_TYPE_AGENT { + continue + } + identityID, err := parseUUID(entry.GetIdentityId()) + if err != nil { + return nil, status.Errorf(codes.Internal, "identity type entry[%d].identity_id: %v", i, err) + } + if _, ok := participantIDs[identityID]; !ok { + return nil, status.Errorf(codes.Internal, "identity type entry[%d].identity_id: unexpected identity", i) + } + agentIDs = append(agentIDs, identityID) + } + return agentIDs, nil +} + func (s *Server) ArchiveThread(ctx context.Context, req *threadsv1.ArchiveThreadRequest) (*threadsv1.ArchiveThreadResponse, error) { threadID, err := parseUUID(req.GetThreadId()) if err != nil { @@ -283,6 +341,9 @@ func (s *Server) AddParticipant(ctx context.Context, req *threadsv1.AddParticipa if err != nil { return nil, err } + if err := s.requireCanInitiateAgentParticipants(ctx, identityID, []store.ParticipantInput{{ID: participantID}}); err != nil { + return nil, err + } thread, err := s.store.AddParticipant(ctx, threadID, participantID, req.GetPassive()) if err != nil { return nil, toStatusError(err) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 086bf24..b723857 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -135,9 +135,10 @@ func (s *stubThreadStore) AckMessages(ctx context.Context, participantID uuid.UU } type stubIdentityResolver struct { - t *testing.T - resolveFn func(ctx context.Context, req *identityv1.ResolveNicknameRequest, opts ...grpc.CallOption) (*identityv1.ResolveNicknameResponse, error) - batchFn func(ctx context.Context, req *identityv1.BatchGetNicknamesRequest, opts ...grpc.CallOption) (*identityv1.BatchGetNicknamesResponse, error) + t *testing.T + resolveFn func(ctx context.Context, req *identityv1.ResolveNicknameRequest, opts ...grpc.CallOption) (*identityv1.ResolveNicknameResponse, error) + batchFn func(ctx context.Context, req *identityv1.BatchGetNicknamesRequest, opts ...grpc.CallOption) (*identityv1.BatchGetNicknamesResponse, error) + typeBatchFn func(ctx context.Context, req *identityv1.BatchGetIdentityTypesRequest, opts ...grpc.CallOption) (*identityv1.BatchGetIdentityTypesResponse, error) } func (s *stubIdentityResolver) ResolveNickname(ctx context.Context, req *identityv1.ResolveNicknameRequest, opts ...grpc.CallOption) (*identityv1.ResolveNicknameResponse, error) { @@ -156,6 +157,18 @@ func (s *stubIdentityResolver) BatchGetNicknames(ctx context.Context, req *ident return s.batchFn(ctx, req, opts...) } +func (s *stubIdentityResolver) BatchGetIdentityTypes(ctx context.Context, req *identityv1.BatchGetIdentityTypesRequest, opts ...grpc.CallOption) (*identityv1.BatchGetIdentityTypesResponse, error) { + s.t.Helper() + if s.typeBatchFn == nil { + entries := make([]*identityv1.IdentityTypeEntry, len(req.GetIdentityIds())) + for i, identityID := range req.GetIdentityIds() { + entries[i] = &identityv1.IdentityTypeEntry{IdentityId: identityID, IdentityType: identityv1.IdentityType_IDENTITY_TYPE_USER} + } + return &identityv1.BatchGetIdentityTypesResponse{Entries: entries}, nil + } + return s.typeBatchFn(ctx, req, opts...) +} + type stubAgentsService struct { t *testing.T getAgentFn func(ctx context.Context, req *agentsv1.GetAgentRequest, opts ...grpc.CallOption) (*agentsv1.GetAgentResponse, error) @@ -201,6 +214,30 @@ func allowAuthStub(t *testing.T) *stubAuthorizationService { } } +func TestCreateThreadParticipantIDsRequireIdentityResolver(t *testing.T) { + organizationID := uuid.New() + identityID := uuid.New() + participantID := uuid.New() + storeStub := &stubThreadStore{t: t} + authStub := allowAuthStub(t) + srv := New(storeStub, nil, authStub, nil, nil, nil) + ctx := metadata.NewIncomingContext( + context.Background(), + metadata.Pairs("x-identity-id", identityID.String(), "x-identity-type", "user", "x-organization-id", organizationID.String()), + ) + _, err := srv.CreateThread(ctx, &threadsv1.CreateThreadRequest{ParticipantIds: []string{participantID.String()}}) + if err == nil { + t.Fatal("expected error") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status error, got %v", err) + } + if st.Code() != codes.Internal { + t.Fatalf("expected Internal, got %s: %s", st.Code(), st.Message()) + } +} + func TestCreateThreadAgentInitiatorPassive(t *testing.T) { threadID := uuid.New() organizationID := uuid.New() @@ -247,7 +284,7 @@ func TestCreateThreadAgentInitiatorPassive(t *testing.T) { } authStub := allowAuthStub(t) - srv := New(storeStub, nil, authStub, nil, nil, nil) + srv := New(storeStub, nil, authStub, &stubIdentityResolver{t: t}, nil, nil) ctx := metadata.NewIncomingContext( context.Background(), metadata.Pairs("x-identity-id", agentID.String(), "x-identity-type", "agent", "x-organization-id", organizationID.String()), @@ -310,7 +347,7 @@ func TestCreateThreadEmptyParticipantsWithAgentInitiator(t *testing.T) { } authStub := allowAuthStub(t) - srv := New(storeStub, nil, authStub, nil, nil, nil) + srv := New(storeStub, nil, authStub, &stubIdentityResolver{t: t}, nil, nil) ctx := metadata.NewIncomingContext( context.Background(), metadata.Pairs("x-identity-id", agentID.String(), "x-identity-type", "agent", "x-organization-id", organizationID.String()), @@ -377,7 +414,7 @@ func TestCreateThreadUserInitiatorActive(t *testing.T) { } authStub := allowAuthStub(t) - srv := New(storeStub, nil, authStub, nil, nil, nil) + srv := New(storeStub, nil, authStub, &stubIdentityResolver{t: t}, nil, nil) ctx := metadata.NewIncomingContext( context.Background(), metadata.Pairs("x-identity-id", userID.String(), "x-identity-type", "user", "x-organization-id", organizationID.String()), @@ -932,7 +969,7 @@ func TestCreateThreadWritesAuthorizationTuples(t *testing.T) { }, } - srv := New(storeStub, nil, authStub, nil, nil, nil) + srv := New(storeStub, nil, authStub, &stubIdentityResolver{t: t}, nil, nil) ctx := metadata.NewIncomingContext( context.Background(), metadata.Pairs("x-identity-id", identityID.String(), "x-identity-type", "user", "x-organization-id", organizationID.String()), @@ -1058,7 +1095,7 @@ func TestAddParticipantWithParticipantIDOneof(t *testing.T) { identityID := uuid.New() authStub := allowAuthStub(t) - srv := New(storeStub, nil, authStub, nil, nil, nil) + srv := New(storeStub, nil, authStub, &stubIdentityResolver{t: t}, nil, nil) ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-identity-id", identityID.String())) _, err := srv.AddParticipant(ctx, &threadsv1.AddParticipantRequest{ ThreadId: threadID.String(), @@ -1101,7 +1138,7 @@ func TestAddParticipantWithLegacyParticipantID(t *testing.T) { identityID := uuid.New() authStub := allowAuthStub(t) - srv := New(storeStub, nil, authStub, nil, nil, nil) + srv := New(storeStub, nil, authStub, &stubIdentityResolver{t: t}, nil, nil) ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-identity-id", identityID.String())) _, err := srv.AddParticipant(ctx, &threadsv1.AddParticipantRequest{ ThreadId: threadID.String(), @@ -1227,7 +1264,7 @@ func TestAddParticipantWritesAuthorizationTuple(t *testing.T) { }, } - srv := New(storeStub, nil, authStub, nil, nil, nil) + srv := New(storeStub, nil, authStub, &stubIdentityResolver{t: t}, nil, nil) ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-identity-id", identityID.String())) _, err := srv.AddParticipant(ctx, &threadsv1.AddParticipantRequest{ThreadId: threadID.String(), ParticipantId: participantID.String()}) if err != nil {