diff --git a/README.md b/README.md index 833bd8dc..f6e7fcf9 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ A terminal UI for AWS resource management - **Cross-resource navigation** - Jump from VPC to subnets, Lambda to CloudWatch - **Filtering & sorting** - Fuzzy search, tag filtering, column sorting - **Resource comparison** - Side-by-side diff view +- **AI Chat** - AI assistant with AWS context (via Bedrock) - **6 color themes** - dark, light, nord, dracula, gruvbox, catppuccin ## Screenshots @@ -31,6 +32,12 @@ A terminal UI for AWS resource management ![multi-region](docs/images/multi-account-region.png) +### AI Chat (Bedrock) + +![ai-chat](docs/images/ai-chat.png) + +Press `A` in list/detail/diff views to open AI chat. The assistant analyzes resources, compares configurations, and identifies risks using AWS Bedrock. + ## Installation ### Homebrew (macOS/Linux) @@ -89,6 +96,7 @@ claws --read-only | `:` | Command mode (e.g., `:ec2/instances`) | | `/` | Filter mode (fuzzy search) | | `a` | Open actions menu | +| `A` | AI Chat (in list/detail/diff views) | | `R` | Select region(s) | | `P` | Select profile(s) | | `?` | Show help | @@ -104,6 +112,7 @@ See [docs/keybindings.md](docs/keybindings.md) for complete reference. | [Supported Services](docs/services.md) | All 69 services and 163 resources | | [Configuration](docs/configuration.md) | Config file, themes, and options | | [IAM Permissions](docs/iam-permissions.md) | Required AWS permissions | +| [AI Chat](docs/ai-chat.md) | AI assistant usage and features | | [Architecture](docs/architecture.md) | Internal design and structure | | [Adding Resources](docs/adding-resources.md) | Guide for contributors | diff --git a/cmd/claws/imports_custom.go b/cmd/claws/imports_custom.go index 48191f0f..c716c173 100644 --- a/cmd/claws/imports_custom.go +++ b/cmd/claws/imports_custom.go @@ -150,6 +150,7 @@ import ( // ECS _ "github.com/clawscli/claws/custom/ecs/clusters" _ "github.com/clawscli/claws/custom/ecs/services" + _ "github.com/clawscli/claws/custom/ecs/task-definitions" _ "github.com/clawscli/claws/custom/ecs/tasks" // ElastiCache diff --git a/custom/accessanalyzer/analyzers/dao.go b/custom/accessanalyzer/analyzers/dao.go index bf99861f..00e2257d 100644 --- a/custom/accessanalyzer/analyzers/dao.go +++ b/custom/accessanalyzer/analyzers/dao.go @@ -85,8 +85,9 @@ type AnalyzerResource struct { func NewAnalyzerResource(analyzer types.AnalyzerSummary) *AnalyzerResource { return &AnalyzerResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(analyzer.Name), - ARN: appaws.Str(analyzer.Arn), + ID: appaws.Str(analyzer.Name), + ARN: appaws.Str(analyzer.Arn), + Data: analyzer, }, Summary: &analyzer, } @@ -96,8 +97,9 @@ func NewAnalyzerResource(analyzer types.AnalyzerSummary) *AnalyzerResource { func NewAnalyzerResourceFromDetail(analyzer types.AnalyzerSummary) *AnalyzerResource { return &AnalyzerResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(analyzer.Name), - ARN: appaws.Str(analyzer.Arn), + ID: appaws.Str(analyzer.Name), + ARN: appaws.Str(analyzer.Arn), + Data: analyzer, }, Detail: &analyzer, } diff --git a/custom/accessanalyzer/findings/dao.go b/custom/accessanalyzer/findings/dao.go index 21c46a47..f14ef213 100644 --- a/custom/accessanalyzer/findings/dao.go +++ b/custom/accessanalyzer/findings/dao.go @@ -93,7 +93,8 @@ type FindingResource struct { func NewFindingResource(finding types.FindingSummary, analyzerArn string) *FindingResource { return &FindingResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(finding.Id), + ID: appaws.Str(finding.Id), + Data: finding, }, Summary: &finding, AnalyzerArn: analyzerArn, @@ -104,7 +105,8 @@ func NewFindingResource(finding types.FindingSummary, analyzerArn string) *Findi func NewFindingResourceFromDetail(finding types.Finding, analyzerArn string) *FindingResource { return &FindingResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(finding.Id), + ID: appaws.Str(finding.Id), + Data: finding, }, Detail: &finding, AnalyzerArn: analyzerArn, diff --git a/custom/apprunner/operations/dao.go b/custom/apprunner/operations/dao.go index 977bee5c..624d1c32 100644 --- a/custom/apprunner/operations/dao.go +++ b/custom/apprunner/operations/dao.go @@ -113,8 +113,9 @@ type OperationResource struct { func NewOperationResource(op types.OperationSummary) *OperationResource { return &OperationResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(op.Id), - ARN: appaws.Str(op.TargetArn), + ID: appaws.Str(op.Id), + ARN: appaws.Str(op.TargetArn), + Data: op, }, Item: op, } diff --git a/custom/apprunner/services/dao.go b/custom/apprunner/services/dao.go index 07580134..cd674af2 100644 --- a/custom/apprunner/services/dao.go +++ b/custom/apprunner/services/dao.go @@ -85,8 +85,9 @@ type ServiceResource struct { func NewServiceResource(svc types.ServiceSummary) *ServiceResource { return &ServiceResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(svc.ServiceName), - ARN: appaws.Str(svc.ServiceArn), + ID: appaws.Str(svc.ServiceName), + ARN: appaws.Str(svc.ServiceArn), + Data: svc, }, Summary: &svc, } @@ -96,8 +97,9 @@ func NewServiceResource(svc types.ServiceSummary) *ServiceResource { func NewServiceResourceFromDetail(svc types.Service) *ServiceResource { return &ServiceResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(svc.ServiceName), - ARN: appaws.Str(svc.ServiceArn), + ID: appaws.Str(svc.ServiceName), + ARN: appaws.Str(svc.ServiceArn), + Data: svc, }, Detail: &svc, } diff --git a/custom/appsync/data-sources/dao.go b/custom/appsync/data-sources/dao.go index 066e0f0f..bd0517eb 100644 --- a/custom/appsync/data-sources/dao.go +++ b/custom/appsync/data-sources/dao.go @@ -103,8 +103,9 @@ type DataSourceResource struct { func NewDataSourceResource(ds types.DataSource, apiId string) *DataSourceResource { return &DataSourceResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(ds.Name), - ARN: appaws.Str(ds.DataSourceArn), + ID: appaws.Str(ds.Name), + ARN: appaws.Str(ds.DataSourceArn), + Data: ds, }, DataSource: &ds, apiId: apiId, diff --git a/custom/appsync/graphql-apis/dao.go b/custom/appsync/graphql-apis/dao.go index ce1ae147..5b2b78e7 100644 --- a/custom/appsync/graphql-apis/dao.go +++ b/custom/appsync/graphql-apis/dao.go @@ -83,8 +83,9 @@ type GraphQLApiResource struct { func NewGraphQLApiResource(api types.GraphqlApi) *GraphQLApiResource { return &GraphQLApiResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(api.ApiId), - ARN: appaws.Str(api.Arn), + ID: appaws.Str(api.ApiId), + ARN: appaws.Str(api.Arn), + Data: api, }, Api: &api, } diff --git a/custom/appsync/graphql-apis/render.go b/custom/appsync/graphql-apis/render.go index c9dc5e6f..4d33815b 100644 --- a/custom/appsync/graphql-apis/render.go +++ b/custom/appsync/graphql-apis/render.go @@ -172,7 +172,7 @@ func (r *GraphQLApiRenderer) Navigations(resource dao.Resource) []render.Navigat } return []render.Navigation{ { - Key: "d", + Key: "D", Label: "Data Sources", Service: "appsync", Resource: "data-sources", diff --git a/custom/athena/query-executions/dao.go b/custom/athena/query-executions/dao.go index 553c98ec..4b6852a6 100644 --- a/custom/athena/query-executions/dao.go +++ b/custom/athena/query-executions/dao.go @@ -122,8 +122,9 @@ type QueryExecutionResource struct { func NewQueryExecutionResource(qe types.QueryExecution) *QueryExecutionResource { return &QueryExecutionResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(qe.QueryExecutionId), - ARN: "", + ID: appaws.Str(qe.QueryExecutionId), + ARN: "", + Data: qe, }, Item: qe, } diff --git a/custom/athena/workgroups/dao.go b/custom/athena/workgroups/dao.go index e56fea61..4dd7dd90 100644 --- a/custom/athena/workgroups/dao.go +++ b/custom/athena/workgroups/dao.go @@ -85,8 +85,9 @@ type WorkgroupResource struct { func NewWorkgroupResource(wg types.WorkGroupSummary) *WorkgroupResource { return &WorkgroupResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(wg.Name), - ARN: "", + ID: appaws.Str(wg.Name), + ARN: "", + Data: wg, }, Summary: &wg, } @@ -96,8 +97,9 @@ func NewWorkgroupResource(wg types.WorkGroupSummary) *WorkgroupResource { func NewWorkgroupResourceFromDetail(wg types.WorkGroup) *WorkgroupResource { return &WorkgroupResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(wg.Name), - ARN: "", + ID: appaws.Str(wg.Name), + ARN: "", + Data: wg, }, Detail: &wg, } diff --git a/custom/batch/compute-environments/dao.go b/custom/batch/compute-environments/dao.go index 6fa58dc5..4ebf2fa3 100644 --- a/custom/batch/compute-environments/dao.go +++ b/custom/batch/compute-environments/dao.go @@ -97,8 +97,9 @@ type ComputeEnvironmentResource struct { func NewComputeEnvironmentResource(env types.ComputeEnvironmentDetail) *ComputeEnvironmentResource { return &ComputeEnvironmentResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(env.ComputeEnvironmentName), - ARN: appaws.Str(env.ComputeEnvironmentArn), + ID: appaws.Str(env.ComputeEnvironmentName), + ARN: appaws.Str(env.ComputeEnvironmentArn), + Data: env, }, Env: &env, } diff --git a/custom/batch/job-definitions/dao.go b/custom/batch/job-definitions/dao.go index 4428e803..9bb14534 100644 --- a/custom/batch/job-definitions/dao.go +++ b/custom/batch/job-definitions/dao.go @@ -96,8 +96,9 @@ func NewJobDefinitionResource(def types.JobDefinition) *JobDefinitionResource { } return &JobDefinitionResource{ BaseResource: dao.BaseResource{ - ID: name, - ARN: appaws.Str(def.JobDefinitionArn), + ID: name, + ARN: appaws.Str(def.JobDefinitionArn), + Data: def, }, Def: &def, } diff --git a/custom/batch/job-queues/dao.go b/custom/batch/job-queues/dao.go index ee68fd73..5af35155 100644 --- a/custom/batch/job-queues/dao.go +++ b/custom/batch/job-queues/dao.go @@ -97,8 +97,9 @@ type JobQueueResource struct { func NewJobQueueResource(queue types.JobQueueDetail) *JobQueueResource { return &JobQueueResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(queue.JobQueueName), - ARN: appaws.Str(queue.JobQueueArn), + ID: appaws.Str(queue.JobQueueName), + ARN: appaws.Str(queue.JobQueueArn), + Data: queue, }, Queue: &queue, } diff --git a/custom/batch/jobs/dao.go b/custom/batch/jobs/dao.go index 0f196249..ec7a5a4a 100644 --- a/custom/batch/jobs/dao.go +++ b/custom/batch/jobs/dao.go @@ -95,8 +95,9 @@ func (d *JobDAO) Get(ctx context.Context, id string) (dao.Resource, error) { return &JobResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(job.JobId), - ARN: appaws.Str(job.JobArn), + ID: appaws.Str(job.JobId), + ARN: appaws.Str(job.JobArn), + Data: job, }, Job: &types.JobSummary{ JobId: job.JobId, @@ -154,8 +155,9 @@ type JobResource struct { func NewJobResource(job types.JobSummary) *JobResource { return &JobResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(job.JobId), - ARN: appaws.Str(job.JobArn), + ID: appaws.Str(job.JobId), + ARN: appaws.Str(job.JobArn), + Data: job, }, Job: &job, } diff --git a/custom/budgets/budgets/dao.go b/custom/budgets/budgets/dao.go index 6b1e60eb..2b73a65c 100644 --- a/custom/budgets/budgets/dao.go +++ b/custom/budgets/budgets/dao.go @@ -112,8 +112,9 @@ type BudgetResource struct { func NewBudgetResource(budget types.Budget, accountID string) *BudgetResource { return &BudgetResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(budget.BudgetName), - ARN: fmt.Sprintf("arn:aws:budgets::%s:budget/%s", accountID, appaws.Str(budget.BudgetName)), + ID: appaws.Str(budget.BudgetName), + ARN: fmt.Sprintf("arn:aws:budgets::%s:budget/%s", accountID, appaws.Str(budget.BudgetName)), + Data: budget, }, Item: budget, AccountID: accountID, diff --git a/custom/budgets/notifications/dao.go b/custom/budgets/notifications/dao.go index 4799f97a..792fd15d 100644 --- a/custom/budgets/notifications/dao.go +++ b/custom/budgets/notifications/dao.go @@ -105,8 +105,9 @@ func NewNotificationResource(notif types.Notification, budgetName string, index ) return &NotificationResource{ BaseResource: dao.BaseResource{ - ID: id, - ARN: "", + ID: id, + ARN: "", + Data: notif, }, Item: notif, BudgetName: budgetName, diff --git a/custom/ce/costs/dao.go b/custom/ce/costs/dao.go index 420d7731..c07f8613 100644 --- a/custom/ce/costs/dao.go +++ b/custom/ce/costs/dao.go @@ -163,7 +163,8 @@ func NewCostResource(group types.Group, start, end string) *CostResource { ID: serviceName, // Pseudo-ARN: Cost Explorer aggregates don't have real ARNs. // Format "ce::" enables internal resource identification. - ARN: fmt.Sprintf("ce::%s", serviceName), + ARN: fmt.Sprintf("ce::%s", serviceName), + Data: serviceName, }, ServiceName: serviceName, Cost: cost, diff --git a/custom/cloudtrail/events/dao.go b/custom/cloudtrail/events/dao.go index b3c0b3bb..5e36685a 100644 --- a/custom/cloudtrail/events/dao.go +++ b/custom/cloudtrail/events/dao.go @@ -126,8 +126,9 @@ type EventResource struct { func NewEventResource(event types.Event) *EventResource { return &EventResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(event.EventId), - ARN: appaws.Str(event.EventId), + ID: appaws.Str(event.EventId), + ARN: appaws.Str(event.EventId), + Data: event, }, Item: event, } diff --git a/custom/cloudtrail/trails/dao.go b/custom/cloudtrail/trails/dao.go index c30b778b..7abcc8b8 100644 --- a/custom/cloudtrail/trails/dao.go +++ b/custom/cloudtrail/trails/dao.go @@ -79,8 +79,9 @@ type TrailResource struct { func NewTrailResource(trail types.Trail) *TrailResource { return &TrailResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(trail.Name), - ARN: appaws.Str(trail.TrailARN), + ID: appaws.Str(trail.Name), + ARN: appaws.Str(trail.TrailARN), + Data: trail, }, Item: trail, } diff --git a/custom/configservice/rules/dao.go b/custom/configservice/rules/dao.go index ff018d25..59667c4b 100644 --- a/custom/configservice/rules/dao.go +++ b/custom/configservice/rules/dao.go @@ -100,8 +100,9 @@ type RuleResource struct { func NewRuleResource(rule types.ConfigRule) *RuleResource { return &RuleResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(rule.ConfigRuleName), - ARN: appaws.Str(rule.ConfigRuleArn), + ID: appaws.Str(rule.ConfigRuleName), + ARN: appaws.Str(rule.ConfigRuleArn), + Data: rule, }, Item: rule, } diff --git a/custom/datasync/locations/dao.go b/custom/datasync/locations/dao.go index 02ed8cea..c136784b 100644 --- a/custom/datasync/locations/dao.go +++ b/custom/datasync/locations/dao.go @@ -96,8 +96,9 @@ func NewLocationResource(loc types.LocationListEntry) *LocationResource { arn := appaws.Str(loc.LocationArn) return &LocationResource{ BaseResource: dao.BaseResource{ - ID: extractLocationID(arn), - ARN: arn, + ID: extractLocationID(arn), + ARN: arn, + Data: loc, }, Location: &loc, } diff --git a/custom/datasync/task-executions/dao.go b/custom/datasync/task-executions/dao.go index 2d19627a..38709797 100644 --- a/custom/datasync/task-executions/dao.go +++ b/custom/datasync/task-executions/dao.go @@ -126,8 +126,9 @@ func NewTaskExecutionResource(exec types.TaskExecutionListEntry) *TaskExecutionR arn := appaws.Str(exec.TaskExecutionArn) return &TaskExecutionResource{ BaseResource: dao.BaseResource{ - ID: extractExecutionID(arn), - ARN: arn, + ID: extractExecutionID(arn), + ARN: arn, + Data: exec, }, Execution: &exec, } diff --git a/custom/datasync/tasks/dao.go b/custom/datasync/tasks/dao.go index 8ead4f40..27b74991 100644 --- a/custom/datasync/tasks/dao.go +++ b/custom/datasync/tasks/dao.go @@ -90,8 +90,9 @@ func (d *TaskDAO) Get(ctx context.Context, id string) (dao.Resource, error) { return &TaskResource{ BaseResource: dao.BaseResource{ - ID: extractTaskID(appaws.Str(output.TaskArn)), - ARN: appaws.Str(output.TaskArn), + ID: extractTaskID(appaws.Str(output.TaskArn)), + ARN: appaws.Str(output.TaskArn), + Data: output, }, Task: &types.TaskListEntry{ TaskArn: output.TaskArn, @@ -150,8 +151,9 @@ func NewTaskResource(task types.TaskListEntry) *TaskResource { arn := appaws.Str(task.TaskArn) return &TaskResource{ BaseResource: dao.BaseResource{ - ID: extractTaskID(arn), - ARN: arn, + ID: extractTaskID(arn), + ARN: arn, + Data: task, }, Task: &task, } diff --git a/custom/detective/graphs/dao.go b/custom/detective/graphs/dao.go index 2b42f9df..78205dcd 100644 --- a/custom/detective/graphs/dao.go +++ b/custom/detective/graphs/dao.go @@ -101,8 +101,9 @@ func NewGraphResource(graph types.Graph) *GraphResource { return &GraphResource{ BaseResource: dao.BaseResource{ - ID: id, - ARN: arn, + ID: id, + ARN: arn, + Data: graph, }, Graph: &graph, } diff --git a/custom/detective/investigations/dao.go b/custom/detective/investigations/dao.go index 14f69851..779f1121 100644 --- a/custom/detective/investigations/dao.go +++ b/custom/detective/investigations/dao.go @@ -109,8 +109,9 @@ func NewInvestigationResource(inv types.InvestigationDetail, graphArn string) *I id := appaws.Str(inv.InvestigationId) return &InvestigationResource{ BaseResource: dao.BaseResource{ - ID: id, - ARN: graphArn + "/investigation/" + id, + ID: id, + ARN: graphArn + "/investigation/" + id, + Data: inv, }, Investigation: &inv, graphArn: graphArn, diff --git a/custom/directconnect/connections/dao.go b/custom/directconnect/connections/dao.go index 3c5dc110..1191c1f9 100644 --- a/custom/directconnect/connections/dao.go +++ b/custom/directconnect/connections/dao.go @@ -79,8 +79,9 @@ type ConnectionResource struct { func NewConnectionResource(conn types.Connection) *ConnectionResource { return &ConnectionResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(conn.ConnectionId), - ARN: "", + ID: appaws.Str(conn.ConnectionId), + ARN: "", + Data: conn, }, Item: conn, } diff --git a/custom/directconnect/virtual-interfaces/dao.go b/custom/directconnect/virtual-interfaces/dao.go index 37416ec7..b1155fff 100644 --- a/custom/directconnect/virtual-interfaces/dao.go +++ b/custom/directconnect/virtual-interfaces/dao.go @@ -86,8 +86,9 @@ type VirtualInterfaceResource struct { func NewVirtualInterfaceResource(vi types.VirtualInterface) *VirtualInterfaceResource { return &VirtualInterfaceResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(vi.VirtualInterfaceId), - ARN: "", + ID: appaws.Str(vi.VirtualInterfaceId), + ARN: "", + Data: vi, }, Item: vi, } diff --git a/custom/ecs/clusters/render.go b/custom/ecs/clusters/render.go index d660ee49..47ed2576 100644 --- a/custom/ecs/clusters/render.go +++ b/custom/ecs/clusters/render.go @@ -203,5 +203,11 @@ func (r *ClusterRenderer) Navigations(resource dao.Resource) []render.Navigation FilterField: "ClusterName", FilterValue: clusterName, }, + { + Key: "D", + Label: "Task Definitions", + Service: "ecs", + Resource: "task-definitions", + }, } } diff --git a/custom/ecs/services/render.go b/custom/ecs/services/render.go index 607d2d9f..70e90835 100644 --- a/custom/ecs/services/render.go +++ b/custom/ecs/services/render.go @@ -363,7 +363,7 @@ func (r *ServiceRenderer) Navigations(resource dao.Resource) []render.Navigation // Extract cluster name from ARN for filtering clusterName := appaws.ExtractResourceName(svc.ClusterArn()) - return []render.Navigation{ + navs := []render.Navigation{ { Key: "t", Label: "Tasks", @@ -389,4 +389,17 @@ func (r *ServiceRenderer) Navigations(resource dao.Resource) []render.Navigation FilterValue: "/ecs/" + svc.GetName(), }, } + + if td := svc.TaskDefinition(); td != "" { + navs = append(navs, render.Navigation{ + Key: "D", + Label: "Task Definition", + Service: "ecs", + Resource: "task-definitions", + FilterField: "TaskDefinition", + FilterValue: appaws.ExtractResourceName(td), + }) + } + + return navs } diff --git a/custom/ecs/task-definitions/constants.go b/custom/ecs/task-definitions/constants.go new file mode 100644 index 00000000..24bc4520 --- /dev/null +++ b/custom/ecs/task-definitions/constants.go @@ -0,0 +1,7 @@ +// Code generated by go generate; DO NOT EDIT. +// To regenerate: task gen-imports + +package taskdefinitions + +// ServiceResourcePath is the canonical path for this resource type. +const ServiceResourcePath = "ecs/task-definitions" diff --git a/custom/ecs/task-definitions/dao.go b/custom/ecs/task-definitions/dao.go new file mode 100644 index 00000000..1d1ae452 --- /dev/null +++ b/custom/ecs/task-definitions/dao.go @@ -0,0 +1,239 @@ +package taskdefinitions + +import ( + "context" + "fmt" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/aws/aws-sdk-go-v2/service/ecs/types" + + appaws "github.com/clawscli/claws/internal/aws" + "github.com/clawscli/claws/internal/dao" + apperrors "github.com/clawscli/claws/internal/errors" + "github.com/clawscli/claws/internal/log" +) + +type TaskDefinitionDAO struct { + dao.BaseDAO + client *ecs.Client +} + +func NewTaskDefinitionDAO(ctx context.Context) (dao.DAO, error) { + cfg, err := appaws.NewConfig(ctx) + if err != nil { + return nil, apperrors.Wrap(err, "new ecs/task-definitions dao") + } + return &TaskDefinitionDAO{ + BaseDAO: dao.NewBaseDAO("ecs", "task-definitions"), + client: ecs.NewFromConfig(cfg), + }, nil +} + +func (d *TaskDefinitionDAO) List(ctx context.Context) ([]dao.Resource, error) { + taskDefArns, err := appaws.Paginate(ctx, func(token *string) ([]string, *string, error) { + output, err := d.client.ListTaskDefinitions(ctx, &ecs.ListTaskDefinitionsInput{ + Status: types.TaskDefinitionStatusActive, + Sort: types.SortOrderDesc, + NextToken: token, + }) + if err != nil { + return nil, nil, apperrors.Wrap(err, "list task definitions") + } + return output.TaskDefinitionArns, output.NextToken, nil + }) + if err != nil { + return nil, err + } + + seenFamilies := make(map[string]bool) + var latestArns []string + for _, arn := range taskDefArns { + family := extractFamilyFromArn(arn) + if !seenFamilies[family] { + seenFamilies[family] = true + latestArns = append(latestArns, arn) + } + } + + resources := make([]dao.Resource, 0, len(latestArns)) + for _, arn := range latestArns { + output, err := d.client.DescribeTaskDefinition(ctx, &ecs.DescribeTaskDefinitionInput{ + TaskDefinition: &arn, + }) + if err != nil { + log.Warn("failed to describe task definition", "arn", arn, "error", err) + continue + } + if output.TaskDefinition != nil { + resources = append(resources, NewTaskDefinitionResource(*output.TaskDefinition)) + } + } + + return resources, nil +} + +func (d *TaskDefinitionDAO) Get(ctx context.Context, id string) (dao.Resource, error) { + output, err := d.client.DescribeTaskDefinition(ctx, &ecs.DescribeTaskDefinitionInput{ + TaskDefinition: &id, + }) + if err != nil { + return nil, apperrors.Wrapf(err, "describe task definition %s", id) + } + + if output.TaskDefinition == nil { + return nil, fmt.Errorf("task definition not found: %s", id) + } + + return NewTaskDefinitionResource(*output.TaskDefinition), nil +} + +func (d *TaskDefinitionDAO) Delete(ctx context.Context, id string) error { + _, err := d.client.DeregisterTaskDefinition(ctx, &ecs.DeregisterTaskDefinitionInput{ + TaskDefinition: &id, + }) + if err != nil { + return apperrors.Wrapf(err, "deregister task definition %s", id) + } + return nil +} + +func extractFamilyFromArn(arn string) string { + parts := strings.Split(arn, "/") + if len(parts) < 2 { + return arn + } + familyRevision := parts[len(parts)-1] + colonIdx := strings.LastIndex(familyRevision, ":") + if colonIdx == -1 { + return familyRevision + } + return familyRevision[:colonIdx] +} + +type TaskDefinitionResource struct { + dao.BaseResource + Item types.TaskDefinition +} + +func NewTaskDefinitionResource(td types.TaskDefinition) *TaskDefinitionResource { + family := appaws.Str(td.Family) + revision := td.Revision + id := fmt.Sprintf("%s:%d", family, revision) + + return &TaskDefinitionResource{ + BaseResource: dao.BaseResource{ + ID: id, + Name: family, + ARN: appaws.Str(td.TaskDefinitionArn), + Data: td, + }, + Item: td, + } +} + +func (r *TaskDefinitionResource) Family() string { + return appaws.Str(r.Item.Family) +} + +func (r *TaskDefinitionResource) Revision() int32 { + return r.Item.Revision +} + +func (r *TaskDefinitionResource) Status() string { + return string(r.Item.Status) +} + +func (r *TaskDefinitionResource) CPU() string { + return appaws.Str(r.Item.Cpu) +} + +func (r *TaskDefinitionResource) Memory() string { + return appaws.Str(r.Item.Memory) +} + +func (r *TaskDefinitionResource) NetworkMode() string { + return string(r.Item.NetworkMode) +} + +func (r *TaskDefinitionResource) RequiresCompatibilities() []types.Compatibility { + return r.Item.RequiresCompatibilities +} + +func (r *TaskDefinitionResource) ContainerDefinitions() []types.ContainerDefinition { + return r.Item.ContainerDefinitions +} + +func (r *TaskDefinitionResource) TaskRoleArn() string { + return appaws.Str(r.Item.TaskRoleArn) +} + +func (r *TaskDefinitionResource) ExecutionRoleArn() string { + return appaws.Str(r.Item.ExecutionRoleArn) +} + +func (r *TaskDefinitionResource) Volumes() []types.Volume { + return r.Item.Volumes +} + +func (r *TaskDefinitionResource) RuntimePlatform() *types.RuntimePlatform { + return r.Item.RuntimePlatform +} + +func (r *TaskDefinitionResource) GetLogConfiguration(containerName string) *types.LogConfiguration { + containers := r.Item.ContainerDefinitions + if len(containers) == 0 { + return nil + } + + if containerName != "" { + for _, c := range containers { + if appaws.Str(c.Name) == containerName { + return c.LogConfiguration + } + } + return nil + } + + return containers[0].LogConfiguration +} + +func (r *TaskDefinitionResource) GetCloudWatchLogGroup(containerName string) string { + logConfig := r.GetLogConfiguration(containerName) + if logConfig == nil { + return "" + } + + if logConfig.LogDriver != types.LogDriverAwslogs { + return "" + } + + if logConfig.Options == nil { + return "" + } + + return logConfig.Options["awslogs-group"] +} + +func (r *TaskDefinitionResource) GetAllCloudWatchLogGroups() []string { + var groups []string + seen := make(map[string]bool) + + for _, c := range r.Item.ContainerDefinitions { + if c.LogConfiguration == nil { + continue + } + if c.LogConfiguration.LogDriver != types.LogDriverAwslogs { + continue + } + if c.LogConfiguration.Options == nil { + continue + } + if group := c.LogConfiguration.Options["awslogs-group"]; group != "" && !seen[group] { + seen[group] = true + groups = append(groups, group) + } + } + + return groups +} diff --git a/custom/ecs/task-definitions/register.go b/custom/ecs/task-definitions/register.go new file mode 100644 index 00000000..664a0846 --- /dev/null +++ b/custom/ecs/task-definitions/register.go @@ -0,0 +1,20 @@ +package taskdefinitions + +import ( + "context" + + "github.com/clawscli/claws/internal/dao" + "github.com/clawscli/claws/internal/registry" + "github.com/clawscli/claws/internal/render" +) + +func init() { + registry.Global.RegisterCustom("ecs", "task-definitions", registry.Entry{ + DAOFactory: func(ctx context.Context) (dao.DAO, error) { + return NewTaskDefinitionDAO(ctx) + }, + RendererFactory: func() render.Renderer { + return NewTaskDefinitionRenderer() + }, + }) +} diff --git a/custom/ecs/task-definitions/render.go b/custom/ecs/task-definitions/render.go new file mode 100644 index 00000000..115625f1 --- /dev/null +++ b/custom/ecs/task-definitions/render.go @@ -0,0 +1,289 @@ +package taskdefinitions + +import ( + "fmt" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/ecs/types" + + appaws "github.com/clawscli/claws/internal/aws" + "github.com/clawscli/claws/internal/dao" + "github.com/clawscli/claws/internal/render" + "github.com/clawscli/claws/internal/ui" +) + +var _ render.Navigator = (*TaskDefinitionRenderer)(nil) + +type TaskDefinitionRenderer struct { + render.BaseRenderer +} + +func NewTaskDefinitionRenderer() render.Renderer { + return &TaskDefinitionRenderer{ + BaseRenderer: render.BaseRenderer{ + Service: "ecs", + Resource: "task-definitions", + Cols: []render.Column{ + {Name: "FAMILY", Width: 35, Getter: func(r dao.Resource) string { return r.GetName() }}, + {Name: "REV", Width: 5, Getter: getRevision}, + {Name: "STATUS", Width: 10, Getter: getStatus}, + {Name: "CPU", Width: 8, Getter: getCPU}, + {Name: "MEMORY", Width: 8, Getter: getMemory}, + {Name: "NETWORK", Width: 10, Getter: getNetworkMode}, + {Name: "CONTAINERS", Width: 10, Getter: getContainerCount}, + }, + }, + } +} + +func getRevision(r dao.Resource) string { + if td, ok := r.(*TaskDefinitionResource); ok { + return fmt.Sprintf("%d", td.Revision()) + } + return "" +} + +func getStatus(r dao.Resource) string { + if td, ok := r.(*TaskDefinitionResource); ok { + status := td.Status() + switch status { + case "ACTIVE": + return "active" + case "INACTIVE": + return "stopped" + case "DELETE_IN_PROGRESS": + return "deleting" + default: + return strings.ToLower(status) + } + } + return "" +} + +func getCPU(r dao.Resource) string { + if td, ok := r.(*TaskDefinitionResource); ok { + if cpu := td.CPU(); cpu != "" { + return cpu + } + } + return "-" +} + +func getMemory(r dao.Resource) string { + if td, ok := r.(*TaskDefinitionResource); ok { + if mem := td.Memory(); mem != "" { + return mem + } + } + return "-" +} + +func getNetworkMode(r dao.Resource) string { + if td, ok := r.(*TaskDefinitionResource); ok { + mode := td.NetworkMode() + if mode == "" { + return "bridge" + } + return mode + } + return "" +} + +func getContainerCount(r dao.Resource) string { + if td, ok := r.(*TaskDefinitionResource); ok { + return fmt.Sprintf("%d", len(td.ContainerDefinitions())) + } + return "" +} + +func (r *TaskDefinitionRenderer) RenderDetail(resource dao.Resource) string { + td, ok := resource.(*TaskDefinitionResource) + if !ok { + return "" + } + + d := render.NewDetailBuilder() + + d.Title("ECS Task Definition", td.GetID()) + + d.Section("Basic Information") + d.Field("Family", td.Family()) + d.Field("Revision", fmt.Sprintf("%d", td.Revision())) + d.Field("ARN", td.GetARN()) + d.FieldStyled("Status", td.Status(), render.StateColorer()(strings.ToLower(td.Status()))) + + d.Section("Task Configuration") + if cpu := td.CPU(); cpu != "" { + d.Field("CPU", cpu+" units") + } + if mem := td.Memory(); mem != "" { + d.Field("Memory", mem+" MiB") + } + d.Field("Network Mode", td.NetworkMode()) + + if compat := td.RequiresCompatibilities(); len(compat) > 0 { + var compatStr []string + for _, c := range compat { + compatStr = append(compatStr, string(c)) + } + d.Field("Compatibilities", strings.Join(compatStr, ", ")) + } + + if rp := td.RuntimePlatform(); rp != nil { + if rp.OperatingSystemFamily != "" { + d.Field("OS Family", string(rp.OperatingSystemFamily)) + } + if rp.CpuArchitecture != "" { + d.Field("CPU Architecture", string(rp.CpuArchitecture)) + } + } + + if role := td.TaskRoleArn(); role != "" { + d.Section("IAM Roles") + d.Field("Task Role", appaws.ExtractResourceName(role)) + } + if execRole := td.ExecutionRoleArn(); execRole != "" { + if td.TaskRoleArn() == "" { + d.Section("IAM Roles") + } + d.Field("Execution Role", appaws.ExtractResourceName(execRole)) + } + + containers := td.ContainerDefinitions() + if len(containers) > 0 { + d.Section(fmt.Sprintf("Containers (%d)", len(containers))) + for _, c := range containers { + containerName := appaws.Str(c.Name) + d.Line("") + d.FieldStyled(containerName, "", ui.TitleStyle()) + + if c.Image != nil { + d.Field(" Image", *c.Image) + } + if c.Essential != nil && *c.Essential { + d.Field(" Essential", "Yes") + } + if c.Cpu != 0 { + d.Field(" CPU", fmt.Sprintf("%d", c.Cpu)) + } + if c.Memory != nil { + d.Field(" Memory", fmt.Sprintf("%d MiB", *c.Memory)) + } + if c.MemoryReservation != nil { + d.Field(" Memory Reservation", fmt.Sprintf("%d MiB", *c.MemoryReservation)) + } + + if len(c.PortMappings) > 0 { + var ports []string + for _, pm := range c.PortMappings { + if pm.ContainerPort != nil { + port := fmt.Sprintf("%d", *pm.ContainerPort) + if pm.HostPort != nil && *pm.HostPort != *pm.ContainerPort { + port = fmt.Sprintf("%d:%d", *pm.HostPort, *pm.ContainerPort) + } + if pm.Protocol != "" { + port += "/" + strings.ToLower(string(pm.Protocol)) + } + ports = append(ports, port) + } + } + d.Field(" Ports", strings.Join(ports, ", ")) + } + + if c.LogConfiguration != nil { + d.Field(" Log Driver", string(c.LogConfiguration.LogDriver)) + if c.LogConfiguration.LogDriver == types.LogDriverAwslogs && c.LogConfiguration.Options != nil { + if group := c.LogConfiguration.Options["awslogs-group"]; group != "" { + d.FieldStyled(" Log Group", group, ui.SuccessStyle()) + } + if prefix := c.LogConfiguration.Options["awslogs-stream-prefix"]; prefix != "" { + d.Field(" Stream Prefix", prefix) + } + } + } + } + } + + if volumes := td.Volumes(); len(volumes) > 0 { + d.Section("Volumes") + for _, v := range volumes { + if v.Name != nil { + d.Field(*v.Name, "") + if v.Host != nil && v.Host.SourcePath != nil { + d.Field(" Source", *v.Host.SourcePath) + } + if v.EfsVolumeConfiguration != nil { + d.Field(" Type", "EFS") + d.Field(" File System ID", appaws.Str(v.EfsVolumeConfiguration.FileSystemId)) + } + } + } + } + + d.Tags(td.GetTags()) + + return d.String() +} + +func (r *TaskDefinitionRenderer) RenderSummary(resource dao.Resource) []render.SummaryField { + td, ok := resource.(*TaskDefinitionResource) + if !ok { + return r.BaseRenderer.RenderSummary(resource) + } + + fields := []render.SummaryField{ + {Label: "Family", Value: td.Family()}, + {Label: "Revision", Value: fmt.Sprintf("%d", td.Revision())}, + {Label: "ARN", Value: td.GetARN()}, + {Label: "Status", Value: td.Status()}, + } + + if cpu := td.CPU(); cpu != "" { + fields = append(fields, render.SummaryField{Label: "CPU", Value: cpu}) + } + if mem := td.Memory(); mem != "" { + fields = append(fields, render.SummaryField{Label: "Memory", Value: mem}) + } + + fields = append(fields, render.SummaryField{Label: "Network Mode", Value: td.NetworkMode()}) + fields = append(fields, render.SummaryField{Label: "Containers", Value: fmt.Sprintf("%d", len(td.ContainerDefinitions()))}) + + if groups := td.GetAllCloudWatchLogGroups(); len(groups) > 0 { + fields = append(fields, render.SummaryField{Label: "Log Groups", Value: strings.Join(groups, ", ")}) + } + + return fields +} + +func (r *TaskDefinitionRenderer) Navigations(resource dao.Resource) []render.Navigation { + td, ok := resource.(*TaskDefinitionResource) + if !ok { + return nil + } + + var navs []render.Navigation + + if groups := td.GetAllCloudWatchLogGroups(); len(groups) > 0 { + navs = append(navs, render.Navigation{ + Key: "l", + Label: "Logs", + Service: "cloudwatch", + Resource: "log-groups", + FilterField: "LogGroupPrefix", + FilterValue: groups[0], + }) + } + + if role := td.TaskRoleArn(); role != "" { + navs = append(navs, render.Navigation{ + Key: "r", + Label: "Task Role", + Service: "iam", + Resource: "roles", + FilterField: "RoleName", + FilterValue: appaws.ExtractResourceName(role), + }) + } + + return navs +} diff --git a/custom/ecs/tasks/render.go b/custom/ecs/tasks/render.go index 1bc11a4a..96c78cf3 100644 --- a/custom/ecs/tasks/render.go +++ b/custom/ecs/tasks/render.go @@ -329,13 +329,21 @@ func (r *TaskRenderer) Navigations(resource dao.Resource) []render.Navigation { }) } - // Add logs navigation - use task definition family as log group prefix if taskDef := task.TaskDefinitionArn(); taskDef != "" { - // Task definition ARN: arn:aws:ecs:region:account:task-definition/family:revision taskDefName := appaws.ExtractResourceName(taskDef) - // Remove revision number (e.g., "my-task:5" -> "my-task") - if idx := strings.LastIndex(taskDefName, ":"); idx > 0 { - taskDefName = taskDefName[:idx] + + navs = append(navs, render.Navigation{ + Key: "D", + Label: "Task Definition", + Service: "ecs", + Resource: "task-definitions", + FilterField: "TaskDefinition", + FilterValue: taskDefName, + }) + + family := taskDefName + if idx := strings.LastIndex(family, ":"); idx > 0 { + family = family[:idx] } navs = append(navs, render.Navigation{ Key: "l", @@ -343,7 +351,7 @@ func (r *TaskRenderer) Navigations(resource dao.Resource) []render.Navigation { Service: "cloudwatch", Resource: "log-groups", FilterField: "LogGroupPrefix", - FilterValue: "/ecs/" + taskDefName, + FilterValue: "/ecs/" + family, }) } diff --git a/custom/emr/clusters/dao.go b/custom/emr/clusters/dao.go index 9cf06f1c..88d1ee00 100644 --- a/custom/emr/clusters/dao.go +++ b/custom/emr/clusters/dao.go @@ -78,8 +78,9 @@ func (d *ClusterDAO) Get(ctx context.Context, id string) (dao.Resource, error) { return &ClusterResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(cluster.Id), - ARN: appaws.Str(cluster.ClusterArn), + ID: appaws.Str(cluster.Id), + ARN: appaws.Str(cluster.ClusterArn), + Data: cluster, }, Cluster: &types.ClusterSummary{ Id: cluster.Id, @@ -135,8 +136,9 @@ type ClusterResource struct { func NewClusterResource(cluster types.ClusterSummary) *ClusterResource { return &ClusterResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(cluster.Id), - ARN: appaws.Str(cluster.ClusterArn), + ID: appaws.Str(cluster.Id), + ARN: appaws.Str(cluster.ClusterArn), + Data: cluster, }, Cluster: &cluster, } diff --git a/custom/emr/steps/dao.go b/custom/emr/steps/dao.go index 366d0f78..e692c7bf 100644 --- a/custom/emr/steps/dao.go +++ b/custom/emr/steps/dao.go @@ -76,8 +76,9 @@ func (d *StepDAO) Get(ctx context.Context, id string) (dao.Resource, error) { step := output.Step return &StepResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(step.Id), - ARN: "", + ID: appaws.Str(step.Id), + ARN: "", + Data: step, }, Step: &types.StepSummary{ Id: step.Id, @@ -118,8 +119,9 @@ type StepResource struct { func NewStepResource(step types.StepSummary, clusterId string) *StepResource { return &StepResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(step.Id), - ARN: "", + ID: appaws.Str(step.Id), + ARN: "", + Data: step, }, Step: &step, clusterId: clusterId, diff --git a/custom/fms/policies/dao.go b/custom/fms/policies/dao.go index acca7526..e4d41e93 100644 --- a/custom/fms/policies/dao.go +++ b/custom/fms/policies/dao.go @@ -84,8 +84,9 @@ type PolicyResource struct { func NewPolicyResource(policy types.PolicySummary) *PolicyResource { return &PolicyResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(policy.PolicyId), - ARN: appaws.Str(policy.PolicyArn), + ID: appaws.Str(policy.PolicyId), + ARN: appaws.Str(policy.PolicyArn), + Data: policy, }, Summary: &policy, } @@ -95,8 +96,9 @@ func NewPolicyResource(policy types.PolicySummary) *PolicyResource { func NewPolicyResourceFromDetail(policy types.Policy, arn *string) *PolicyResource { return &PolicyResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(policy.PolicyId), - ARN: appaws.Str(arn), + ID: appaws.Str(policy.PolicyId), + ARN: appaws.Str(arn), + Data: policy, }, Detail: &policy, } diff --git a/custom/glue/crawlers/dao.go b/custom/glue/crawlers/dao.go index 70cd37e9..432e9779 100644 --- a/custom/glue/crawlers/dao.go +++ b/custom/glue/crawlers/dao.go @@ -84,8 +84,9 @@ type CrawlerResource struct { func NewCrawlerResource(crawler types.Crawler) *CrawlerResource { return &CrawlerResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(crawler.Name), - ARN: "", + ID: appaws.Str(crawler.Name), + ARN: "", + Data: crawler, }, Item: crawler, } diff --git a/custom/glue/databases/dao.go b/custom/glue/databases/dao.go index 6917393f..41d13428 100644 --- a/custom/glue/databases/dao.go +++ b/custom/glue/databases/dao.go @@ -84,8 +84,9 @@ type DatabaseResource struct { func NewDatabaseResource(db types.Database) *DatabaseResource { return &DatabaseResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(db.Name), - ARN: appaws.Str(db.CatalogId), + ID: appaws.Str(db.Name), + ARN: appaws.Str(db.CatalogId), + Data: db, }, Item: db, } diff --git a/custom/glue/job-runs/dao.go b/custom/glue/job-runs/dao.go index 1f5d7973..985a013b 100644 --- a/custom/glue/job-runs/dao.go +++ b/custom/glue/job-runs/dao.go @@ -109,8 +109,9 @@ type JobRunResource struct { func NewJobRunResource(run types.JobRun) *JobRunResource { return &JobRunResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(run.Id), - ARN: "", + ID: appaws.Str(run.Id), + ARN: "", + Data: run, }, Item: run, } diff --git a/custom/glue/jobs/dao.go b/custom/glue/jobs/dao.go index 0491f121..38a16788 100644 --- a/custom/glue/jobs/dao.go +++ b/custom/glue/jobs/dao.go @@ -84,8 +84,9 @@ type JobResource struct { func NewJobResource(job types.Job) *JobResource { return &JobResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(job.Name), - ARN: "", // Glue jobs don't have ARN in the response + ID: appaws.Str(job.Name), + ARN: "", // Glue jobs don't have ARN in the response + Data: job, }, Item: job, } diff --git a/custom/glue/tables/dao.go b/custom/glue/tables/dao.go index 55927805..a25921d7 100644 --- a/custom/glue/tables/dao.go +++ b/custom/glue/tables/dao.go @@ -105,8 +105,9 @@ type TableResource struct { func NewTableResource(table types.Table, databaseName string) *TableResource { return &TableResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(table.Name), - ARN: "", + ID: appaws.Str(table.Name), + ARN: "", + Data: table, }, Item: table, DatabaseName: databaseName, diff --git a/custom/health/events/dao.go b/custom/health/events/dao.go index 3ff44a05..b3f84418 100644 --- a/custom/health/events/dao.go +++ b/custom/health/events/dao.go @@ -121,8 +121,9 @@ type EventResource struct { func NewEventResource(event types.Event) *EventResource { return &EventResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(event.Arn), - ARN: appaws.Str(event.Arn), + ID: appaws.Str(event.Arn), + ARN: appaws.Str(event.Arn), + Data: event, }, Item: event, } diff --git a/custom/license-manager/configurations/dao.go b/custom/license-manager/configurations/dao.go index 3e19fc42..182cfa34 100644 --- a/custom/license-manager/configurations/dao.go +++ b/custom/license-manager/configurations/dao.go @@ -62,8 +62,9 @@ func (d *ConfigurationDAO) Get(ctx context.Context, arn string) (dao.Resource, e return &ConfigurationResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(output.Name), - ARN: appaws.Str(output.LicenseConfigurationArn), + ID: appaws.Str(output.Name), + ARN: appaws.Str(output.LicenseConfigurationArn), + Data: output, }, Config: &types.LicenseConfiguration{ LicenseConfigurationArn: output.LicenseConfigurationArn, @@ -98,8 +99,9 @@ type ConfigurationResource struct { func NewConfigurationResource(config types.LicenseConfiguration) *ConfigurationResource { return &ConfigurationResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(config.Name), - ARN: appaws.Str(config.LicenseConfigurationArn), + ID: appaws.Str(config.Name), + ARN: appaws.Str(config.LicenseConfigurationArn), + Data: config, }, Config: &config, } diff --git a/custom/license-manager/grants/dao.go b/custom/license-manager/grants/dao.go index a0e79a74..6ed8ad6b 100644 --- a/custom/license-manager/grants/dao.go +++ b/custom/license-manager/grants/dao.go @@ -115,8 +115,9 @@ func NewGrantResource(grant types.Grant) *GrantResource { } return &GrantResource{ BaseResource: dao.BaseResource{ - ID: id, - ARN: arn, + ID: id, + ARN: arn, + Data: grant, }, Grant: &grant, } diff --git a/custom/license-manager/licenses/dao.go b/custom/license-manager/licenses/dao.go index a5f729e1..85868678 100644 --- a/custom/license-manager/licenses/dao.go +++ b/custom/license-manager/licenses/dao.go @@ -99,8 +99,9 @@ func NewLicenseResource(license types.License) *LicenseResource { } return &LicenseResource{ BaseResource: dao.BaseResource{ - ID: id, - ARN: arn, + ID: id, + ARN: arn, + Data: license, }, License: &license, } diff --git a/custom/macie2/buckets/dao.go b/custom/macie2/buckets/dao.go index f8c09165..7f8bafde 100644 --- a/custom/macie2/buckets/dao.go +++ b/custom/macie2/buckets/dao.go @@ -81,8 +81,9 @@ type BucketResource struct { func NewBucketResource(bucket types.BucketMetadata) *BucketResource { return &BucketResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(bucket.BucketName), - ARN: appaws.Str(bucket.BucketArn), + ID: appaws.Str(bucket.BucketName), + ARN: appaws.Str(bucket.BucketArn), + Data: bucket, }, Bucket: &bucket, } diff --git a/custom/macie2/classification-jobs/dao.go b/custom/macie2/classification-jobs/dao.go index 6c9f0ac7..16c4748e 100644 --- a/custom/macie2/classification-jobs/dao.go +++ b/custom/macie2/classification-jobs/dao.go @@ -62,8 +62,9 @@ func (d *ClassificationJobDAO) Get(ctx context.Context, id string) (dao.Resource } return &ClassificationJobResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(output.JobId), - ARN: appaws.Str(output.JobArn), + ID: appaws.Str(output.JobId), + ARN: appaws.Str(output.JobArn), + Data: output, }, Job: &types.JobSummary{ JobId: output.JobId, @@ -98,8 +99,9 @@ type ClassificationJobResource struct { func NewClassificationJobResource(job types.JobSummary) *ClassificationJobResource { return &ClassificationJobResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(job.JobId), - ARN: "", + ID: appaws.Str(job.JobId), + ARN: "", + Data: job, }, Job: &job, } diff --git a/custom/macie2/findings/dao.go b/custom/macie2/findings/dao.go index b8c61feb..cbf736d7 100644 --- a/custom/macie2/findings/dao.go +++ b/custom/macie2/findings/dao.go @@ -118,8 +118,9 @@ type FindingResource struct { func NewFindingResource(finding types.Finding) *FindingResource { return &FindingResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(finding.Id), - ARN: "", + ID: appaws.Str(finding.Id), + ARN: "", + Data: finding, }, Finding: &finding, } diff --git a/custom/network-firewall/firewall-policies/dao.go b/custom/network-firewall/firewall-policies/dao.go index 90dfadcf..c1ee0463 100644 --- a/custom/network-firewall/firewall-policies/dao.go +++ b/custom/network-firewall/firewall-policies/dao.go @@ -96,8 +96,9 @@ func NewFirewallPolicyResource(p types.FirewallPolicyMetadata) *FirewallPolicyRe func NewFirewallPolicyResourceFromDetail(resp *types.FirewallPolicyResponse, p *types.FirewallPolicy) *FirewallPolicyResource { return &FirewallPolicyResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(resp.FirewallPolicyName), - ARN: appaws.Str(resp.FirewallPolicyArn), + ID: appaws.Str(resp.FirewallPolicyName), + ARN: appaws.Str(resp.FirewallPolicyArn), + Data: p, }, Response: resp, Detail: p, diff --git a/custom/network-firewall/firewalls/dao.go b/custom/network-firewall/firewalls/dao.go index 9dca7cfb..6a0c0103 100644 --- a/custom/network-firewall/firewalls/dao.go +++ b/custom/network-firewall/firewalls/dao.go @@ -96,8 +96,9 @@ func NewFirewallResource(fw types.FirewallMetadata) *FirewallResource { func NewFirewallResourceFromDetail(fw types.Firewall, status *types.FirewallStatus) *FirewallResource { return &FirewallResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(fw.FirewallName), - ARN: appaws.Str(fw.FirewallArn), + ID: appaws.Str(fw.FirewallName), + ARN: appaws.Str(fw.FirewallArn), + Data: fw, }, Detail: &fw, Status: status, diff --git a/custom/network-firewall/rule-groups/dao.go b/custom/network-firewall/rule-groups/dao.go index 9f39698b..85beec00 100644 --- a/custom/network-firewall/rule-groups/dao.go +++ b/custom/network-firewall/rule-groups/dao.go @@ -85,8 +85,9 @@ type RuleGroupResource struct { func NewRuleGroupResource(rg types.RuleGroupMetadata) *RuleGroupResource { return &RuleGroupResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(rg.Name), - ARN: appaws.Str(rg.Arn), + ID: appaws.Str(rg.Name), + ARN: appaws.Str(rg.Arn), + Data: rg, }, Metadata: &rg, } @@ -96,8 +97,9 @@ func NewRuleGroupResource(rg types.RuleGroupMetadata) *RuleGroupResource { func NewRuleGroupResourceFromDetail(resp *types.RuleGroupResponse, rg *types.RuleGroup) *RuleGroupResource { return &RuleGroupResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(resp.RuleGroupName), - ARN: appaws.Str(resp.RuleGroupArn), + ID: appaws.Str(resp.RuleGroupName), + ARN: appaws.Str(resp.RuleGroupArn), + Data: resp, }, Response: resp, Detail: rg, diff --git a/custom/organizations/accounts/dao.go b/custom/organizations/accounts/dao.go index 28c89f7c..fe9005b2 100644 --- a/custom/organizations/accounts/dao.go +++ b/custom/organizations/accounts/dao.go @@ -84,8 +84,9 @@ type AccountResource struct { func NewAccountResource(account types.Account) *AccountResource { return &AccountResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(account.Id), - ARN: appaws.Str(account.Arn), + ID: appaws.Str(account.Id), + ARN: appaws.Str(account.Arn), + Data: account, }, Account: &account, } diff --git a/custom/organizations/ous/dao.go b/custom/organizations/ous/dao.go index 9face24a..bd0383f2 100644 --- a/custom/organizations/ous/dao.go +++ b/custom/organizations/ous/dao.go @@ -90,8 +90,9 @@ type OUResource struct { func NewOUResource(ou types.OrganizationalUnit) *OUResource { return &OUResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(ou.Id), - ARN: appaws.Str(ou.Arn), + ID: appaws.Str(ou.Id), + ARN: appaws.Str(ou.Arn), + Data: ou, }, OU: &ou, } diff --git a/custom/organizations/policies/dao.go b/custom/organizations/policies/dao.go index dab2dad1..00bb6153 100644 --- a/custom/organizations/policies/dao.go +++ b/custom/organizations/policies/dao.go @@ -75,8 +75,9 @@ func (d *PolicyDAO) Get(ctx context.Context, id string) (dao.Resource, error) { } return &PolicyResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(output.Policy.PolicySummary.Id), - ARN: appaws.Str(output.Policy.PolicySummary.Arn), + ID: appaws.Str(output.Policy.PolicySummary.Id), + ARN: appaws.Str(output.Policy.PolicySummary.Arn), + Data: output, }, Policy: output.Policy.PolicySummary, Content: appaws.Str(output.Policy.Content), @@ -105,8 +106,9 @@ type PolicyResource struct { func NewPolicyResource(policy types.PolicySummary) *PolicyResource { return &PolicyResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(policy.Id), - ARN: appaws.Str(policy.Arn), + ID: appaws.Str(policy.Id), + ARN: appaws.Str(policy.Arn), + Data: policy, }, Policy: &policy, } diff --git a/custom/organizations/roots/dao.go b/custom/organizations/roots/dao.go index 4fa78b76..e814f691 100644 --- a/custom/organizations/roots/dao.go +++ b/custom/organizations/roots/dao.go @@ -81,8 +81,9 @@ type RootResource struct { func NewRootResource(root types.Root) *RootResource { return &RootResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(root.Id), - ARN: appaws.Str(root.Arn), + ID: appaws.Str(root.Id), + ARN: appaws.Str(root.Arn), + Data: root, }, Root: &root, } diff --git a/custom/redshift/clusters/dao.go b/custom/redshift/clusters/dao.go index bee473b2..1ee9c4c2 100644 --- a/custom/redshift/clusters/dao.go +++ b/custom/redshift/clusters/dao.go @@ -90,8 +90,9 @@ type ClusterResource struct { func NewClusterResource(cluster types.Cluster) *ClusterResource { return &ClusterResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(cluster.ClusterIdentifier), - ARN: "", + ID: appaws.Str(cluster.ClusterIdentifier), + ARN: "", + Data: cluster, }, Cluster: &cluster, } diff --git a/custom/redshift/snapshots/dao.go b/custom/redshift/snapshots/dao.go index 357e8e4d..3ca6d431 100644 --- a/custom/redshift/snapshots/dao.go +++ b/custom/redshift/snapshots/dao.go @@ -98,8 +98,9 @@ type SnapshotResource struct { func NewSnapshotResource(snapshot types.Snapshot) *SnapshotResource { return &SnapshotResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(snapshot.SnapshotIdentifier), - ARN: "", + ID: appaws.Str(snapshot.SnapshotIdentifier), + ARN: "", + Data: snapshot, }, Snapshot: &snapshot, } diff --git a/custom/s3/buckets/dao.go b/custom/s3/buckets/dao.go index e3a62336..724cf00d 100644 --- a/custom/s3/buckets/dao.go +++ b/custom/s3/buckets/dao.go @@ -78,6 +78,7 @@ func (d *BucketDAO) Get(ctx context.Context, id string) (dao.Resource, error) { BaseResource: dao.BaseResource{ ID: id, Name: id, + Data: id, }, BucketName: id, Region: region, @@ -273,6 +274,7 @@ func NewBucketResource(bucket types.Bucket) *BucketResource { BaseResource: dao.BaseResource{ ID: name, Name: name, + Data: name, }, BucketName: name, CreationDate: appaws.Time(bucket.CreationDate), diff --git a/custom/sagemaker/endpoints/dao.go b/custom/sagemaker/endpoints/dao.go index 30996292..bf1bea9f 100644 --- a/custom/sagemaker/endpoints/dao.go +++ b/custom/sagemaker/endpoints/dao.go @@ -101,8 +101,9 @@ type EndpointResource struct { func NewEndpointResource(endpoint types.EndpointSummary) *EndpointResource { return &EndpointResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(endpoint.EndpointName), - ARN: appaws.Str(endpoint.EndpointArn), + ID: appaws.Str(endpoint.EndpointName), + ARN: appaws.Str(endpoint.EndpointArn), + Data: endpoint, }, Endpoint: endpoint, } diff --git a/custom/sagemaker/models/dao.go b/custom/sagemaker/models/dao.go index 58023783..8eeb580b 100644 --- a/custom/sagemaker/models/dao.go +++ b/custom/sagemaker/models/dao.go @@ -109,8 +109,9 @@ type ModelResource struct { func NewModelResource(model types.ModelSummary) *ModelResource { return &ModelResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(model.ModelName), - ARN: appaws.Str(model.ModelArn), + ID: appaws.Str(model.ModelName), + ARN: appaws.Str(model.ModelArn), + Data: model, }, Model: model, } diff --git a/custom/sagemaker/notebooks/dao.go b/custom/sagemaker/notebooks/dao.go index 61dc0fdd..b03789b2 100644 --- a/custom/sagemaker/notebooks/dao.go +++ b/custom/sagemaker/notebooks/dao.go @@ -124,8 +124,9 @@ type NotebookResource struct { func NewNotebookResource(notebook types.NotebookInstanceSummary) *NotebookResource { return &NotebookResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(notebook.NotebookInstanceName), - ARN: appaws.Str(notebook.NotebookInstanceArn), + ID: appaws.Str(notebook.NotebookInstanceName), + ARN: appaws.Str(notebook.NotebookInstanceArn), + Data: notebook, }, Notebook: notebook, } diff --git a/custom/sagemaker/training-jobs/dao.go b/custom/sagemaker/training-jobs/dao.go index be4c6264..ae207e24 100644 --- a/custom/sagemaker/training-jobs/dao.go +++ b/custom/sagemaker/training-jobs/dao.go @@ -159,8 +159,9 @@ type TrainingJobResource struct { func NewTrainingJobResource(job types.TrainingJobSummary) *TrainingJobResource { return &TrainingJobResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(job.TrainingJobName), - ARN: appaws.Str(job.TrainingJobArn), + ID: appaws.Str(job.TrainingJobName), + ARN: appaws.Str(job.TrainingJobArn), + Data: job, }, Job: job, } diff --git a/custom/securityhub/findings/dao.go b/custom/securityhub/findings/dao.go index df4fc97a..1ded46bf 100644 --- a/custom/securityhub/findings/dao.go +++ b/custom/securityhub/findings/dao.go @@ -43,7 +43,7 @@ func (d *FindingDAO) List(ctx context.Context) ([]dao.Resource, error) { func (d *FindingDAO) ListPage(ctx context.Context, pageSize int, pageToken string) ([]dao.Resource, string, error) { maxResults := int32(pageSize) if maxResults > 100 { - maxResults = 100 // AWS API max + maxResults = 100 } input := &securityhub.GetFindingsInput{ @@ -53,6 +53,18 @@ func (d *FindingDAO) ListPage(ctx context.Context, pageSize int, pageToken strin input.NextToken = &pageToken } + showResolved := dao.GetFilterFromContext(ctx, "ShowResolved") + if showResolved != "true" { + input.Filters = &types.AwsSecurityFindingFilters{ + RecordState: []types.StringFilter{ + {Value: stringPtr("ACTIVE"), Comparison: types.StringFilterComparisonEquals}, + }, + WorkflowStatus: []types.StringFilter{ + {Value: stringPtr("RESOLVED"), Comparison: types.StringFilterComparisonNotEquals}, + }, + } + } + output, err := d.client.GetFindings(ctx, input) if err != nil { return nil, "", apperrors.Wrap(err, "get security hub findings") @@ -71,6 +83,8 @@ func (d *FindingDAO) ListPage(ctx context.Context, pageSize int, pageToken strin return resources, nextToken, nil } +func stringPtr(s string) *string { return &s } + // Get returns a specific finding by ID. func (d *FindingDAO) Get(ctx context.Context, id string) (dao.Resource, error) { output, err := d.client.GetFindings(ctx, &securityhub.GetFindingsInput{ @@ -107,8 +121,9 @@ type FindingResource struct { func NewFindingResource(finding types.AwsSecurityFinding) *FindingResource { return &FindingResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(finding.Id), - ARN: appaws.Str(finding.Id), + ID: appaws.Str(finding.Id), + ARN: appaws.Str(finding.Id), + Data: finding, }, Item: finding, } diff --git a/custom/securityhub/findings/render.go b/custom/securityhub/findings/render.go index 12d0fd98..ac45bc2b 100644 --- a/custom/securityhub/findings/render.go +++ b/custom/securityhub/findings/render.go @@ -196,3 +196,9 @@ func (r *FindingRenderer) RenderSummary(resource dao.Resource) []render.SummaryF return fields } + +func (r *FindingRenderer) ListToggles() []render.Toggle { + return []render.Toggle{ + {Key: "r", ContextKey: "ShowResolved", LabelOn: "all", LabelOff: "active"}, + } +} diff --git a/custom/transcribe/jobs/dao.go b/custom/transcribe/jobs/dao.go index 12399d29..81da54d8 100644 --- a/custom/transcribe/jobs/dao.go +++ b/custom/transcribe/jobs/dao.go @@ -85,8 +85,9 @@ type JobResource struct { func NewJobResource(job types.TranscriptionJobSummary) *JobResource { return &JobResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(job.TranscriptionJobName), - ARN: "", + ID: appaws.Str(job.TranscriptionJobName), + ARN: "", + Data: job, }, Summary: &job, } @@ -96,8 +97,9 @@ func NewJobResource(job types.TranscriptionJobSummary) *JobResource { func NewJobResourceFromDetail(job types.TranscriptionJob) *JobResource { return &JobResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(job.TranscriptionJobName), - ARN: "", + ID: appaws.Str(job.TranscriptionJobName), + ARN: "", + Data: job, }, Detail: &job, } diff --git a/custom/transfer/servers/dao.go b/custom/transfer/servers/dao.go index 6f6e736b..614d6ba6 100644 --- a/custom/transfer/servers/dao.go +++ b/custom/transfer/servers/dao.go @@ -84,8 +84,9 @@ type ServerResource struct { func NewServerResource(srv types.ListedServer) *ServerResource { return &ServerResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(srv.ServerId), - ARN: appaws.Str(srv.Arn), + ID: appaws.Str(srv.ServerId), + ARN: appaws.Str(srv.Arn), + Data: srv, }, Summary: &srv, } @@ -95,8 +96,9 @@ func NewServerResource(srv types.ListedServer) *ServerResource { func NewServerResourceFromDetail(srv types.DescribedServer) *ServerResource { return &ServerResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(srv.ServerId), - ARN: appaws.Str(srv.Arn), + ID: appaws.Str(srv.ServerId), + ARN: appaws.Str(srv.Arn), + Data: srv, }, Detail: &srv, } diff --git a/custom/transfer/users/dao.go b/custom/transfer/users/dao.go index 55f7b892..1665f6f5 100644 --- a/custom/transfer/users/dao.go +++ b/custom/transfer/users/dao.go @@ -104,8 +104,9 @@ type UserResource struct { func NewUserResource(user types.ListedUser, serverId string) *UserResource { return &UserResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(user.UserName), - ARN: appaws.Str(user.Arn), + ID: appaws.Str(user.UserName), + ARN: appaws.Str(user.Arn), + Data: user, }, Summary: &user, ServerId: serverId, @@ -116,8 +117,9 @@ func NewUserResource(user types.ListedUser, serverId string) *UserResource { func NewUserResourceFromDetail(user types.DescribedUser, serverId string) *UserResource { return &UserResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(user.UserName), - ARN: appaws.Str(user.Arn), + ID: appaws.Str(user.UserName), + ARN: appaws.Str(user.Arn), + Data: user, }, Detail: &user, ServerId: serverId, diff --git a/custom/vpc/endpoints/dao.go b/custom/vpc/endpoints/dao.go index fecccd32..54cc41f0 100644 --- a/custom/vpc/endpoints/dao.go +++ b/custom/vpc/endpoints/dao.go @@ -88,8 +88,9 @@ type VpcEndpointResource struct { func NewVpcEndpointResource(endpoint types.VpcEndpoint) *VpcEndpointResource { return &VpcEndpointResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(endpoint.VpcEndpointId), - ARN: "", + ID: appaws.Str(endpoint.VpcEndpointId), + ARN: "", + Data: endpoint, }, Item: endpoint, } diff --git a/custom/vpc/tgw-attachments/dao.go b/custom/vpc/tgw-attachments/dao.go index 4f0a46e9..cc4961b3 100644 --- a/custom/vpc/tgw-attachments/dao.go +++ b/custom/vpc/tgw-attachments/dao.go @@ -116,8 +116,9 @@ type TGWAttachmentResource struct { func NewTGWAttachmentResource(att types.TransitGatewayAttachment) *TGWAttachmentResource { return &TGWAttachmentResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(att.TransitGatewayAttachmentId), - ARN: "", + ID: appaws.Str(att.TransitGatewayAttachmentId), + ARN: "", + Data: att, }, Item: att, } diff --git a/custom/vpc/transit-gateways/dao.go b/custom/vpc/transit-gateways/dao.go index 1466a002..4c6141b9 100644 --- a/custom/vpc/transit-gateways/dao.go +++ b/custom/vpc/transit-gateways/dao.go @@ -88,8 +88,9 @@ type TransitGatewayResource struct { func NewTransitGatewayResource(tgw types.TransitGateway) *TransitGatewayResource { return &TransitGatewayResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(tgw.TransitGatewayId), - ARN: appaws.Str(tgw.TransitGatewayArn), + ID: appaws.Str(tgw.TransitGatewayId), + ARN: appaws.Str(tgw.TransitGatewayArn), + Data: tgw, }, Item: tgw, } diff --git a/custom/xray/groups/dao.go b/custom/xray/groups/dao.go index 891dc0bd..1b223b19 100644 --- a/custom/xray/groups/dao.go +++ b/custom/xray/groups/dao.go @@ -89,8 +89,9 @@ func NewGroupResource(group types.GroupSummary) *GroupResource { } return &GroupResource{ BaseResource: dao.BaseResource{ - ID: appaws.Str(group.GroupName), - ARN: appaws.Str(group.GroupARN), + ID: appaws.Str(group.GroupName), + ARN: appaws.Str(group.GroupARN), + Data: group, }, Item: &group, InsightsEnabled: insightsEnabled, diff --git a/docs/ai-chat.md b/docs/ai-chat.md new file mode 100644 index 00000000..80080914 --- /dev/null +++ b/docs/ai-chat.md @@ -0,0 +1,147 @@ +# AI Chat + +AI Chat provides an intelligent assistant that helps you analyze AWS resources, compare configurations, identify security risks, and navigate documentation. + +## Overview + +Press `A` in the following views to open AI Chat: +- **Resource Browser** (list view) - Analyzes visible resources +- **Detail View** - Analyzes the selected resource +- **Diff View** - Compares two resources side-by-side + +The assistant has access to: +- Current resource context (what you're viewing) +- Active AWS profile and region +- Tools to query resources, fetch logs, and search AWS documentation + +## Setup + +### 1. IAM Permissions + +The AI Chat feature uses Amazon Bedrock. You need the following permission: + +```json +{ + "Effect": "Allow", + "Action": "bedrock:InvokeModelWithResponseStream", + "Resource": "arn:aws:bedrock:*::foundation-model/*" +} +``` + +See [IAM Permissions](iam-permissions.md#ai-chat-optional) for details. + +### 2. Configuration + +Configure AI Chat in `~/.config/claws/config.yaml`: + +```yaml +ai: + profile: "" # AWS profile for Bedrock (empty = use current profile) + region: "" # AWS region for Bedrock (empty = use current region) + model: "global.anthropic.claude-haiku-4-5-20251001-v1:0" # Bedrock model ID + max_sessions: 100 # Max stored sessions (default: 100) + max_tokens: 16000 # Max response tokens (default: 16000) + thinking_budget: 8000 # Extended thinking token budget (default: 8000) + max_tool_rounds: 15 # Max tool execution rounds per message (default: 15) + max_tool_calls_per_query: 50 # Max tool calls per user query (default: 50) + save_sessions: false # Persist chat sessions to disk (default: false) +``` + +See [Configuration](configuration.md) for all options. + +## Usage + +### Opening Chat + +Press `A` in list/detail/diff views to open the AI Chat overlay. + +### What the AI Can Do + +- List and query AWS resources across services and regions +- Get detailed information about specific resources +- Fetch CloudWatch logs for supported resources (Lambda, ECS, CodeBuild, etc.) +- Search AWS documentation + +The AI automatically uses the current profile, region, and resource context from your view. + +### Context Awareness + +The assistant automatically receives context based on your current view: + +**Resource Browser (List View)**: +``` +Currently viewing: ec2/instances (us-west-2, production profile) +Visible resources: [i-abc123, i-def456, ...] +``` + +**Detail View**: +``` +Currently viewing: ec2/instances/i-abc123 (us-west-2, production profile) +Resource details: {...} +``` + +**Diff View**: +``` +Comparing two resources: +Left: ec2/instances/i-abc123 +Right: ec2/instances/i-def456 +``` + +### Session History + +Press `Ctrl+H` to view and resume previous chat sessions. + +## Keyboard Shortcuts + +| Key | Action | +|-----|--------| +| `A` | Open AI Chat (in list/detail/diff views) | +| `Ctrl+H` | Session history | +| `Enter` | Send message | +| `Esc` | Close chat / Cancel stream | +| `Ctrl+C` | Cancel stream | + +## Extended Thinking + +The assistant supports extended thinking for complex queries. When enabled, you'll see a thinking indicator showing the assistant's reasoning process before the final response. + +Configure thinking budget in config.yaml: +```yaml +ai: + thinking_budget: 8000 # Max tokens for extended thinking (default: 8000) +``` + +## Troubleshooting + +### "Bedrock not available in this region" + +Bedrock is not available in all AWS regions. Configure a supported region in your config: + +```yaml +ai: + region: "us-west-2" # Use a region where Bedrock is available +``` + +### "Access Denied" errors + +Ensure your IAM role/user has the required Bedrock permissions. See [IAM Permissions](iam-permissions.md#ai-chat-optional). + +### Tool call limit reached + +If you see "Tool call limit reached", the assistant made too many tool calls in a single query. Increase the limit: + +```yaml +ai: + max_tool_calls_per_query: 100 # Increase from default 50 +``` + +### Session not persisting + +Enable session persistence in config: + +```yaml +ai: + save_sessions: true # Default: false +``` + +Sessions are stored in `~/.config/claws/sessions/`. diff --git a/docs/configuration.md b/docs/configuration.md index 521b8bfc..612f9241 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -40,6 +40,17 @@ startup: # Applied on launch if present navigation: max_stack_size: 100 # Max navigation history depth (default: 100) +ai: + profile: "" # AWS profile for Bedrock (empty = use current profile) + region: "" # AWS region for Bedrock (empty = use current region) + model: "global.anthropic.claude-haiku-4-5-20251001-v1:0" # Bedrock model ID + max_sessions: 100 # Max stored sessions (default: 100) + max_tokens: 16000 # Max response tokens (default: 16000) + thinking_budget: 8000 # Extended thinking token budget (default: 8000) + max_tool_rounds: 15 # Max tool execution rounds per message (default: 15) + max_tool_calls_per_query: 50 # Max tool calls per user query (default: 50) + save_sessions: false # Persist chat sessions to disk (default: false) + theme: nord # Preset: dark, light, nord, dracula, gruvbox, catppuccin # Or use preset with custom overrides: diff --git a/docs/iam-permissions.md b/docs/iam-permissions.md index 64ce2a41..e8628d02 100644 --- a/docs/iam-permissions.md +++ b/docs/iam-permissions.md @@ -6,6 +6,33 @@ claws requires appropriate IAM permissions to access AWS resources. The permissi For basic read-only browsing, claws needs `Describe*`, `List*`, and `Get*` permissions for the services you want to access. +## AI Chat (Optional) + +The AI Chat feature (`A` key) uses Amazon Bedrock. To enable this feature, you need: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": "bedrock:InvokeModelWithResponseStream", + "Resource": "arn:aws:bedrock:*::foundation-model/*" + }, + { + "Effect": "Allow", + "Action": [ + "aws-marketplace:Subscribe", + "aws-marketplace:ViewSubscriptions" + ], + "Resource": "*" + } + ] +} +``` + +**Note**: AWS Marketplace permissions are required for first-time model usage in your account. If the model is already enabled, only the `bedrock:InvokeModelWithResponseStream` permission is needed. + ## Inline Metrics (Optional) To display inline CloudWatch metrics (toggle with `M` key), you need: @@ -37,7 +64,7 @@ Some resource actions require additional permissions: ## Recommended Policy -For full read-only access with metrics: +For full read-only access with metrics and AI chat: ```json { @@ -57,6 +84,19 @@ For full read-only access with metrics: "iam:Get*" ], "Resource": "*" + }, + { + "Effect": "Allow", + "Action": "bedrock:InvokeModelWithResponseStream", + "Resource": "arn:aws:bedrock:*::foundation-model/*" + }, + { + "Effect": "Allow", + "Action": [ + "aws-marketplace:Subscribe", + "aws-marketplace:ViewSubscriptions" + ], + "Resource": "*" } ] } diff --git a/docs/images/ai-chat.png b/docs/images/ai-chat.png new file mode 100644 index 00000000..5fb35052 Binary files /dev/null and b/docs/images/ai-chat.png differ diff --git a/docs/keybindings.md b/docs/keybindings.md index 406e74c5..6ee38720 100644 --- a/docs/keybindings.md +++ b/docs/keybindings.md @@ -22,6 +22,7 @@ Complete reference for all keyboard shortcuts in claws. | `:pulse` | Go to dashboard | | `:services` | Go to service browser | | `/` | Filter mode (fuzzy search) | +| `A` | AI Chat (Bedrock) | | `?` | Show help | ## Resource Browser @@ -88,6 +89,7 @@ These shortcuts navigate to related resources based on the current context: | `l` | View CloudWatch Logs | | `o` | View Outputs / Operations | | `i` | View Images / Indexes | +| `D` | View Data Sources (AppSync) / Task Definitions (ECS) | ## Region Selector (`R` key) diff --git a/go.mod b/go.mod index 1aaafb5d..317e021b 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/bedrock v1.53.0 github.com/aws/aws-sdk-go-v2/service/bedrockagent v1.52.2 github.com/aws/aws-sdk-go-v2/service/bedrockagentcorecontrol v1.15.1 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.47.1 github.com/aws/aws-sdk-go-v2/service/budgets v1.42.3 github.com/aws/aws-sdk-go-v2/service/cloudformation v1.71.4 github.com/aws/aws-sdk-go-v2/service/cloudfront v1.58.3 @@ -83,6 +84,8 @@ require ( github.com/aws/smithy-go v1.24.0 github.com/charmbracelet/x/ansi v0.11.3 github.com/creack/pty v1.1.24 + github.com/google/uuid v1.6.0 + github.com/mattn/go-runewidth v0.0.19 golang.org/x/sync v0.19.0 golang.org/x/term v0.38.0 gopkg.in/ini.v1 v1.67.0 @@ -114,7 +117,6 @@ require ( github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.3.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect - github.com/mattn/go-runewidth v0.0.19 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/stretchr/testify v1.11.1 // indirect diff --git a/go.sum b/go.sum index 8b157de5..a293a418 100644 --- a/go.sum +++ b/go.sum @@ -50,6 +50,8 @@ github.com/aws/aws-sdk-go-v2/service/bedrockagent v1.52.2 h1:jrOALh0fIx8kUfesQS4 github.com/aws/aws-sdk-go-v2/service/bedrockagent v1.52.2/go.mod h1:hRzcNxU8BOG5ijgeMDLyw0sx4fBOxrjPDB/DnDK6X1M= github.com/aws/aws-sdk-go-v2/service/bedrockagentcorecontrol v1.15.1 h1:BJmfQWd/3kjWCw3zkS3lSZ9uVwo9jsDGfW8g4EG2xbY= github.com/aws/aws-sdk-go-v2/service/bedrockagentcorecontrol v1.15.1/go.mod h1:3zWDBnJEUh72XdC7iEqdCSwPwDuveVsKTmtThuGwC2s= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.47.1 h1:xryaVPvLLcCf7Y/4beWjOcWxiftorB/KDjtiYORVSNo= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.47.1/go.mod h1:ckSglleOJ2avj81L6vBb70nK51cnhTwvVK1SkLgFtj4= github.com/aws/aws-sdk-go-v2/service/budgets v1.42.3 h1:SWmlAqhAeh9ByGn56CLqJEEFwd1tsDM1t9ojTcxpnvo= github.com/aws/aws-sdk-go-v2/service/budgets v1.42.3/go.mod h1:MBllv8Mjt8gp2rBU+iA5L6QabvS5L00LSru/ICHld7M= github.com/aws/aws-sdk-go-v2/service/cloudformation v1.71.4 h1:9dwMueqbHIp0KTw2Zt0rhVobiPMlAI8UgyxiaBzM+1E= @@ -210,6 +212,8 @@ github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= diff --git a/internal/ai/bedrock.go b/internal/ai/bedrock.go new file mode 100644 index 00000000..cca144af --- /dev/null +++ b/internal/ai/bedrock.go @@ -0,0 +1,459 @@ +package ai + +import ( + "context" + "encoding/json" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/document" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + + appaws "github.com/clawscli/claws/internal/aws" + appconfig "github.com/clawscli/claws/internal/config" + apperrors "github.com/clawscli/claws/internal/errors" + "github.com/clawscli/claws/internal/log" +) + +// Role represents the role of a message participant. +type Role string + +const ( + RoleUser Role = "user" + RoleAssistant Role = "assistant" +) + +// StopReason indicates why the model stopped generating. +type StopReason string + +const ( + StopReasonEndTurn StopReason = "end_turn" + StopReasonToolUse StopReason = "tool_use" + StopReasonMaxTokens StopReason = "max_tokens" +) + +// Message represents a single message in a conversation. +// Each message contains one or more ContentBlocks. +type Message struct { + Role Role `json:"role"` + Content []ContentBlock `json:"content"` +} + +// ContentBlock represents a content element within a message. +// Only one field should be set at a time. +type ContentBlock struct { + // Text content + Text string `json:"text,omitempty"` + + // Tool use request from LLM + ToolUse *ToolUseContent `json:"toolUse,omitempty"` + + // Tool result from application + ToolResult *ToolResultContent `json:"toolResult,omitempty"` + + // Extended Thinking (Reasoning content from Bedrock API) + Reasoning string `json:"reasoning,omitempty"` + ReasoningSignature string `json:"reasoningSignature,omitempty"` +} + +// ToolUseContent represents a tool invocation request from the LLM. +type ToolUseContent struct { + ID string `json:"toolUseId"` + Name string `json:"name"` + Input map[string]any `json:"input"` + InputError string `json:"-"` +} + +// ToolResultContent represents the result of a tool execution. +type ToolResultContent struct { + ToolUseID string `json:"toolUseId"` + Content string `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +// StreamEvent represents an event from streaming response. +type StreamEvent struct { + Type string + Text string + Thinking *ThinkingContent + ToolUse *ToolUseContent + StopReason StopReason + Error error +} + +// ThinkingContent represents thinking/reasoning content. +type ThinkingContent struct { + Text string + Signature string +} + +// Tool represents a tool definition for the LLM. +type Tool struct { + Name string + Description string + InputSchema map[string]any +} + +// Client wraps the Bedrock runtime client. +type Client struct { + client *bedrockruntime.Client + modelID string + tools []Tool + maxTokens int32 + thinkingBudget int +} + +type ClientOption func(*Client) + +func WithModel(modelID string) ClientOption { + return func(c *Client) { + c.modelID = modelID + } +} + +func WithTools(tools []Tool) ClientOption { + return func(c *Client) { + c.tools = tools + } +} + +func WithMaxTokens(maxTokens int) ClientOption { + return func(c *Client) { + c.maxTokens = int32(maxTokens) + } +} + +func WithThinkingBudget(budget int) ClientOption { + return func(c *Client) { + c.thinkingBudget = budget + } +} + +func NewClient(ctx context.Context, opts ...ClientOption) (*Client, error) { + // Use AI-specific profile/region if configured + fileCfg := appconfig.File() + if profile := fileCfg.GetAIProfile(); profile != "" { + ctx = appaws.WithSelectionOverride(ctx, appconfig.ProfileSelectionFromID(profile)) + } + if region := fileCfg.GetAIRegion(); region != "" { + ctx = appaws.WithRegionOverride(ctx, region) + } + + awsCfg, err := appaws.NewConfig(ctx) + if err != nil { + return nil, apperrors.Wrap(err, "load aws config") + } + + c := &Client{ + client: bedrockruntime.NewFromConfig(awsCfg), + } + + for _, opt := range opts { + opt(c) + } + + return c, nil +} + +// ConverseStream sends a streaming request and returns a channel of events. +func (c *Client) ConverseStream(ctx context.Context, messages []Message, systemPrompt string) (<-chan StreamEvent, error) { + input := c.buildConverseStreamInput(messages, systemPrompt) + + output, err := c.client.ConverseStream(ctx, input) + if err != nil { + return nil, apperrors.Wrap(err, "converse stream") + } + + events := make(chan StreamEvent, 10) + go c.processStream(ctx, output, events) + + return events, nil +} + +func (c *Client) buildConverseStreamInput(messages []Message, systemPrompt string) *bedrockruntime.ConverseStreamInput { + log.Debug("buildConverseStreamInput", "modelID", c.modelID, "maxTokens", c.maxTokens, "thinkingBudget", c.thinkingBudget) + input := &bedrockruntime.ConverseStreamInput{ + ModelId: aws.String(c.modelID), + Messages: convertMessages(messages), + } + + if systemPrompt != "" { + input.System = []types.SystemContentBlock{ + &types.SystemContentBlockMemberText{Value: systemPrompt}, + } + } + + if len(c.tools) > 0 { + input.ToolConfig = c.buildToolConfig() + } + + if c.maxTokens > 0 { + input.InferenceConfig = &types.InferenceConfiguration{ + MaxTokens: aws.Int32(c.maxTokens), + } + } + + if c.thinkingBudget > 0 && strings.Contains(c.modelID, "anthropic.claude") { + log.Debug("applying thinking config", "budget", c.thinkingBudget) + thinkingConfig := map[string]any{ + "thinking": map[string]any{ + "type": "enabled", + "budget_tokens": c.thinkingBudget, + }, + "anthropic_beta": []string{"interleaved-thinking-2025-05-14"}, + } + input.AdditionalModelRequestFields = document.NewLazyDocument(thinkingConfig) + if input.InferenceConfig == nil { + input.InferenceConfig = &types.InferenceConfiguration{} + } + input.InferenceConfig.Temperature = aws.Float32(1.0) + } + + return input +} + +// convertMessages converts our Message type to Bedrock API types. +func convertMessages(messages []Message) []types.Message { + result := make([]types.Message, len(messages)) + for i, msg := range messages { + result[i] = types.Message{ + Role: types.ConversationRole(msg.Role), + Content: convertContentBlocks(msg.Content), + } + } + return result +} + +// convertContentBlocks converts our ContentBlock to Bedrock API types. +// Based on dt's implementation. +func convertContentBlocks(blocks []ContentBlock) []types.ContentBlock { + result := make([]types.ContentBlock, 0, len(blocks)) + for _, block := range blocks { + if block.Text != "" { + result = append(result, &types.ContentBlockMemberText{Value: block.Text}) + } + if block.ToolUse != nil { + result = append(result, &types.ContentBlockMemberToolUse{ + Value: types.ToolUseBlock{ + ToolUseId: aws.String(block.ToolUse.ID), + Name: aws.String(block.ToolUse.Name), + Input: document.NewLazyDocument(block.ToolUse.Input), + }, + }) + } + if block.ToolResult != nil { + status := types.ToolResultStatusSuccess + if block.ToolResult.IsError { + status = types.ToolResultStatusError + } + result = append(result, &types.ContentBlockMemberToolResult{ + Value: types.ToolResultBlock{ + ToolUseId: aws.String(block.ToolResult.ToolUseID), + Status: status, + Content: []types.ToolResultContentBlock{ + &types.ToolResultContentBlockMemberText{Value: block.ToolResult.Content}, + }, + }, + }) + } + if block.Reasoning != "" { + reasoningBlock := types.ReasoningTextBlock{ + Text: aws.String(block.Reasoning), + } + if block.ReasoningSignature != "" { + reasoningBlock.Signature = aws.String(block.ReasoningSignature) + } + result = append(result, &types.ContentBlockMemberReasoningContent{ + Value: &types.ReasoningContentBlockMemberReasoningText{ + Value: reasoningBlock, + }, + }) + } + } + return result +} + +func (c *Client) buildToolConfig() *types.ToolConfiguration { + toolDefs := make([]types.Tool, 0, len(c.tools)) + + for _, t := range c.tools { + toolDefs = append(toolDefs, &types.ToolMemberToolSpec{ + Value: types.ToolSpecification{ + Name: aws.String(t.Name), + Description: aws.String(t.Description), + InputSchema: &types.ToolInputSchemaMemberJson{ + Value: document.NewLazyDocument(t.InputSchema), + }, + }, + }) + } + + return &types.ToolConfiguration{ + Tools: toolDefs, + } +} + +// processStream processes the streaming response from Bedrock. +// Based on dt's implementation. +func (c *Client) processStream(ctx context.Context, output *bedrockruntime.ConverseStreamOutput, events chan<- StreamEvent) { + defer close(events) + + stream := output.GetStream() + defer func() { + if err := stream.Close(); err != nil { + log.Debug("stream close error", "error", err) + } + }() + + // Track current content block state + var currentToolUse *ToolUseContent + var toolInputBuffer string + + var thinkingText string + var thinkingSignature string + var isThinkingBlock bool + + for event := range stream.Events() { + select { + case <-ctx.Done(): + events <- StreamEvent{Type: "error", Error: ctx.Err()} + return + default: + } + + switch e := event.(type) { + case *types.ConverseStreamOutputMemberContentBlockStart: + // Start of a new content block + start := e.Value.Start + switch s := start.(type) { + case *types.ContentBlockStartMemberToolUse: + // Initialize tool use tracking + currentToolUse = &ToolUseContent{ + ID: aws.ToString(s.Value.ToolUseId), + Name: aws.ToString(s.Value.Name), + } + toolInputBuffer = "" + } + + case *types.ConverseStreamOutputMemberContentBlockDelta: + switch delta := e.Value.Delta.(type) { + case *types.ContentBlockDeltaMemberText: + events <- StreamEvent{Type: "text", Text: delta.Value} + case *types.ContentBlockDeltaMemberReasoningContent: + // Mark that we're processing a thinking block + isThinkingBlock = true + + // Process reasoning delta types + switch reasoningDelta := delta.Value.(type) { + case *types.ReasoningContentBlockDeltaMemberText: + thinkingText += reasoningDelta.Value + // Stream text chunks as they arrive + events <- StreamEvent{ + Type: "thinking", + Thinking: &ThinkingContent{Text: reasoningDelta.Value}, + } + case *types.ReasoningContentBlockDeltaMemberSignature: + thinkingSignature = reasoningDelta.Value + case *types.ReasoningContentBlockDeltaMemberRedactedContent: + // Redacted content - ignore + } + case *types.ContentBlockDeltaMemberToolUse: + // Accumulate tool use input + if currentToolUse != nil { + toolInputBuffer += aws.ToString(delta.Value.Input) + } + } + + case *types.ConverseStreamOutputMemberContentBlockStop: + if currentToolUse != nil { + var input map[string]any + if err := json.Unmarshal([]byte(toolInputBuffer), &input); err != nil { + log.Debug("failed to parse tool input JSON", "error", err) + input = make(map[string]any) + currentToolUse.InputError = err.Error() + } + currentToolUse.Input = input + + events <- StreamEvent{ + Type: "tool_use", + ToolUse: currentToolUse, + } + + currentToolUse = nil + toolInputBuffer = "" + } + + // If we were processing a thinking block, send the complete version with signature + if isThinkingBlock { + // Send complete thinking event with both text and signature + events <- StreamEvent{ + Type: "thinking_complete", + Thinking: &ThinkingContent{ + Text: thinkingText, + Signature: thinkingSignature, + }, + } + + // Reset thinking state + thinkingText = "" + thinkingSignature = "" + isThinkingBlock = false + } + + case *types.ConverseStreamOutputMemberMessageStop: + events <- StreamEvent{ + Type: "done", + StopReason: convertStopReason(e.Value.StopReason), + } + return + } + } + + if err := stream.Err(); err != nil { + events <- StreamEvent{Type: "error", Error: err} + } +} + +func convertStopReason(reason types.StopReason) StopReason { + switch reason { + case types.StopReasonEndTurn: + return StopReasonEndTurn + case types.StopReasonToolUse: + return StopReasonToolUse + case types.StopReasonMaxTokens: + return StopReasonMaxTokens + default: + return StopReasonEndTurn + } +} + +// Helper functions for building messages + +// NewUserMessage creates a user message with text content. +func NewUserMessage(text string) Message { + return Message{ + Role: RoleUser, + Content: []ContentBlock{{Text: text}}, + } +} + +// NewAssistantMessage creates an assistant message with content blocks. +func NewAssistantMessage(blocks ...ContentBlock) Message { + return Message{ + Role: RoleAssistant, + Content: blocks, + } +} + +// NewToolResultMessage creates a user message with tool results. +func NewToolResultMessage(results ...ToolResultContent) Message { + blocks := make([]ContentBlock, len(results)) + for i, r := range results { + blocks[i] = ContentBlock{ToolResult: &r} + } + return Message{ + Role: RoleUser, + Content: blocks, + } +} diff --git a/internal/ai/bedrock_test.go b/internal/ai/bedrock_test.go new file mode 100644 index 00000000..9859d7a3 --- /dev/null +++ b/internal/ai/bedrock_test.go @@ -0,0 +1,287 @@ +package ai + +import ( + "testing" +) + +func TestNewUserMessage(t *testing.T) { + msg := NewUserMessage("hello world") + + if msg.Role != RoleUser { + t.Errorf("expected role %q, got %q", RoleUser, msg.Role) + } + if len(msg.Content) != 1 { + t.Fatalf("expected 1 content block, got %d", len(msg.Content)) + } + if msg.Content[0].Text != "hello world" { + t.Errorf("expected text %q, got %q", "hello world", msg.Content[0].Text) + } +} + +func TestNewAssistantMessage(t *testing.T) { + blocks := []ContentBlock{ + {Text: "response text"}, + {ToolUse: &ToolUseContent{ID: "123", Name: "test_tool", Input: map[string]any{"key": "value"}}}, + } + msg := NewAssistantMessage(blocks...) + + if msg.Role != RoleAssistant { + t.Errorf("expected role %q, got %q", RoleAssistant, msg.Role) + } + if len(msg.Content) != 2 { + t.Fatalf("expected 2 content blocks, got %d", len(msg.Content)) + } + if msg.Content[0].Text != "response text" { + t.Errorf("expected text %q, got %q", "response text", msg.Content[0].Text) + } + if msg.Content[1].ToolUse == nil { + t.Fatal("expected tool use block") + } + if msg.Content[1].ToolUse.Name != "test_tool" { + t.Errorf("expected tool name %q, got %q", "test_tool", msg.Content[1].ToolUse.Name) + } +} + +func TestNewToolResultMessage(t *testing.T) { + results := []ToolResultContent{ + {ToolUseID: "123", Content: "success result", IsError: false}, + {ToolUseID: "456", Content: "error message", IsError: true}, + } + msg := NewToolResultMessage(results...) + + if msg.Role != RoleUser { + t.Errorf("expected role %q, got %q", RoleUser, msg.Role) + } + if len(msg.Content) != 2 { + t.Fatalf("expected 2 content blocks, got %d", len(msg.Content)) + } + if msg.Content[0].ToolResult == nil { + t.Fatal("expected tool result block") + } + if msg.Content[0].ToolResult.ToolUseID != "123" { + t.Errorf("expected tool use ID %q, got %q", "123", msg.Content[0].ToolResult.ToolUseID) + } + if msg.Content[0].ToolResult.IsError { + t.Error("expected IsError to be false for first result") + } + if !msg.Content[1].ToolResult.IsError { + t.Error("expected IsError to be true for second result") + } +} + +func TestConvertStopReason(t *testing.T) { + tests := []struct { + name string + input string + expected StopReason + }{ + {"end_turn", "end_turn", StopReasonEndTurn}, + {"tool_use", "tool_use", StopReasonToolUse}, + {"max_tokens", "max_tokens", StopReasonMaxTokens}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if string(tt.expected) != tt.input { + t.Errorf("StopReason constant mismatch: expected %q, got %q", tt.input, tt.expected) + } + }) + } +} + +func TestContentBlockTypes(t *testing.T) { + t.Run("text block", func(t *testing.T) { + block := ContentBlock{Text: "hello"} + if block.Text != "hello" { + t.Errorf("expected text %q, got %q", "hello", block.Text) + } + if block.ToolUse != nil || block.ToolResult != nil { + t.Error("other fields should be nil for text block") + } + }) + + t.Run("tool use block", func(t *testing.T) { + block := ContentBlock{ + ToolUse: &ToolUseContent{ + ID: "tool-123", + Name: "query_resources", + Input: map[string]any{"service": "ec2", "region": "us-east-1"}, + }, + } + if block.ToolUse == nil { + t.Fatal("expected tool use") + } + if block.ToolUse.ID != "tool-123" { + t.Errorf("expected ID %q, got %q", "tool-123", block.ToolUse.ID) + } + if block.ToolUse.Input["service"] != "ec2" { + t.Errorf("expected service %q, got %v", "ec2", block.ToolUse.Input["service"]) + } + }) + + t.Run("tool result block", func(t *testing.T) { + block := ContentBlock{ + ToolResult: &ToolResultContent{ + ToolUseID: "tool-123", + Content: "Found 5 instances", + IsError: false, + }, + } + if block.ToolResult == nil { + t.Fatal("expected tool result") + } + if block.ToolResult.Content != "Found 5 instances" { + t.Errorf("expected content %q, got %q", "Found 5 instances", block.ToolResult.Content) + } + }) + + t.Run("reasoning block", func(t *testing.T) { + block := ContentBlock{ + Reasoning: "Let me think about this...", + ReasoningSignature: "sig123", + } + if block.Reasoning != "Let me think about this..." { + t.Errorf("expected reasoning %q, got %q", "Let me think about this...", block.Reasoning) + } + if block.ReasoningSignature != "sig123" { + t.Errorf("expected signature %q, got %q", "sig123", block.ReasoningSignature) + } + }) +} + +func TestToolDefinition(t *testing.T) { + tool := Tool{ + Name: "test_tool", + Description: "A test tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "param1": map[string]any{ + "type": "string", + "description": "First parameter", + }, + }, + "required": []string{"param1"}, + }, + } + + if tool.Name != "test_tool" { + t.Errorf("expected name %q, got %q", "test_tool", tool.Name) + } + if tool.Description != "A test tool" { + t.Errorf("expected description %q, got %q", "A test tool", tool.Description) + } + + props, ok := tool.InputSchema["properties"].(map[string]any) + if !ok { + t.Fatal("expected properties to be map") + } + if _, ok := props["param1"]; !ok { + t.Error("expected param1 in properties") + } +} + +func TestClientOptions(t *testing.T) { + t.Run("WithModel", func(t *testing.T) { + c := &Client{} + opt := WithModel("test-model") + opt(c) + if c.modelID != "test-model" { + t.Errorf("expected model %q, got %q", "test-model", c.modelID) + } + }) + + t.Run("WithMaxTokens", func(t *testing.T) { + c := &Client{} + opt := WithMaxTokens(1000) + opt(c) + if c.maxTokens != 1000 { + t.Errorf("expected maxTokens %d, got %d", 1000, c.maxTokens) + } + }) + + t.Run("WithThinkingBudget", func(t *testing.T) { + c := &Client{} + opt := WithThinkingBudget(5000) + opt(c) + if c.thinkingBudget != 5000 { + t.Errorf("expected thinkingBudget %d, got %d", 5000, c.thinkingBudget) + } + }) + + t.Run("WithTools", func(t *testing.T) { + c := &Client{} + tools := []Tool{ + {Name: "tool1"}, + {Name: "tool2"}, + } + opt := WithTools(tools) + opt(c) + if len(c.tools) != 2 { + t.Errorf("expected 2 tools, got %d", len(c.tools)) + } + }) +} + +func TestStreamEvent(t *testing.T) { + t.Run("text event", func(t *testing.T) { + event := StreamEvent{Type: "text", Text: "hello"} + if event.Type != "text" { + t.Errorf("expected type %q, got %q", "text", event.Type) + } + if event.Text != "hello" { + t.Errorf("expected text %q, got %q", "hello", event.Text) + } + }) + + t.Run("thinking event", func(t *testing.T) { + event := StreamEvent{ + Type: "thinking", + Thinking: &ThinkingContent{Text: "reasoning..."}, + } + if event.Thinking == nil { + t.Fatal("expected thinking content") + } + if event.Thinking.Text != "reasoning..." { + t.Errorf("expected thinking text %q, got %q", "reasoning...", event.Thinking.Text) + } + }) + + t.Run("tool_use event", func(t *testing.T) { + event := StreamEvent{ + Type: "tool_use", + ToolUse: &ToolUseContent{ID: "123", Name: "test"}, + } + if event.ToolUse == nil { + t.Fatal("expected tool use") + } + if event.ToolUse.Name != "test" { + t.Errorf("expected tool name %q, got %q", "test", event.ToolUse.Name) + } + }) + + t.Run("done event", func(t *testing.T) { + event := StreamEvent{Type: "done", StopReason: StopReasonEndTurn} + if event.StopReason != StopReasonEndTurn { + t.Errorf("expected stop reason %q, got %q", StopReasonEndTurn, event.StopReason) + } + }) + + t.Run("error event", func(t *testing.T) { + event := StreamEvent{Type: "error", Error: &testError{"test error"}} + if event.Error == nil { + t.Fatal("expected error") + } + if event.Error.Error() != "test error" { + t.Errorf("expected error %q, got %q", "test error", event.Error.Error()) + } + }) +} + +type testError struct { + msg string +} + +func (e *testError) Error() string { + return e.msg +} diff --git a/internal/ai/session.go b/internal/ai/session.go new file mode 100644 index 00000000..538779fb --- /dev/null +++ b/internal/ai/session.go @@ -0,0 +1,350 @@ +package ai + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "time" + + "github.com/google/uuid" + + "github.com/clawscli/claws/internal/config" + "github.com/clawscli/claws/internal/log" +) + +const ( + DefaultMaxSessions = 100 + sessionDir = "chat/sessions" + currentSessionFile = "chat/current.json" +) + +type Session struct { + ID string `json:"id"` + StartedAt time.Time `json:"started_at"` + UpdatedAt time.Time `json:"updated_at"` + Messages []Message `json:"messages"` + Context *Context `json:"context,omitempty"` +} + +type ContextMode string + +const ( + ContextModeSingle ContextMode = "single" + ContextModeList ContextMode = "list" + ContextModeDiff ContextMode = "diff" +) + +type ResourceRef struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Region string `json:"region,omitempty"` + Profile string `json:"profile,omitempty"` + Cluster string `json:"cluster,omitempty"` +} + +type Context struct { + Service string `json:"service,omitempty"` + ResourceType string `json:"resource_type,omitempty"` + UserRegions []string `json:"user_regions,omitempty"` + UserProfiles []string `json:"user_profiles,omitempty"` + Mode ContextMode `json:"mode,omitempty"` + + ResourceID string `json:"resource_id,omitempty"` + ResourceName string `json:"resource_name,omitempty"` + ResourceRegion string `json:"resource_region,omitempty"` + ResourceProfile string `json:"resource_profile,omitempty"` + Cluster string `json:"cluster,omitempty"` + LogGroup string `json:"log_group,omitempty"` + + ResourceCount int `json:"resource_count,omitempty"` + FilterText string `json:"filter_text,omitempty"` + Toggles map[string]bool `json:"toggles,omitempty"` + + DiffLeft *ResourceRef `json:"diff_left,omitempty"` + DiffRight *ResourceRef `json:"diff_right,omitempty"` +} + +type SessionManager struct { + maxSessions int + saveEnabled bool + currentID string +} + +func NewSessionManager(maxSessions int, saveEnabled bool) *SessionManager { + if maxSessions <= 0 { + maxSessions = DefaultMaxSessions + } + return &SessionManager{ + maxSessions: maxSessions, + saveEnabled: saveEnabled, + } +} + +func (m *SessionManager) sessionsDir() (string, error) { + dir, err := config.ConfigDir() + if err != nil { + return "", err + } + return filepath.Join(dir, sessionDir), nil +} + +func (m *SessionManager) currentPath() (string, error) { + dir, err := config.ConfigDir() + if err != nil { + return "", err + } + return filepath.Join(dir, currentSessionFile), nil +} + +func (m *SessionManager) NewSession(ctx *Context) (*Session, error) { + session := &Session{ + ID: generateSessionID(), + StartedAt: time.Now(), + UpdatedAt: time.Now(), + Messages: []Message{}, + Context: ctx, + } + + m.currentID = session.ID + return session, nil +} + +func (m *SessionManager) CurrentSession() (*Session, error) { + if m.currentID == "" { + path, err := m.currentPath() + if err != nil { + return nil, err + } + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + + var current struct { + ID string `json:"id"` + } + if err := json.Unmarshal(data, ¤t); err != nil { + return nil, err + } + m.currentID = current.ID + } + + if m.currentID == "" { + return nil, nil + } + + return m.LoadSession(m.currentID) +} + +func (m *SessionManager) LoadSession(id string) (*Session, error) { + dir, err := m.sessionsDir() + if err != nil { + return nil, err + } + + path := filepath.Join(dir, id+".json") + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var session Session + if err := json.Unmarshal(data, &session); err != nil { + return nil, err + } + + return &session, nil +} + +func (m *SessionManager) SaveMessages(session *Session) error { + session.UpdatedAt = time.Now() + return m.saveSession(session) +} + +func (m *SessionManager) AddMessage(session *Session, msg Message) error { + session.Messages = append(session.Messages, msg) + err := m.SaveMessages(session) + if err != nil { + return err + } + if len(session.Messages) == 1 { + // Check if pruning is needed before loading all sessions + shouldPrune, checkErr := m.shouldPrune() + if checkErr != nil { + log.Debug("failed to check prune status", "error", checkErr) + } else if shouldPrune { + if pruneErr := m.pruneOldSessions(); pruneErr != nil { + log.Debug("failed to prune old sessions", "error", pruneErr) + } + } + } + return nil +} + +func (m *SessionManager) shouldPrune() (bool, error) { + dir, err := m.sessionsDir() + if err != nil { + return false, err + } + + entries, err := os.ReadDir(dir) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + + // Count only .json files + count := 0 + for _, entry := range entries { + if !entry.IsDir() && filepath.Ext(entry.Name()) == ".json" { + count++ + } + } + + return count > m.maxSessions, nil +} + +func (m *SessionManager) ListSessions() ([]Session, error) { + dir, err := m.sessionsDir() + if err != nil { + return nil, err + } + + entries, err := os.ReadDir(dir) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + + var sessions []Session + for _, entry := range entries { + if entry.IsDir() || filepath.Ext(entry.Name()) != ".json" { + continue + } + + id := entry.Name()[:len(entry.Name())-5] + session, err := m.LoadSession(id) + if err != nil { + log.Debug("failed to load session", "id", id, "error", err) + continue + } + sessions = append(sessions, *session) + } + + sort.Slice(sessions, func(i, j int) bool { + return sessions[i].UpdatedAt.After(sessions[j].UpdatedAt) + }) + + return sessions, nil +} + +func (m *SessionManager) saveSession(session *Session) error { + if !m.saveEnabled { + return nil + } + + dir, err := m.sessionsDir() + if err != nil { + return err + } + + if err := os.MkdirAll(dir, 0700); err != nil { + return err + } + + path := filepath.Join(dir, session.ID+".json") + data, err := json.MarshalIndent(session, "", " ") + if err != nil { + return err + } + + if err := os.WriteFile(path, data, 0600); err != nil { + return err + } + + return m.saveCurrentID(session.ID) +} + +func (m *SessionManager) saveCurrentID(id string) error { + path, err := m.currentPath() + if err != nil { + return err + } + + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0700); err != nil { + return err + } + + data, _ := json.Marshal(struct { + ID string `json:"id"` + }{ID: id}) + + return os.WriteFile(path, data, 0600) +} + +func (m *SessionManager) pruneOldSessions() error { + dir, err := m.sessionsDir() + if err != nil { + return err + } + + entries, err := os.ReadDir(dir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + // Collect .json files with modification times + type sessionFile struct { + name string + modTime time.Time + } + var files []sessionFile + for _, entry := range entries { + if !entry.IsDir() && filepath.Ext(entry.Name()) == ".json" { + info, err := entry.Info() + if err != nil { + continue + } + files = append(files, sessionFile{ + name: entry.Name(), + modTime: info.ModTime(), + }) + } + } + + if len(files) <= m.maxSessions { + return nil + } + + // Sort by modification time (oldest first) + sort.Slice(files, func(i, j int) bool { + return files[i].modTime.Before(files[j].modTime) + }) + + // Delete oldest sessions + deleteCount := len(files) - m.maxSessions + for i := 0; i < deleteCount; i++ { + _ = os.Remove(filepath.Join(dir, files[i].name)) + } + + return nil +} + +func generateSessionID() string { + now := time.Now() + return fmt.Sprintf("%s-%s", now.Format("20060102-150405"), uuid.New().String()[:8]) +} diff --git a/internal/ai/session_test.go b/internal/ai/session_test.go new file mode 100644 index 00000000..d9ec4e72 --- /dev/null +++ b/internal/ai/session_test.go @@ -0,0 +1,385 @@ +package ai + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestNewSessionManager(t *testing.T) { + t.Run("default max sessions", func(t *testing.T) { + sm := NewSessionManager(0, false) + if sm.maxSessions != DefaultMaxSessions { + t.Errorf("expected maxSessions %d, got %d", DefaultMaxSessions, sm.maxSessions) + } + }) + + t.Run("custom max sessions", func(t *testing.T) { + sm := NewSessionManager(50, false) + if sm.maxSessions != 50 { + t.Errorf("expected maxSessions %d, got %d", 50, sm.maxSessions) + } + }) + + t.Run("save disabled", func(t *testing.T) { + sm := NewSessionManager(10, false) + if sm.saveEnabled { + t.Error("expected saveEnabled to be false") + } + }) + + t.Run("save enabled", func(t *testing.T) { + sm := NewSessionManager(10, true) + if !sm.saveEnabled { + t.Error("expected saveEnabled to be true") + } + }) +} + +func TestGenerateSessionID(t *testing.T) { + id1 := generateSessionID() + time.Sleep(time.Millisecond) + id2 := generateSessionID() + + if id1 == "" { + t.Error("expected non-empty session ID") + } + if id1 == id2 { + t.Error("expected unique session IDs") + } + + // Check format: YYYY-MM-DD-xxxxxx + if len(id1) < 10 { + t.Errorf("session ID too short: %q", id1) + } +} + +func TestSessionManagerNewSession(t *testing.T) { + sm := NewSessionManager(10, false) // save disabled + + ctx := &Context{ + Service: "ec2", + ResourceType: "instances", + } + + session, err := sm.NewSession(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if session.ID == "" { + t.Error("expected non-empty session ID") + } + if session.Context == nil { + t.Error("expected context") + } + if session.Context.Service != "ec2" { + t.Errorf("expected service %q, got %q", "ec2", session.Context.Service) + } + if len(session.Messages) != 0 { + t.Errorf("expected 0 messages, got %d", len(session.Messages)) + } + if session.StartedAt.IsZero() { + t.Error("expected StartedAt to be set") + } + if session.UpdatedAt.IsZero() { + t.Error("expected UpdatedAt to be set") + } +} + +func TestSessionManagerAddMessage(t *testing.T) { + sm := NewSessionManager(10, false) + + session, err := sm.NewSession(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msg := NewUserMessage("test message") + err = sm.AddMessage(session, msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(session.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(session.Messages)) + } + if session.Messages[0].Content[0].Text != "test message" { + t.Errorf("expected message %q, got %q", "test message", session.Messages[0].Content[0].Text) + } +} + +func TestSessionManagerWithPersistence(t *testing.T) { + // Create temp directory for test + tmpDir, err := os.MkdirTemp("", "claws-session-test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + t.Setenv("HOME", tmpDir) + + sm := NewSessionManager(10, true) // save enabled + + session, err := sm.NewSession(&Context{Service: "lambda"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Add a message + err = sm.AddMessage(session, NewUserMessage("hello")) + if err != nil { + t.Fatalf("failed to add message: %v", err) + } + + // Verify file was created + sessionsDir := filepath.Join(tmpDir, ".config", "claws", "chat", "sessions") + sessionFile := filepath.Join(sessionsDir, session.ID+".json") + if _, err := os.Stat(sessionFile); os.IsNotExist(err) { + t.Error("expected session file to be created") + } + + // Load session + loaded, err := sm.LoadSession(session.ID) + if err != nil { + t.Fatalf("failed to load session: %v", err) + } + + if loaded.ID != session.ID { + t.Errorf("expected ID %q, got %q", session.ID, loaded.ID) + } + if len(loaded.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(loaded.Messages)) + } + if loaded.Context.Service != "lambda" { + t.Errorf("expected service %q, got %q", "lambda", loaded.Context.Service) + } +} + +func TestContext(t *testing.T) { + t.Run("single mode", func(t *testing.T) { + ctx := &Context{ + Mode: ContextModeSingle, + Service: "ec2", + ResourceType: "instances", + ResourceID: "i-12345", + ResourceName: "my-instance", + ResourceRegion: "us-east-1", + } + if ctx.Mode != ContextModeSingle { + t.Errorf("expected mode %q, got %q", ContextModeSingle, ctx.Mode) + } + if ctx.ResourceID != "i-12345" { + t.Errorf("expected resource ID %q, got %q", "i-12345", ctx.ResourceID) + } + }) + + t.Run("list mode", func(t *testing.T) { + ctx := &Context{ + Mode: ContextModeList, + Service: "lambda", + ResourceType: "functions", + ResourceCount: 25, + FilterText: "prod", + } + if ctx.Mode != ContextModeList { + t.Errorf("expected mode %q, got %q", ContextModeList, ctx.Mode) + } + if ctx.ResourceCount != 25 { + t.Errorf("expected count %d, got %d", 25, ctx.ResourceCount) + } + }) + + t.Run("diff mode", func(t *testing.T) { + ctx := &Context{ + Mode: ContextModeDiff, + Service: "rds", + ResourceType: "instances", + DiffLeft: &ResourceRef{ + ID: "db-1", + Name: "prod-db", + Region: "us-east-1", + }, + DiffRight: &ResourceRef{ + ID: "db-2", + Name: "staging-db", + Region: "us-west-2", + }, + } + if ctx.Mode != ContextModeDiff { + t.Errorf("expected mode %q, got %q", ContextModeDiff, ctx.Mode) + } + if ctx.DiffLeft.ID != "db-1" { + t.Errorf("expected left ID %q, got %q", "db-1", ctx.DiffLeft.ID) + } + if ctx.DiffRight.Region != "us-west-2" { + t.Errorf("expected right region %q, got %q", "us-west-2", ctx.DiffRight.Region) + } + }) +} + +func TestResourceRef(t *testing.T) { + ref := ResourceRef{ + ID: "i-12345", + Name: "my-instance", + Region: "us-east-1", + Profile: "prod", + Cluster: "my-cluster", + } + + if ref.ID != "i-12345" { + t.Errorf("expected ID %q, got %q", "i-12345", ref.ID) + } + if ref.Cluster != "my-cluster" { + t.Errorf("expected cluster %q, got %q", "my-cluster", ref.Cluster) + } +} + +func TestSessionListEmpty(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "claws-session-test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + t.Setenv("HOME", tmpDir) + + sm := NewSessionManager(10, true) + + sessions, err := sm.ListSessions() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(sessions) != 0 { + t.Errorf("expected 0 sessions, got %d", len(sessions)) + } +} + +func TestSessionPruning(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "claws-session-test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + t.Setenv("HOME", tmpDir) + + sm := NewSessionManager(3, true) // Max 3 sessions + + // Create 5 sessions with messages (files only saved on AddMessage) + for i := 0; i < 5; i++ { + sess, err := sm.NewSession(nil) + if err != nil { + t.Fatalf("failed to create session %d: %v", i, err) + } + // AddMessage triggers file save and pruning (on first message) + if err := sm.AddMessage(sess, NewUserMessage("test")); err != nil { + t.Fatalf("failed to add message to session %d: %v", i, err) + } + time.Sleep(time.Millisecond) // Ensure different timestamps + } + + // List sessions - should be pruned to 3 + sessions, err := sm.ListSessions() + if err != nil { + t.Fatalf("failed to list sessions: %v", err) + } + + if len(sessions) != 3 { + t.Errorf("expected 3 sessions after pruning, got %d", len(sessions)) + } +} + +func TestShouldPrune(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + defer os.Setenv("HOME", origHome) + os.Setenv("HOME", tmpDir) + + sm := NewSessionManager(5, true) + + t.Run("no sessions", func(t *testing.T) { + should, err := sm.shouldPrune() + if err != nil { + t.Fatalf("shouldPrune failed: %v", err) + } + if should { + t.Error("empty dir should not need pruning") + } + }) + + t.Run("under limit", func(t *testing.T) { + // Create 4 sessions (below limit of 5) + for i := 0; i < 4; i++ { + sess, err := sm.NewSession(nil) + if err != nil { + t.Fatalf("NewSession failed: %v", err) + } + if err := sm.SaveMessages(sess); err != nil { + t.Fatalf("SaveMessages failed: %v", err) + } + } + + should, err := sm.shouldPrune() + if err != nil { + t.Fatalf("shouldPrune failed: %v", err) + } + if should { + t.Error("under limit should not need pruning") + } + }) + + t.Run("at limit", func(t *testing.T) { + // Add one more session (total 5, at limit) + sess, err := sm.NewSession(nil) + if err != nil { + t.Fatalf("NewSession failed: %v", err) + } + if err := sm.SaveMessages(sess); err != nil { + t.Fatalf("SaveMessages failed: %v", err) + } + + should, err := sm.shouldPrune() + if err != nil { + t.Fatalf("shouldPrune failed: %v", err) + } + if should { + t.Error("at limit should not need pruning") + } + }) + + t.Run("over limit", func(t *testing.T) { + // Add one more session (total 6, over limit) + sess, err := sm.NewSession(nil) + if err != nil { + t.Fatalf("NewSession failed: %v", err) + } + if err := sm.SaveMessages(sess); err != nil { + t.Fatalf("SaveMessages failed: %v", err) + } + + should, err := sm.shouldPrune() + if err != nil { + t.Fatalf("shouldPrune failed: %v", err) + } + if !should { + t.Error("over limit should need pruning") + } + }) + + t.Run("non-existent directory", func(t *testing.T) { + tmpDir2 := t.TempDir() + os.Setenv("HOME", tmpDir2) + defer os.Setenv("HOME", tmpDir) + + sm2 := NewSessionManager(5, true) + should, err := sm2.shouldPrune() + if err != nil { + t.Fatalf("shouldPrune failed: %v", err) + } + if should { + t.Error("non-existent dir should not need pruning") + } + }) +} diff --git a/internal/ai/tools.go b/internal/ai/tools.go new file mode 100644 index 00000000..19ac28a1 --- /dev/null +++ b/internal/ai/tools.go @@ -0,0 +1,769 @@ +package ai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs" + + appaws "github.com/clawscli/claws/internal/aws" + appconfig "github.com/clawscli/claws/internal/config" + "github.com/clawscli/claws/internal/dao" + "github.com/clawscli/claws/internal/log" + "github.com/clawscli/claws/internal/registry" + + apigatewayStages "github.com/clawscli/claws/custom/apigateway/stages" + apigatewayStagesV2 "github.com/clawscli/claws/custom/apigateway/stages-v2" + cloudtrailtrails "github.com/clawscli/claws/custom/cloudtrail/trails" + codebuildbuilds "github.com/clawscli/claws/custom/codebuild/builds" + codebuildprojects "github.com/clawscli/claws/custom/codebuild/projects" + ecsservices "github.com/clawscli/claws/custom/ecs/services" + taskdefinitions "github.com/clawscli/claws/custom/ecs/task-definitions" + ecstasks "github.com/clawscli/claws/custom/ecs/tasks" + sfnStateMachines "github.com/clawscli/claws/custom/stepfunctions/state-machines" +) + +type ToolExecutor struct { + registry *registry.Registry +} + +func NewToolExecutor(_ context.Context, reg *registry.Registry) (*ToolExecutor, error) { + return &ToolExecutor{ + registry: reg, + }, nil +} + +func (e *ToolExecutor) Tools() []Tool { + return []Tool{ + { + Name: "list_resources", + Description: "List resource types available for a specific AWS service", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "service": map[string]any{ + "type": "string", + "description": "AWS service name (e.g., ec2, lambda, s3)", + }, + }, + "required": []string{"service"}, + }, + }, + { + Name: "query_resources", + Description: "List AWS resources. You MUST provide service, resource_type, and region parameters.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "service": map[string]any{ + "type": "string", + "description": "AWS service name. Examples: ec2, lambda, s3, rds, ecs, dynamodb", + }, + "resource_type": map[string]any{ + "type": "string", + "description": "Resource type. Examples: instances (for ec2), functions (for lambda), buckets (for s3), tables (for dynamodb)", + }, + "region": map[string]any{ + "type": "string", + "description": "AWS region. Examples: us-east-1, us-west-2, ap-northeast-1", + }, + "profile": map[string]any{ + "type": "string", + "description": "AWS profile name (optional, uses current profile if not specified)", + }, + "include_resolved": map[string]any{ + "type": "boolean", + "description": "Include resolved/archived items (securityhub/findings only, default: false)", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum resources to return (default: 100, max: 2000)", + }, + "offset": map[string]any{ + "type": "integer", + "description": "Skip first N resources for pagination (default: 0)", + }, + }, + "required": []string{"service", "resource_type", "region"}, + }, + }, + { + Name: "get_resource_detail", + Description: "Get detailed information about a specific AWS resource. NOTE: For ecs/services and ecs/tasks, cluster parameter is required.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "service": map[string]any{ + "type": "string", + "description": "AWS service name", + }, + "resource_type": map[string]any{ + "type": "string", + "description": "Resource type", + }, + "region": map[string]any{ + "type": "string", + "description": "AWS region (e.g., us-east-1, us-west-2)", + }, + "id": map[string]any{ + "type": "string", + "description": "Resource ID", + }, + "cluster": map[string]any{ + "type": "string", + "description": "ECS cluster name (required for ecs/services and ecs/tasks)", + }, + "profile": map[string]any{ + "type": "string", + "description": "AWS profile name (optional, uses current profile if not specified)", + }, + }, + "required": []string{"service", "resource_type", "region", "id"}, + }, + }, + { + Name: "tail_logs", + Description: "Fetch recent CloudWatch logs for an AWS resource. Automatically extracts log group from resource configuration. Supported: lambda/functions, ecs/services, ecs/tasks, ecs/task-definitions, codebuild/projects, codebuild/builds, cloudtrail/trails, apigateway/stages, apigateway/stages-v2, stepfunctions/state-machines. NOTE: For ecs/services and ecs/tasks, cluster parameter is required.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "service": map[string]any{ + "type": "string", + "description": "AWS service name (e.g., lambda, ecs, codebuild)", + }, + "resource_type": map[string]any{ + "type": "string", + "description": "Resource type (e.g., functions, services, tasks, task-definitions)", + }, + "region": map[string]any{ + "type": "string", + "description": "AWS region (e.g., us-east-1, ap-northeast-1)", + }, + "id": map[string]any{ + "type": "string", + "description": "Resource ID", + }, + "cluster": map[string]any{ + "type": "string", + "description": "ECS cluster name (required for ecs/services and ecs/tasks)", + }, + "profile": map[string]any{ + "type": "string", + "description": "AWS profile name (optional, uses current profile if not specified)", + }, + "filter": map[string]any{ + "type": "string", + "description": "Optional filter pattern for log messages", + }, + "since": map[string]any{ + "type": "string", + "description": "Time range (e.g., 5m, 1h, 24h). Default: 15m", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of log events. Default: 100", + }, + }, + "required": []string{"service", "resource_type", "region", "id"}, + }, + }, + { + Name: "search_aws_docs", + Description: "Search AWS documentation for information", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "Search query for AWS documentation", + }, + }, + "required": []string{"query"}, + }, + }, + } +} + +func (e *ToolExecutor) Execute(ctx context.Context, call *ToolUseContent) ToolResultContent { + if call.InputError != "" { + return ToolResultContent{ + ToolUseID: call.ID, + Content: fmt.Sprintf("Error: malformed tool input: %s", call.InputError), + IsError: true, + } + } + + var content string + var isError bool + + switch call.Name { + case "list_resources": + service, _ := call.Input["service"].(string) + content = e.listResources(service) + case "query_resources": + service, _ := call.Input["service"].(string) + resourceType, _ := call.Input["resource_type"].(string) + region, _ := call.Input["region"].(string) + profile, _ := call.Input["profile"].(string) + includeResolved, _ := call.Input["include_resolved"].(bool) + limit, _ := call.Input["limit"].(float64) + offset, _ := call.Input["offset"].(float64) + content, isError = e.queryResources(ctx, service, resourceType, region, profile, includeResolved, int(limit), int(offset)) + case "get_resource_detail": + service, _ := call.Input["service"].(string) + resourceType, _ := call.Input["resource_type"].(string) + region, _ := call.Input["region"].(string) + id, _ := call.Input["id"].(string) + cluster, _ := call.Input["cluster"].(string) + profile, _ := call.Input["profile"].(string) + content, isError = e.getResourceDetail(ctx, service, resourceType, region, id, cluster, profile) + case "tail_logs": + service, _ := call.Input["service"].(string) + resourceType, _ := call.Input["resource_type"].(string) + region, _ := call.Input["region"].(string) + id, _ := call.Input["id"].(string) + cluster, _ := call.Input["cluster"].(string) + profile, _ := call.Input["profile"].(string) + filter, _ := call.Input["filter"].(string) + since, _ := call.Input["since"].(string) + limit, _ := call.Input["limit"].(float64) + content, isError = e.tailLogs(ctx, service, resourceType, region, id, cluster, profile, filter, since, int(limit)) + case "search_aws_docs": + query, _ := call.Input["query"].(string) + content = e.searchDocs(ctx, query) + default: + content = fmt.Sprintf("Unknown tool: %s", call.Name) + isError = true + } + + return ToolResultContent{ + ToolUseID: call.ID, + Content: content, + IsError: isError, + } +} + +func (e *ToolExecutor) listResources(service string) string { + resources := e.registry.ListResources(service) + if len(resources) == 0 { + return fmt.Sprintf("No resources found for service: %s", service) + } + + displayName := e.registry.GetDisplayName(service) + result := fmt.Sprintf("Resource types for %s (%s):\n", displayName, service) + for _, r := range resources { + result += fmt.Sprintf("- %s\n", r) + } + return result +} + +func (e *ToolExecutor) queryResources(ctx context.Context, service, resourceType, region, profile string, includeResolved bool, limit, offset int) (string, bool) { + if service == "" { + return "Error: service parameter is required", true + } + if resourceType == "" { + return "Error: resource_type parameter is required", true + } + if region == "" { + return "Error: region parameter is required", true + } + + // Validate and apply limit + if limit <= 0 { + limit = 100 // default changed from 50 + } + if limit > 2000 { + limit = 2000 // max 2000 + } + + // Validate offset + if offset < 0 { + offset = 0 + } + + if profile != "" { + ctx = appaws.WithSelectionOverride(ctx, appconfig.ProfileSelectionFromID(profile)) + } + ctx = appaws.WithRegionOverride(ctx, region) + if includeResolved { + ctx = dao.WithFilter(ctx, "ShowResolved", "true") + } + d, err := e.registry.GetDAO(ctx, service, resourceType) + if err != nil { + return fmt.Sprintf("Error: %s/%s not found. Use list_resources(service=\"%s\") to see available types.", service, resourceType, service), true + } + + resources, err := d.List(ctx) + if err != nil { + return fmt.Sprintf("Error listing %s/%s: %v", service, resourceType, err), true + } + + if len(resources) == 0 { + return fmt.Sprintf("No %s/%s resources found in %s", service, resourceType, region), false + } + + filterNote := "" + if service == "securityhub" && resourceType == "findings" { + if includeResolved { + filterNote = " (including resolved)" + } else { + filterNote = " (active only, use include_resolved=true for all)" + } + } + + // Apply offset + start := offset + if start >= len(resources) { + return fmt.Sprintf("Offset %d exceeds total count %d", offset, len(resources)), true + } + + end := start + limit + if end > len(resources) { + end = len(resources) + } + + viewResources := resources[start:end] + + result := fmt.Sprintf("Found %d %s/%s resources in %s%s (showing %d-%d):\n\n", + len(resources), service, resourceType, region, filterNote, start+1, end) + + for _, r := range viewResources { + result += formatResourceSummary(r) + } + + if end < len(resources) { + result += fmt.Sprintf("\n... and %d more (use offset=%d to see next page)\n", len(resources)-end, end) + } + + return result, false +} + +func (e *ToolExecutor) getResourceDetail(ctx context.Context, service, resourceType, region, id, cluster, profile string) (string, bool) { + if region == "" { + return "Error: region parameter is required", true + } + + if profile != "" { + ctx = appaws.WithSelectionOverride(ctx, appconfig.ProfileSelectionFromID(profile)) + } + ctx = appaws.WithRegionOverride(ctx, region) + + if service == "ecs" && (resourceType == "services" || resourceType == "tasks") { + if cluster == "" { + err := "Error: cluster parameter is required for ecs/services and ecs/tasks" + log.Warn("getResourceDetail failed", "error", err) + return err, true + } + ctx = dao.WithFilter(ctx, "ClusterName", cluster) + } + + d, err := e.registry.GetDAO(ctx, service, resourceType) + if err != nil { + log.Warn("getResourceDetail GetDAO failed", "error", err) + return fmt.Sprintf("Error getting DAO: %v", err), true + } + + resource, err := d.Get(ctx, id) + if err != nil { + log.Warn("getResourceDetail Get failed", "service", service, "resourceType", resourceType, "id", id, "error", err) + return fmt.Sprintf("Error getting resource: %v", err), true + } + + return formatResourceDetail(resource), false +} + +func (e *ToolExecutor) tailLogs(ctx context.Context, service, resourceType, region, id, cluster, profile, filter, since string, limit int) (string, bool) { + if region == "" { + return "Error: region parameter is required", true + } + if limit <= 0 { + limit = 100 + } + if limit > 500 { + limit = 500 + } + + if profile != "" { + ctx = appaws.WithSelectionOverride(ctx, appconfig.ProfileSelectionFromID(profile)) + } + ctx = appaws.WithRegionOverride(ctx, region) + + logGroup, err := e.extractLogGroup(ctx, service, resourceType, id, cluster) + if err != nil { + log.Warn("tailLogs extractLogGroup failed", "service", service, "resourceType", resourceType, "id", id, "error", err) + return fmt.Sprintf("Error extracting log group for %s/%s/%s: %v", service, resourceType, id, err), true + } + + cfg, err := appaws.NewConfigWithRegion(ctx, region) + if err != nil { + return fmt.Sprintf("Error creating config for region %s: %v", region, err), true + } + cwClient := cloudwatchlogs.NewFromConfig(cfg) + + startTime := time.Now().Add(-15 * time.Minute) + if since != "" { + if d, err := time.ParseDuration(since); err == nil { + startTime = time.Now().Add(-d) + } + } + + input := &cloudwatchlogs.FilterLogEventsInput{ + LogGroupName: aws.String(logGroup), + StartTime: aws.Int64(startTime.UnixMilli()), + Limit: aws.Int32(int32(limit)), + } + + if filter != "" { + input.FilterPattern = aws.String(filter) + } + + output, err := cwClient.FilterLogEvents(ctx, input) + if err != nil { + return fmt.Sprintf("Error fetching logs from %s: %v", logGroup, err), true + } + + if len(output.Events) == 0 { + sinceStr := "15m" + if since != "" { + sinceStr = since + } + return fmt.Sprintf("No logs found in %s (since %s)", logGroup, sinceStr), false + } + + result := fmt.Sprintf("Logs from %s (%d events):\n\n", logGroup, len(output.Events)) + for _, event := range output.Events { + ts := time.UnixMilli(aws.ToInt64(event.Timestamp)) + result += fmt.Sprintf("[%s] %s\n", ts.Format("15:04:05"), aws.ToString(event.Message)) + } + + return result, false +} + +func (e *ToolExecutor) extractLogGroup(ctx context.Context, service, resourceType, id, cluster string) (string, error) { + key := service + "/" + resourceType + + switch key { + case "lambda/functions": + return "/aws/lambda/" + id, nil + + case "ecs/task-definitions": + resource, err := e.getResource(ctx, service, resourceType, id) + if err != nil { + return "", err + } + td, ok := resource.(*taskdefinitions.TaskDefinitionResource) + if !ok { + return "", fmt.Errorf("unexpected resource type for task-definitions") + } + if logGroup := td.GetCloudWatchLogGroup(""); logGroup != "" { + return logGroup, nil + } + return "", fmt.Errorf("no CloudWatch logs configured for task definition %s", id) + + case "ecs/services": + if cluster == "" { + return "", fmt.Errorf("cluster parameter is required for ecs/services") + } + ctxWithCluster := dao.WithFilter(ctx, "ClusterName", cluster) + resource, err := e.getResource(ctxWithCluster, service, resourceType, id) + if err != nil { + return "", err + } + svc, ok := resource.(*ecsservices.ServiceResource) + if !ok { + return "", fmt.Errorf("unexpected resource type for ecs services") + } + taskDefArn := svc.TaskDefinition() + if taskDefArn == "" { + return "", fmt.Errorf("no task definition found for service %s", id) + } + return e.extractLogGroupFromTaskDef(ctx, taskDefArn) + + case "ecs/tasks": + if cluster == "" { + return "", fmt.Errorf("cluster parameter is required for ecs/tasks") + } + ctxWithCluster := dao.WithFilter(ctx, "ClusterName", cluster) + resource, err := e.getResource(ctxWithCluster, service, resourceType, id) + if err != nil { + return "", err + } + task, ok := resource.(*ecstasks.TaskResource) + if !ok { + return "", fmt.Errorf("unexpected resource type for ecs tasks") + } + taskDefArn := task.TaskDefinitionArn() + if taskDefArn == "" { + return "", fmt.Errorf("no task definition found for task %s", id) + } + return e.extractLogGroupFromTaskDef(ctx, taskDefArn) + + case "codebuild/projects": + resource, err := e.getResource(ctx, service, resourceType, id) + if err != nil { + return "", err + } + proj, ok := resource.(*codebuildprojects.ProjectResource) + if !ok { + return "", fmt.Errorf("unexpected resource type for codebuild projects") + } + if proj.Project.LogsConfig != nil && + proj.Project.LogsConfig.CloudWatchLogs != nil && + proj.Project.LogsConfig.CloudWatchLogs.GroupName != nil { + return *proj.Project.LogsConfig.CloudWatchLogs.GroupName, nil + } + return "/aws/codebuild/" + id, nil + + case "codebuild/builds": + resource, err := e.getResource(ctx, service, resourceType, id) + if err != nil { + return "", err + } + build, ok := resource.(*codebuildbuilds.BuildResource) + if !ok { + return "", fmt.Errorf("unexpected resource type for codebuild builds") + } + if build.LogsGroupName() != "" { + return build.LogsGroupName(), nil + } + return "", fmt.Errorf("no CloudWatch logs configured for build %s", id) + + case "cloudtrail/trails": + resource, err := e.getResource(ctx, service, resourceType, id) + if err != nil { + return "", err + } + trail, ok := resource.(*cloudtrailtrails.TrailResource) + if !ok { + return "", fmt.Errorf("unexpected resource type for cloudtrail trails") + } + logGroupArn := trail.CloudWatchLogsLogGroupArn() + if logGroupArn == "" { + return "", fmt.Errorf("no CloudWatch logs configured for trail %s", id) + } + return extractLogGroupNameFromArn(logGroupArn), nil + + case "apigateway/stages": + resource, err := e.getResource(ctx, service, resourceType, id) + if err != nil { + return "", err + } + stage, ok := resource.(*apigatewayStages.StageResource) + if !ok { + return "", fmt.Errorf("unexpected resource type for apigateway stages") + } + destArn := stage.AccessLogDestination() + if destArn == "" { + return "", fmt.Errorf("no access logs configured for stage %s", id) + } + return extractLogGroupNameFromArn(destArn), nil + + case "apigateway/stages-v2": + resource, err := e.getResource(ctx, service, resourceType, id) + if err != nil { + return "", err + } + stage, ok := resource.(*apigatewayStagesV2.StageV2Resource) + if !ok { + return "", fmt.Errorf("unexpected resource type for apigateway stages-v2") + } + destArn := stage.AccessLogDestination() + if destArn == "" { + return "", fmt.Errorf("no access logs configured for stage %s", id) + } + return extractLogGroupNameFromArn(destArn), nil + + case "stepfunctions/state-machines": + resource, err := e.getResource(ctx, service, resourceType, id) + if err != nil { + return "", err + } + sm, ok := resource.(*sfnStateMachines.StateMachineResource) + if !ok { + return "", fmt.Errorf("unexpected resource type for stepfunctions state-machines") + } + if sm.Detail != nil && sm.Detail.LoggingConfiguration != nil { + for _, dest := range sm.Detail.LoggingConfiguration.Destinations { + if dest.CloudWatchLogsLogGroup != nil && dest.CloudWatchLogsLogGroup.LogGroupArn != nil { + return extractLogGroupNameFromArn(*dest.CloudWatchLogsLogGroup.LogGroupArn), nil + } + } + } + return "", fmt.Errorf("no CloudWatch logs configured for state machine %s", id) + + default: + return "", fmt.Errorf("log extraction not supported for %s/%s. Supported: lambda/functions, ecs/services, ecs/tasks, ecs/task-definitions, codebuild/projects, codebuild/builds, cloudtrail/trails, apigateway/stages, apigateway/stages-v2, stepfunctions/state-machines", service, resourceType) + } +} + +func (e *ToolExecutor) extractLogGroupFromTaskDef(ctx context.Context, taskDefArn string) (string, error) { + taskDefID := appaws.ExtractResourceName(taskDefArn) + resource, err := e.getResource(ctx, "ecs", "task-definitions", taskDefID) + if err != nil { + return "", fmt.Errorf("failed to get task definition %s: %w", taskDefArn, err) + } + + td, ok := resource.(*taskdefinitions.TaskDefinitionResource) + if !ok { + return "", fmt.Errorf("unexpected resource type") + } + + if logGroup := td.GetCloudWatchLogGroup(""); logGroup != "" { + return logGroup, nil + } + + return "", fmt.Errorf("no CloudWatch logs configured in task definition %s", taskDefArn) +} + +func extractLogGroupNameFromArn(arn string) string { + parts := strings.Split(arn, ":") + if len(parts) >= 7 { + logGroupPart := parts[6] + if strings.HasPrefix(logGroupPart, "log-group:") { + return strings.TrimPrefix(logGroupPart, "log-group:") + } + return logGroupPart + } + return arn +} + +func (e *ToolExecutor) searchDocs(ctx context.Context, query string) string { + if query == "" { + return "Error: query parameter is required" + } + + reqBody := map[string]any{ + "textQuery": map[string]string{ + "input": query, + }, + "contextAttributes": []map[string]string{ + {"key": "domain", "value": "docs.aws.amazon.com"}, + }, + "acceptSuggestionBody": "RawText", + "locales": []string{"en_us"}, + } + + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return fmt.Sprintf("Error creating request: %v", err) + } + + reqCtx, cancel := context.WithTimeout(ctx, appconfig.File().DocsSearchTimeout()) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, "POST", "https://proxy.search.docs.aws.amazon.com/search", bytes.NewBuffer(jsonBody)) + if err != nil { + return fmt.Sprintf("Error creating request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Sprintf("Error searching documentation: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + return fmt.Sprintf("Error: received status %d from AWS documentation search", resp.StatusCode) + } + + var result struct { + Suggestions []struct { + TextExcerptSuggestion struct { + Link string `json:"link"` + Title string `json:"title"` + Metadata struct { + SeoAbstract string `json:"seo_abstract"` + Abstract string `json:"abstract"` + } `json:"metadata"` + Summary string `json:"summary"` + } `json:"textExcerptSuggestion"` + } `json:"suggestions"` + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return fmt.Sprintf("Error parsing response: %v", err) + } + + if len(result.Suggestions) == 0 { + return fmt.Sprintf("No documentation found for: %s", query) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("AWS Documentation results for '%s':\n\n", query)) + for i, s := range result.Suggestions { + if i >= 5 { + break + } + suggestion := s.TextExcerptSuggestion + sb.WriteString(fmt.Sprintf("%d. %s\n", i+1, suggestion.Title)) + sb.WriteString(fmt.Sprintf(" URL: %s\n", suggestion.Link)) + context := suggestion.Metadata.SeoAbstract + if context == "" { + context = suggestion.Metadata.Abstract + } + if context == "" { + context = suggestion.Summary + } + if context != "" { + sb.WriteString(fmt.Sprintf(" %s\n", context)) + } + sb.WriteString("\n") + } + return sb.String() +} + +func (e *ToolExecutor) getResource(ctx context.Context, service, resourceType, id string) (dao.Resource, error) { + d, err := e.registry.GetDAO(ctx, service, resourceType) + if err != nil { + return nil, err + } + resource, err := d.Get(ctx, id) + if err != nil { + return nil, err + } + return dao.UnwrapResource(resource), nil +} + +func formatResourceSummary(r dao.Resource) string { + result := fmt.Sprintf("- ID: %s", r.GetID()) + if name := r.GetName(); name != "" && name != r.GetID() { + result += fmt.Sprintf(", Name: %s", name) + } + result += "\n" + return result +} + +func formatResourceDetail(r dao.Resource) string { + result := fmt.Sprintf("ID: %s\n", r.GetID()) + + if name := r.GetName(); name != "" { + result += fmt.Sprintf("Name: %s\n", name) + } + + if arn := r.GetARN(); arn != "" { + result += fmt.Sprintf("ARN: %s\n", arn) + } + + if tags := r.GetTags(); len(tags) > 0 { + result += "\nTags:\n" + for k, v := range tags { + result += fmt.Sprintf(" %s: %s\n", k, v) + } + } + + if raw := r.Raw(); raw != nil { + data, err := json.MarshalIndent(raw, "", " ") + if err == nil { + result += fmt.Sprintf("\nRaw Data:\n%s\n", string(data)) + } + } + + return result +} diff --git a/internal/ai/tools_test.go b/internal/ai/tools_test.go new file mode 100644 index 00000000..3cf0c34e --- /dev/null +++ b/internal/ai/tools_test.go @@ -0,0 +1,326 @@ +package ai + +import ( + "context" + "strings" + "testing" +) + +func TestToolExecutorTools(t *testing.T) { + executor := &ToolExecutor{} + tools := executor.Tools() + + expectedTools := []string{ + "list_resources", + "query_resources", + "get_resource_detail", + "tail_logs", + "search_aws_docs", + } + + if len(tools) != len(expectedTools) { + t.Errorf("expected %d tools, got %d", len(expectedTools), len(tools)) + } + + toolNames := make(map[string]bool) + for _, tool := range tools { + toolNames[tool.Name] = true + } + + for _, name := range expectedTools { + if !toolNames[name] { + t.Errorf("missing tool: %s", name) + } + } +} + +func TestToolSchemas(t *testing.T) { + executor := &ToolExecutor{} + tools := executor.Tools() + + for _, tool := range tools { + t.Run(tool.Name, func(t *testing.T) { + if tool.Name == "" { + t.Error("tool name is empty") + } + if tool.Description == "" { + t.Error("tool description is empty") + } + if tool.InputSchema == nil { + t.Error("tool input schema is nil") + } + + schemaType, ok := tool.InputSchema["type"].(string) + if !ok || schemaType != "object" { + t.Errorf("expected schema type 'object', got %v", tool.InputSchema["type"]) + } + + props, ok := tool.InputSchema["properties"].(map[string]any) + if !ok { + t.Error("schema properties is not a map") + } + + if len(props) == 0 { + t.Error("schema has no properties") + } + }) + } +} + +func TestQueryResourcesRequiredParams(t *testing.T) { + executor := &ToolExecutor{} + tools := executor.Tools() + + var queryTool *Tool + for i := range tools { + if tools[i].Name == "query_resources" { + queryTool = &tools[i] + break + } + } + + if queryTool == nil { + t.Fatal("query_resources tool not found") + } + + required, ok := queryTool.InputSchema["required"].([]string) + if !ok { + t.Fatal("required field is not []string") + } + + expectedRequired := map[string]bool{ + "service": true, + "resource_type": true, + "region": true, + } + + for _, r := range required { + if !expectedRequired[r] { + t.Errorf("unexpected required field: %s", r) + } + delete(expectedRequired, r) + } + + for missing := range expectedRequired { + t.Errorf("missing required field: %s", missing) + } +} + +func TestToolExecuteUnknownTool(t *testing.T) { + executor := &ToolExecutor{registry: nil} + + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "test-123", + Name: "unknown_tool", + Input: map[string]any{}, + }) + + if result.ToolUseID != "test-123" { + t.Errorf("expected tool use ID %q, got %q", "test-123", result.ToolUseID) + } + if !result.IsError { + t.Error("expected IsError to be true") + } + if !strings.Contains(result.Content, "Unknown tool") { + t.Errorf("expected error message about unknown tool, got %q", result.Content) + } +} + +func TestToolExecuteQueryResourcesMissingParams(t *testing.T) { + executor := &ToolExecutor{registry: nil} + + tests := []struct { + name string + input map[string]any + expectedError string + }{ + { + name: "missing service", + input: map[string]any{"resource_type": "instances", "region": "us-east-1"}, + expectedError: "service parameter is required", + }, + { + name: "missing resource_type", + input: map[string]any{"service": "ec2", "region": "us-east-1"}, + expectedError: "resource_type parameter is required", + }, + { + name: "missing region", + input: map[string]any{"service": "ec2", "resource_type": "instances"}, + expectedError: "region parameter is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "test-123", + Name: "query_resources", + Input: tt.input, + }) + + if !result.IsError { + t.Error("expected IsError to be true") + } + if !strings.Contains(result.Content, tt.expectedError) { + t.Errorf("expected error %q, got %q", tt.expectedError, result.Content) + } + }) + } +} + +func TestToolExecuteGetResourceDetailMissingRegion(t *testing.T) { + executor := &ToolExecutor{registry: nil} + + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "test-123", + Name: "get_resource_detail", + Input: map[string]any{ + "service": "ec2", + "resource_type": "instances", + "id": "i-12345", + }, + }) + + if !result.IsError { + t.Error("expected IsError to be true") + } + if !strings.Contains(result.Content, "region parameter is required") { + t.Errorf("expected region error, got %q", result.Content) + } +} + +func TestToolExecuteTailLogsMissingRegion(t *testing.T) { + executor := &ToolExecutor{registry: nil} + + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "test-123", + Name: "tail_logs", + Input: map[string]any{ + "service": "lambda", + "resource_type": "functions", + "id": "my-function", + }, + }) + + if !result.IsError { + t.Error("expected IsError to be true") + } + if !strings.Contains(result.Content, "region parameter is required") { + t.Errorf("expected region error, got %q", result.Content) + } +} + +func TestToolExecuteSearchDocsEmptyQuery(t *testing.T) { + executor := &ToolExecutor{registry: nil} + + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "test-123", + Name: "search_aws_docs", + Input: map[string]any{}, + }) + + if !strings.Contains(result.Content, "query parameter is required") { + t.Errorf("expected query error, got %q", result.Content) + } +} + +func TestExtractLogGroupNameFromArn(t *testing.T) { + tests := []struct { + arn string + expected string + }{ + { + arn: "arn:aws:logs:us-east-1:123456789012:log-group:/aws/lambda/my-function", + expected: "/aws/lambda/my-function", + }, + { + arn: "arn:aws:logs:us-west-2:123456789012:log-group:/ecs/my-service", + expected: "/ecs/my-service", + }, + { + arn: "/aws/lambda/simple", + expected: "/aws/lambda/simple", + }, + } + + for _, tt := range tests { + t.Run(tt.arn, func(t *testing.T) { + result := extractLogGroupNameFromArn(tt.arn) + if result != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestFormatResourceSummary(t *testing.T) { + resource := &mockResource{ + id: "i-12345", + name: "my-instance", + } + + result := formatResourceSummary(resource) + + if !strings.Contains(result, "i-12345") { + t.Errorf("expected ID in summary, got %q", result) + } + if !strings.Contains(result, "my-instance") { + t.Errorf("expected name in summary, got %q", result) + } +} + +func TestFormatResourceSummarySameIDAndName(t *testing.T) { + resource := &mockResource{ + id: "my-bucket", + name: "my-bucket", + } + + result := formatResourceSummary(resource) + + if strings.Count(result, "my-bucket") != 1 { + t.Errorf("expected ID only once when same as name, got %q", result) + } +} + +func TestFormatResourceDetail(t *testing.T) { + resource := &mockResource{ + id: "i-12345", + name: "my-instance", + arn: "arn:aws:ec2:us-east-1:123456789012:instance/i-12345", + tags: map[string]string{"Environment": "prod", "Team": "platform"}, + raw: map[string]string{"InstanceType": "t3.micro"}, + } + + result := formatResourceDetail(resource) + + if !strings.Contains(result, "i-12345") { + t.Errorf("expected ID in detail, got %q", result) + } + if !strings.Contains(result, "my-instance") { + t.Errorf("expected name in detail, got %q", result) + } + if !strings.Contains(result, "arn:aws:ec2") { + t.Errorf("expected ARN in detail, got %q", result) + } + if !strings.Contains(result, "Environment") { + t.Errorf("expected tags in detail, got %q", result) + } + if !strings.Contains(result, "InstanceType") { + t.Errorf("expected raw data in detail, got %q", result) + } +} + +type mockResource struct { + id string + name string + arn string + tags map[string]string + raw any +} + +func (m *mockResource) GetID() string { return m.id } +func (m *mockResource) GetName() string { return m.name } +func (m *mockResource) GetARN() string { return m.arn } +func (m *mockResource) GetTags() map[string]string { return m.tags } +func (m *mockResource) Raw() any { return m.raw } diff --git a/internal/app/app.go b/internal/app/app.go index 67cc2a8a..19b75006 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -11,6 +11,7 @@ import ( tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" + "github.com/clawscli/claws/internal/ai" "github.com/clawscli/claws/internal/aws" "github.com/clawscli/claws/internal/clipboard" "github.com/clawscli/claws/internal/config" @@ -331,6 +332,15 @@ func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { profileSelector.Init(), a.modal.SetSize(a.width, a.height), ) + + case key.Matches(msg, a.keys.AI): + aiCtx := a.buildAIContext() + chatOverlay := view.NewChatOverlay(a.ctx, a.registry, aiCtx) + a.modal = &view.Modal{Content: chatOverlay, Width: view.ModalWidthChat} + return a, tea.Batch( + chatOverlay.Init(), + a.modal.SetSize(a.width, a.height), + ) } case view.ShowModalMsg: @@ -773,6 +783,7 @@ type keyMap struct { Command key.Binding Region key.Binding Profile key.Binding + AI key.Binding Help key.Binding Quit key.Binding } @@ -811,6 +822,10 @@ func defaultKeyMap() keyMap { key.WithKeys("P"), key.WithHelp("P", "profile"), ), + AI: key.NewBinding( + key.WithKeys("A"), + key.WithHelp("A", "ai chat"), + ), Help: key.NewBinding( key.WithKeys("?"), key.WithHelp("?", "help"), @@ -854,3 +869,88 @@ func (a *App) resolveStartupView(viewName string) view.View { return view.NewResourceBrowserWithType(a.ctx, a.registry, service, resourceType) } } + +func (a *App) buildAIContext() *ai.Context { + regions := config.Global().Regions() + selections := config.Global().Selections() + var profiles []string + for _, sel := range selections { + if id := sel.ID(); id != "" { + profiles = append(profiles, id) + } + } + + switch v := a.currentView.(type) { + case *view.ResourceBrowser: + return &ai.Context{ + Mode: ai.ContextModeList, + Service: v.Service(), + ResourceType: v.ResourceType(), + ResourceCount: v.ResourceCount(), + FilterText: v.FilterText(), + Toggles: v.ToggleStates(), + UserRegions: regions, + UserProfiles: profiles, + } + + case *view.DiffView: + return &ai.Context{ + Mode: ai.ContextModeDiff, + Service: v.Service(), + ResourceType: v.ResourceType(), + DiffLeft: buildResourceRef(v.Left()), + DiffRight: buildResourceRef(v.Right()), + UserRegions: regions, + UserProfiles: profiles, + } + + case *view.DetailView: + r := v.Resource() + if r != nil { + unwrapped := dao.UnwrapResource(r) + resourceRegion := dao.GetResourceRegion(r) + log.Debug("buildAIContext DetailView", "service", v.Service(), "resourceType", v.ResourceType(), + "id", unwrapped.GetID(), "resourceRegion", resourceRegion, "regions", regions) + ctx := &ai.Context{ + Mode: ai.ContextModeSingle, + Service: v.Service(), + ResourceType: v.ResourceType(), + ResourceID: unwrapped.GetID(), + ResourceName: unwrapped.GetName(), + ResourceRegion: resourceRegion, + ResourceProfile: dao.GetResourceProfile(r), + UserRegions: regions, + UserProfiles: profiles, + } + if v.Service() == "lambda" && v.ResourceType() == "functions" { + ctx.LogGroup = "/aws/lambda/" + unwrapped.GetName() + } + if clusterArn := dao.GetResourceClusterArn(r); clusterArn != "" { + ctx.Cluster = aws.ExtractResourceName(clusterArn) + } + return ctx + } + + case *view.LogView: + return &ai.Context{ + LogGroup: v.LogGroupName(), + UserRegions: regions, + UserProfiles: profiles, + } + } + return &ai.Context{UserRegions: regions, UserProfiles: profiles} +} + +func buildResourceRef(r dao.Resource) *ai.ResourceRef { + unwrapped := dao.UnwrapResource(r) + ref := &ai.ResourceRef{ + ID: unwrapped.GetID(), + Name: unwrapped.GetName(), + Region: dao.GetResourceRegion(r), + Profile: dao.GetResourceProfile(r), + } + if clusterArn := dao.GetResourceClusterArn(r); clusterArn != "" { + ref.Cluster = aws.ExtractResourceName(clusterArn) + } + return ref +} diff --git a/internal/config/file.go b/internal/config/file.go index d2ca68d1..bccb5a38 100644 --- a/internal/config/file.go +++ b/internal/config/file.go @@ -10,6 +10,8 @@ import ( "time" "gopkg.in/yaml.v3" + + "github.com/clawscli/claws/internal/log" ) const ( @@ -18,9 +20,11 @@ const ( DefaultTagSearchTimeout = 30 * time.Second DefaultMetricsLoadTimeout = 30 * time.Second DefaultLogFetchTimeout = 10 * time.Second + DefaultDocsSearchTimeout = 10 * time.Second DefaultMetricsWindow = 15 * time.Minute DefaultMaxConcurrentFetches = 50 DefaultMaxStackSize = 100 + DefaultAIMaxToolCallsPerQuery = 50 ) func ConfigDir() (string, error) { @@ -45,6 +49,7 @@ type TimeoutConfig struct { TagSearch Duration `yaml:"tag_search,omitempty"` MetricsLoad Duration `yaml:"metrics_load,omitempty"` LogFetch Duration `yaml:"log_fetch,omitempty"` + DocsSearch Duration `yaml:"docs_search,omitempty"` } type CloudWatchConfig struct { @@ -78,11 +83,22 @@ func (s StartupConfig) GetProfiles() []string { return nil } -// NavigationConfig controls navigation behavior. type NavigationConfig struct { MaxStackSize int `yaml:"max_stack_size,omitempty"` } +type AIConfig struct { + Profile string `yaml:"profile,omitempty"` + Region string `yaml:"region,omitempty"` + Model string `yaml:"model,omitempty"` + MaxSessions int `yaml:"max_sessions,omitempty"` + MaxTokens int `yaml:"max_tokens,omitempty"` + ThinkingBudget *int `yaml:"thinking_budget,omitempty"` + MaxToolRounds int `yaml:"max_tool_rounds,omitempty"` + MaxToolCallsPerQuery int `yaml:"max_tool_calls_per_query,omitempty"` + SaveSessions *bool `yaml:"save_sessions,omitempty"` +} + // ThemeConfig holds theme configuration. // Can be specified as: // - A preset name string: "dark", "light", "nord", "dracula", "gruvbox", "catppuccin" @@ -134,6 +150,7 @@ type FileConfig struct { Startup StartupConfig `yaml:"startup,omitempty"` Theme ThemeConfig `yaml:"theme,omitempty"` Navigation NavigationConfig `yaml:"navigation,omitempty"` + AI AIConfig `yaml:"ai,omitempty"` } // Duration wraps time.Duration for YAML marshal/unmarshal as string (e.g., "5s", "30s") @@ -172,6 +189,7 @@ func DefaultFileConfig() *FileConfig { TagSearch: Duration(DefaultTagSearchTimeout), MetricsLoad: Duration(DefaultMetricsLoadTimeout), LogFetch: Duration(DefaultLogFetchTimeout), + DocsSearch: Duration(DefaultDocsSearchTimeout), }, Concurrency: ConcurrencyConfig{ MaxFetches: DefaultMaxConcurrentFetches, @@ -240,6 +258,9 @@ func (c *FileConfig) applyDefaults() { if c.Timeouts.LogFetch <= 0 { c.Timeouts.LogFetch = Duration(DefaultLogFetchTimeout) } + if c.Timeouts.DocsSearch <= 0 { + c.Timeouts.DocsSearch = Duration(DefaultDocsSearchTimeout) + } if c.CloudWatch.Window <= 0 { c.CloudWatch.Window = Duration(DefaultMetricsWindow) } @@ -296,6 +317,15 @@ func (c *FileConfig) LogFetchTimeout() time.Duration { }) } +func (c *FileConfig) DocsSearchTimeout() time.Duration { + return withRLock(&c.mu, func() time.Duration { + if c.Timeouts.DocsSearch == 0 { + return DefaultDocsSearchTimeout + } + return c.Timeouts.DocsSearch.Duration() + }) +} + func (c *FileConfig) MaxConcurrentFetches() int { return withRLock(&c.mu, func() int { if c.Concurrency.MaxFetches == 0 { @@ -362,6 +392,92 @@ func (c *FileConfig) GetTheme() ThemeConfig { return withRLock(&c.mu, func() ThemeConfig { return c.Theme }) } +const DefaultAIModel = "global.anthropic.claude-haiku-4-5-20251001-v1:0" +const DefaultAIMaxSessions = 100 +const DefaultAIMaxTokens = 16000 +const DefaultAIThinkingBudget = 8000 +const DefaultAIMaxToolRounds = 15 + +func (c *FileConfig) GetAIProfile() string { + return withRLock(&c.mu, func() string { + return c.AI.Profile + }) +} + +func (c *FileConfig) GetAIRegion() string { + return withRLock(&c.mu, func() string { + return c.AI.Region + }) +} + +func (c *FileConfig) GetAIModel() string { + return withRLock(&c.mu, func() string { + if c.AI.Model == "" { + return DefaultAIModel + } + return c.AI.Model + }) +} + +func (c *FileConfig) GetAIMaxSessions() int { + return withRLock(&c.mu, func() int { + if c.AI.MaxSessions <= 0 { + return DefaultAIMaxSessions + } + return c.AI.MaxSessions + }) +} + +func (c *FileConfig) GetAIMaxTokens() int { + return withRLock(&c.mu, func() int { + if c.AI.MaxTokens <= 0 { + return DefaultAIMaxTokens + } + return c.AI.MaxTokens + }) +} + +func (c *FileConfig) GetAIThinkingBudget() int { + return withRLock(&c.mu, func() int { + if c.AI.ThinkingBudget == nil { + return DefaultAIThinkingBudget + } + v := *c.AI.ThinkingBudget + if v < 0 { + log.Warn("ai.thinking_budget is negative, treating as disabled", "value", v) + return 0 + } + return v + }) +} + +func (c *FileConfig) GetAIMaxToolRounds() int { + return withRLock(&c.mu, func() int { + if c.AI.MaxToolRounds <= 0 { + return DefaultAIMaxToolRounds + } + return c.AI.MaxToolRounds + }) +} + +func (c *FileConfig) GetAIMaxToolCallsPerQuery() int { + return withRLock(&c.mu, func() int { + if c.AI.MaxToolCallsPerQuery <= 0 { + return DefaultAIMaxToolCallsPerQuery + } + return c.AI.MaxToolCallsPerQuery + }) +} + +func (c *FileConfig) GetAISaveSessions() bool { + return withRLock(&c.mu, func() bool { + if c.AI.SaveSessions == nil { + return false + } + return *c.AI.SaveSessions + }) +} + func (c *FileConfig) SaveRegions(regions []string) error { if len(regions) == 0 { return nil diff --git a/internal/config/file_test.go b/internal/config/file_test.go index 9a28f417..b77f068c 100644 --- a/internal/config/file_test.go +++ b/internal/config/file_test.go @@ -667,6 +667,27 @@ func TestConcurrentSaves(t *testing.T) { } } +func TestGetAIMaxToolCallsPerQuery(t *testing.T) { + tests := []struct { + name string + config AIConfig + want int + }{ + {"default", AIConfig{}, DefaultAIMaxToolCallsPerQuery}, + {"custom", AIConfig{MaxToolCallsPerQuery: 25}, 25}, + {"zero defaults", AIConfig{MaxToolCallsPerQuery: 0}, DefaultAIMaxToolCallsPerQuery}, + {"negative defaults", AIConfig{MaxToolCallsPerQuery: -1}, DefaultAIMaxToolCallsPerQuery}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &FileConfig{AI: tt.config} + if got := cfg.GetAIMaxToolCallsPerQuery(); got != tt.want { + t.Errorf("GetAIMaxToolCallsPerQuery() = %d, want %d", got, tt.want) + } + }) + } +} + func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) } diff --git a/internal/dao/dao.go b/internal/dao/dao.go index 7556c7a8..06c478aa 100644 --- a/internal/dao/dao.go +++ b/internal/dao/dao.go @@ -158,6 +158,10 @@ type profiledResource interface { GetAccountID() string } +type clusterAwareResource interface { + ClusterArn() string +} + func GetResourceRegion(res Resource) string { if rr, ok := res.(regionalResource); ok { return rr.GetRegion() @@ -179,6 +183,14 @@ func GetResourceAccountID(res Resource) string { return "" } +func GetResourceClusterArn(res Resource) string { + unwrapped := UnwrapResource(res) + if cr, ok := unwrapped.(clusterAwareResource); ok { + return cr.ClusterArn() + } + return "" +} + func UnwrapResource(res Resource) Resource { if pr, ok := res.(*ProfiledResource); ok { return pr.Resource diff --git a/internal/render/render.go b/internal/render/render.go index 5194ce37..43ae68b7 100644 --- a/internal/render/render.go +++ b/internal/render/render.go @@ -71,6 +71,19 @@ type Navigator interface { Navigations(resource dao.Resource) []Navigation } +// Toggle defines a list-level toggle for filtering or view modes +type Toggle struct { + Key string // Key to press (e.g., "r") + ContextKey string // Context key for DAO filtering (e.g., "ShowResolved") + LabelOn string // Label when toggle is ON (e.g., "all") + LabelOff string // Label when toggle is OFF (e.g., "active only") +} + +// Toggler is an optional interface for renderers that support list-level toggles +type Toggler interface { + ListToggles() []Toggle +} + // MetricSpecProvider is an optional interface for renderers that support inline metrics. type MetricSpecProvider interface { MetricSpec() *MetricSpec diff --git a/internal/ui/theme.go b/internal/ui/theme.go index 098b24e9..60022708 100644 --- a/internal/ui/theme.go +++ b/internal/ui/theme.go @@ -409,6 +409,24 @@ func FaintStyle() lipgloss.Style { return lipgloss.NewStyle().Faint(true) } +// DimItalicStyle returns a dim italic style (for AI context/thinking) +func DimItalicStyle() lipgloss.Style { + return lipgloss.NewStyle().Foreground(Current().TextDim).Italic(true) +} + +// ItalicStyle returns an italic style +func ItalicStyle() lipgloss.Style { + return lipgloss.NewStyle().Italic(true) +} + +// ChatInputStyle returns a style for chat input with rounded border +func ChatInputStyle() lipgloss.Style { + return lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(Current().Border). + Padding(0, 1) +} + func BoxStyle() lipgloss.Style { return lipgloss.NewStyle(). Border(lipgloss.RoundedBorder()). diff --git a/internal/view/chat_overlay.go b/internal/view/chat_overlay.go new file mode 100644 index 00000000..e2210016 --- /dev/null +++ b/internal/view/chat_overlay.go @@ -0,0 +1,775 @@ +package view + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "charm.land/bubbles/v2/textinput" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + + "github.com/clawscli/claws/internal/ai" + "github.com/clawscli/claws/internal/config" + apperrors "github.com/clawscli/claws/internal/errors" + "github.com/clawscli/claws/internal/log" + "github.com/clawscli/claws/internal/registry" + "github.com/clawscli/claws/internal/ui" +) + +type chatStyles struct { + title lipgloss.Style + context lipgloss.Style + userMsg lipgloss.Style + assistantMsg lipgloss.Style + toolCall lipgloss.Style + toolError lipgloss.Style + thinking lipgloss.Style + input lipgloss.Style + errorMsg lipgloss.Style + mdBold lipgloss.Style + mdCode lipgloss.Style + mdItalic lipgloss.Style +} + +func newChatStyles() chatStyles { + return chatStyles{ + title: ui.TitleStyle(), + context: ui.DimItalicStyle(), + userMsg: ui.TextStyle(), + assistantMsg: ui.SecondaryStyle(), + toolCall: ui.DimStyle(), + toolError: ui.DangerStyle(), + thinking: ui.DimItalicStyle(), + input: ui.ChatInputStyle(), + errorMsg: ui.DangerStyle(), + mdBold: ui.TitleStyle(), + mdCode: ui.SuccessStyle(), + mdItalic: ui.ItalicStyle(), + } +} + +type ChatOverlay struct { + ctx context.Context + registry *registry.Registry + aiCtx *ai.Context + styles chatStyles + + client *ai.Client + executor *ai.ToolExecutor + session *ai.Session + sessMgr *ai.SessionManager + + input textinput.Model + vp ViewportState + + messages []chatMessage + streamingMsg string + streamingThinking string + collapsedThinking map[int]bool + collapsedToolCalls map[int]bool + thinkingLineRanges map[int][2]int + toolCallLineRanges map[int][2]int + isStreaming bool + err error + + // Streaming state - accumulates ContentBlocks for the current assistant turn + pendingToolUses []*ai.ToolUseContent + currentReasoning string + reasoningSignature string + streamMessages []ai.Message + toolRound int + toolCallCount int // Counts tool calls within current query (reset per query) + + width int + height int + + showingHistory bool + sessionHistory *SessionHistory + + statusMsg string + statusMsgTime time.Time + + contextExpanded bool + + // Stream cancellation - prevents goroutine leaks when overlay closes mid-stream + streamCancel context.CancelFunc + streamCancelMu sync.Mutex +} + +// chatMessage is a UI-level message for display purposes. +// It stores extracted text/thinking for rendering. +type chatMessage struct { + role ai.Role + content string + thinkingContent string + toolUse *ai.ToolUseContent + toolResult *ai.ToolResultContent + toolError bool +} + +type chatStreamMsg struct { + event ai.StreamEvent + eventCh <-chan ai.StreamEvent +} + +type chatToolExecuteMsg struct { + // The assistant message with ToolUse blocks that triggered this execution + assistantBlocks []ai.ContentBlock + toolUses []*ai.ToolUseContent + messages []ai.Message + toolRound int +} + +type chatInitMsg struct { + client *ai.Client + executor *ai.ToolExecutor + session *ai.Session + err error +} + +func NewChatOverlay(ctx context.Context, reg *registry.Registry, aiCtx *ai.Context) *ChatOverlay { + cfg := config.File() + + ti := textinput.New() + ti.Placeholder = "Ask about AWS resources..." + ti.Focus() + ti.CharLimit = 500 + + return &ChatOverlay{ + ctx: ctx, + registry: reg, + aiCtx: aiCtx, + styles: newChatStyles(), + input: ti, + sessMgr: ai.NewSessionManager(cfg.GetAIMaxSessions(), cfg.GetAISaveSessions()), + messages: []chatMessage{}, + collapsedThinking: make(map[int]bool), + collapsedToolCalls: make(map[int]bool), + } +} + +func (c *ChatOverlay) Init() tea.Cmd { + return tea.Batch( + textinput.Blink, + c.initClient, + ) +} + +func (c *ChatOverlay) initClient() tea.Msg { + executor, err := ai.NewToolExecutor(c.ctx, c.registry) + if err != nil { + return chatInitMsg{err: apperrors.Wrap(err, "init tool executor")} + } + + client, err := ai.NewClient( + c.ctx, + ai.WithModel(config.File().GetAIModel()), + ai.WithTools(executor.Tools()), + ai.WithMaxTokens(config.File().GetAIMaxTokens()), + ai.WithThinkingBudget(config.File().GetAIThinkingBudget()), + ) + if err != nil { + return chatInitMsg{err: apperrors.Wrap(err, "init ai client")} + } + + session, err := c.sessMgr.NewSession(c.aiCtx) + if err != nil { + return chatInitMsg{err: apperrors.Wrap(err, "create session")} + } + + return chatInitMsg{client: client, executor: executor, session: session} +} + +func (c *ChatOverlay) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + if c.showingHistory { + return c.handleHistoryUpdate(msg) + } + + switch msg := msg.(type) { + case chatInitMsg: + if msg.err != nil { + c.err = msg.err + } else { + c.client = msg.client + c.executor = msg.executor + c.session = msg.session + } + return c, nil + + case tea.KeyPressMsg: + return c.handleKeyPress(msg) + + case chatStreamMsg: + return c.handleStreamEvent(msg) + + case chatToolExecuteMsg: + return c.handleToolExecute(msg) + + case tea.MouseClickMsg: + return c.handleMouseClick(msg) + } + + var cmds []tea.Cmd + + if c.vp.Ready { + var vpCmd tea.Cmd + c.vp.Model, vpCmd = c.vp.Model.Update(msg) + cmds = append(cmds, vpCmd) + } + + var inputCmd tea.Cmd + c.input, inputCmd = c.input.Update(msg) + cmds = append(cmds, inputCmd) + + return c, tea.Batch(cmds...) +} + +func (c *ChatOverlay) cancelStream() { + c.streamCancelMu.Lock() + defer c.streamCancelMu.Unlock() + if c.streamCancel != nil { + c.streamCancel() + c.streamCancel = nil + } +} + +func (c *ChatOverlay) handleKeyPress(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { + if IsEscKey(msg) { + c.cancelStream() + return c, func() tea.Msg { return HideModalMsg{} } + } + + switch msg.String() { + case "ctrl+c": + c.cancelStream() + return c, func() tea.Msg { return HideModalMsg{} } + case "ctrl+h": + return c.showHistory() + case "enter": + if c.isStreaming { + return c, nil + } + + text := strings.TrimSpace(c.input.Value()) + if text == "" { + return c, nil + } + + c.input.SetValue("") + c.messages = append(c.messages, chatMessage{role: ai.RoleUser, content: text}) + c.isStreaming = true + c.streamingMsg = "" + c.streamingThinking = "" + c.pendingToolUses = nil + c.currentReasoning = "" + c.reasoningSignature = "" + c.toolRound = 0 + c.toolCallCount = 0 // Reset per-query tool call counter + c.err = nil + c.updateViewport() + + userMsg := ai.NewUserMessage(text) + c.streamMessages = append(c.streamMessages, userMsg) + if c.session != nil { + if err := c.sessMgr.AddMessage(c.session, userMsg); err != nil { + log.Warn("failed to save user message", "error", err) + c.statusMsg = "Failed to save message" + c.statusMsgTime = time.Now() + } + } + return c, c.startStream(c.streamMessages) + } + + var kpCmd tea.Cmd + c.input, kpCmd = c.input.Update(msg) + return c, kpCmd +} + +func (c *ChatOverlay) handleMouseClick(msg tea.MouseClickMsg) (tea.Model, tea.Cmd) { + if c.aiCtx != nil && c.aiCtx.Service != "" && msg.Y == 1 { + c.contextExpanded = !c.contextExpanded + c.updateViewport() + return c, nil + } + + if !c.vp.Ready { + return c, nil + } + + headerHeight := c.headerHeight() + + contentLine := msg.Y - headerHeight + c.vp.Model.YOffset() + if contentLine < 0 { + return c, nil + } + + for msgIdx, lineRange := range c.thinkingLineRanges { + if contentLine >= lineRange[0] && contentLine < lineRange[1] { + wasCollapsed := c.collapsedThinking[msgIdx] + c.collapsedThinking[msgIdx] = !wasCollapsed + c.scrollToCollapsible(lineRange[0], wasCollapsed) + return c, nil + } + } + + for msgIdx, lineRange := range c.toolCallLineRanges { + if contentLine >= lineRange[0] && contentLine < lineRange[1] { + wasCollapsed := c.collapsedToolCalls[msgIdx] + c.collapsedToolCalls[msgIdx] = !wasCollapsed + c.scrollToCollapsible(lineRange[0], wasCollapsed) + return c, nil + } + } + + return c, nil +} + +func (c *ChatOverlay) startStream(messages []ai.Message) tea.Cmd { + c.cancelStream() + streamCtx, cancel := context.WithCancel(c.ctx) + + c.streamCancelMu.Lock() + c.streamCancel = cancel + c.streamCancelMu.Unlock() + + return func() tea.Msg { + if c.client == nil || c.executor == nil { + return chatStreamMsg{event: ai.StreamEvent{Type: "error", Error: errors.New("client not initialized")}} + } + + systemPrompt := c.buildSystemPrompt() + + eventCh, err := c.client.ConverseStream(streamCtx, messages, systemPrompt) + if err != nil { + return chatStreamMsg{event: ai.StreamEvent{Type: "error", Error: err}} + } + + event, ok := <-eventCh + if !ok { + return chatStreamMsg{event: ai.StreamEvent{Type: "done"}} + } + return chatStreamMsg{event: event, eventCh: eventCh} + } +} + +func (c *ChatOverlay) waitForStream(eventCh <-chan ai.StreamEvent) tea.Cmd { + return func() tea.Msg { + event, ok := <-eventCh + if !ok { + return chatStreamMsg{event: ai.StreamEvent{Type: "done"}} + } + return chatStreamMsg{event: event, eventCh: eventCh} + } +} + +func (c *ChatOverlay) handleStreamEvent(msg chatStreamMsg) (tea.Model, tea.Cmd) { + event := msg.event + + switch event.Type { + case "text": + c.streamingMsg += event.Text + c.updateViewport() + return c, c.waitForStream(msg.eventCh) + + case "thinking": + if event.Thinking != nil { + c.streamingThinking += event.Thinking.Text + } + c.updateViewport() + return c, c.waitForStream(msg.eventCh) + + case "thinking_complete": + // Capture the complete thinking with signature for API replay + if event.Thinking != nil { + c.currentReasoning = event.Thinking.Text + c.reasoningSignature = event.Thinking.Signature + } + return c, c.waitForStream(msg.eventCh) + + case "tool_use": + if event.ToolUse != nil { + c.pendingToolUses = append(c.pendingToolUses, event.ToolUse) + } + return c, c.waitForStream(msg.eventCh) + + case "done": + return c.handleStreamDone(msg.eventCh) + + case "error": + c.err = event.Error + c.isStreaming = false + c.updateViewport() + return c, nil + } + + return c, c.waitForStream(msg.eventCh) +} + +func (c *ChatOverlay) handleStreamDone(_ <-chan ai.StreamEvent) (tea.Model, tea.Cmd) { + // Build the assistant's ContentBlocks from accumulated state + var assistantBlocks []ai.ContentBlock + + // Add reasoning block if present + if c.currentReasoning != "" { + assistantBlocks = append(assistantBlocks, ai.ContentBlock{ + Reasoning: c.currentReasoning, + ReasoningSignature: c.reasoningSignature, + }) + } + + // Add text block if present + if c.streamingMsg != "" { + assistantBlocks = append(assistantBlocks, ai.ContentBlock{Text: c.streamingMsg}) + } + + // Add tool use blocks + for _, tu := range c.pendingToolUses { + assistantBlocks = append(assistantBlocks, ai.ContentBlock{ToolUse: tu}) + } + + // Save to UI messages for display + if c.streamingMsg != "" || c.streamingThinking != "" { + c.messages = append(c.messages, chatMessage{ + role: ai.RoleAssistant, + content: c.streamingMsg, + thinkingContent: c.streamingThinking, + }) + if c.streamingThinking != "" { + c.collapsedThinking[len(c.messages)-1] = true + } + } + + // If there are tool uses, execute them + if len(c.pendingToolUses) > 0 && c.toolRound < config.File().GetAIMaxToolRounds() { + c.updateViewport() + + // Save assistant message with tool uses to session + if c.session != nil && len(assistantBlocks) > 0 { + assistantMsg := ai.Message{ + Role: ai.RoleAssistant, + Content: assistantBlocks, + } + c.streamMessages = append(c.streamMessages, assistantMsg) + if err := c.sessMgr.AddMessage(c.session, assistantMsg); err != nil { + log.Warn("failed to save assistant message with tool uses", "error", err) + } + } + + // Clear streaming state before tool execution + toolUses := c.pendingToolUses + c.pendingToolUses = nil + c.streamingMsg = "" + c.streamingThinking = "" + c.currentReasoning = "" + c.reasoningSignature = "" + c.toolRound++ + + return c, func() tea.Msg { + return chatToolExecuteMsg{ + assistantBlocks: assistantBlocks, + toolUses: toolUses, + messages: c.streamMessages, + toolRound: c.toolRound, + } + } + } + + // No tool uses or max rounds reached - done + if len(assistantBlocks) > 0 { + assistantMsg := ai.Message{ + Role: ai.RoleAssistant, + Content: assistantBlocks, + } + c.streamMessages = append(c.streamMessages, assistantMsg) + if c.session != nil { + if err := c.sessMgr.AddMessage(c.session, assistantMsg); err != nil { + log.Warn("failed to save assistant message", "error", err) + c.statusMsg = "Failed to save message" + c.statusMsgTime = time.Now() + } + } + } + + if len(c.pendingToolUses) > 0 && c.toolRound >= config.File().GetAIMaxToolRounds() { + c.messages = append(c.messages, chatMessage{ + role: ai.RoleAssistant, + content: "(tool limit reached)", + }) + } + + c.streamingMsg = "" + c.streamingThinking = "" + c.currentReasoning = "" + c.reasoningSignature = "" + c.pendingToolUses = nil + c.isStreaming = false + c.updateViewport() + return c, nil +} + +func (c *ChatOverlay) handleToolExecute(msg chatToolExecuteMsg) (tea.Model, tea.Cmd) { + maxCalls := config.File().GetAIMaxToolCallsPerQuery() + + // Execute each tool and collect results + var toolResults []ai.ToolResultContent + for _, tu := range msg.toolUses { + // Check tool call limit before executing each tool + if c.toolCallCount >= maxCalls { + c.err = fmt.Errorf("Tool call limit reached (%d calls). Start new query to continue.", maxCalls) + c.isStreaming = false + c.updateViewport() + return c, nil + } + + result := c.executor.Execute(c.ctx, tu) + toolResults = append(toolResults, result) + c.toolCallCount++ + + c.messages = append(c.messages, chatMessage{ + content: result.Content, + toolUse: tu, + toolResult: &result, + toolError: result.IsError, + }) + c.collapsedToolCalls[len(c.messages)-1] = true + } + c.updateViewport() + + // Build the new messages to send to API: + // 1. Previous messages (including assistant message with tool uses from handleStreamDone) + // 2. User message with tool results + + messages := make([]ai.Message, len(msg.messages), len(msg.messages)+1) + copy(messages, msg.messages) + + // Add user message with tool results + var resultBlocks []ai.ContentBlock + for _, tr := range toolResults { + resultBlocks = append(resultBlocks, ai.ContentBlock{ToolResult: &tr}) + } + messages = append(messages, ai.Message{ + Role: ai.RoleUser, + Content: resultBlocks, + }) + + c.streamMessages = messages + c.isStreaming = true + + // Save tool result message to session + if c.session != nil { + toolResultMsg := messages[len(messages)-1] // Last message with tool results + if err := c.sessMgr.AddMessage(c.session, toolResultMsg); err != nil { + log.Warn("failed to save tool result message", "error", err) + } + } + + return c, c.startStream(messages) +} + +func (c *ChatOverlay) View() tea.View { + return tea.NewView(c.ViewString()) +} + +func (c *ChatOverlay) ViewString() string { + if c.showingHistory && c.sessionHistory != nil { + return c.sessionHistory.ViewString() + } + + var sb strings.Builder + + title := c.styles.title.Render("AI Chat") + hint := c.styles.context.Render("Ctrl+h: history") + titleWidth := lipgloss.Width(title) + hintWidth := lipgloss.Width(hint) + padding := c.width - titleWidth - hintWidth + if padding < 1 { + padding = 1 + } + sb.WriteString(title + strings.Repeat(" ", padding) + hint) + sb.WriteString("\n") + + if c.aiCtx != nil && c.aiCtx.Service != "" { + indicator := "▶" + if c.contextExpanded { + indicator = "▼" + } + ctx := fmt.Sprintf("Context: %s", c.aiCtx.Service) + if c.aiCtx.ResourceType != "" { + ctx += "/" + c.aiCtx.ResourceType + } + if c.aiCtx.ResourceName != "" { + ctx += " - " + c.aiCtx.ResourceName + } + ctx += " [" + indicator + "]" + sb.WriteString(c.styles.context.Render(ctx)) + sb.WriteString("\n") + } + sb.WriteString("\n") + + if c.vp.Ready { + sb.WriteString(c.vp.Model.View()) + } else { + sb.WriteString(c.renderMessages()) + } + + sb.WriteString("\n") + sb.WriteString(c.styles.input.Render(c.input.View())) + + return sb.String() +} + +func (c *ChatOverlay) SetSize(width, height int) tea.Cmd { + c.width = width + c.height = height + + vpHeight := height - 8 + if vpHeight < 5 { + vpHeight = 5 + } + + c.vp.SetSize(width, vpHeight) + c.updateViewport() + + return nil +} + +func (c *ChatOverlay) StatusLine() string { + if c.statusMsg != "" && time.Since(c.statusMsgTime) < 3*time.Second { + return c.statusMsg + } + return "AI Chat | Enter: send | Esc: close" +} + +func (c *ChatOverlay) headerHeight() int { + lines := 2 + if c.aiCtx != nil && c.aiCtx.Service != "" { + ctx := fmt.Sprintf("Context: %s", c.aiCtx.Service) + if c.aiCtx.ResourceType != "" { + ctx += "/" + c.aiCtx.ResourceType + } + if c.aiCtx.ResourceName != "" { + ctx += " - " + c.aiCtx.ResourceName + } + rendered := c.styles.context.Render(ctx) + lines += strings.Count(rendered, "\n") + 1 + } + return lines +} + +func (c *ChatOverlay) HasActiveInput() bool { + return true +} + +func (c *ChatOverlay) scrollToCollapsible(startLine int, wasCollapsed bool) { + if !c.vp.Ready { + return + } + content := c.renderMessages() + c.vp.Model.SetContent(content) + if wasCollapsed { + c.vp.Model.SetYOffset(startLine) + } +} + +func (c *ChatOverlay) showHistory() (tea.Model, tea.Cmd) { + sessions, _ := c.sessMgr.ListSessions() + currentID := "" + if c.session != nil { + currentID = c.session.ID + } + c.sessionHistory = NewSessionHistory(sessions, currentID) + c.sessionHistory.SetSize(c.width, c.height) + c.showingHistory = true + return c, nil +} + +func (c *ChatOverlay) handleHistoryUpdate(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case SessionSelectedMsg: + c.showingHistory = false + c.sessionHistory = nil + if msg.Session != nil { + return c.loadSession(msg.Session) + } + return c, nil + + case NewSessionMsg: + c.showingHistory = false + c.sessionHistory = nil + return c.newSession() + + case CloseHistoryMsg: + c.showingHistory = false + c.sessionHistory = nil + return c, nil + } + + if c.sessionHistory != nil { + model, cmd := c.sessionHistory.Update(msg) + if sh, ok := model.(*SessionHistory); ok { + c.sessionHistory = sh + } + return c, cmd + } + return c, nil +} + +func (c *ChatOverlay) loadSession(sess *ai.Session) (tea.Model, tea.Cmd) { + if sess == nil { + return c, nil + } + + c.cancelStream() + if c.isStreaming { + c.isStreaming = false + c.streamingMsg = "" + c.streamingThinking = "" + c.pendingToolUses = nil + c.currentReasoning = "" + c.reasoningSignature = "" + } + + c.session = sess + c.messages = []chatMessage{} + c.streamMessages = []ai.Message{} + c.collapsedThinking = make(map[int]bool) + c.collapsedToolCalls = make(map[int]bool) + c.toolCallCount = 0 // Reset per-query counter + + for _, msg := range sess.Messages { + cm := chatMessage{role: msg.Role} + for _, block := range msg.Content { + if block.Text != "" { + cm.content = block.Text + } + if block.Reasoning != "" { + cm.thinkingContent = block.Reasoning + } + } + c.messages = append(c.messages, cm) + c.streamMessages = append(c.streamMessages, msg) + } + + c.updateViewport() + return c, nil +} + +func (c *ChatOverlay) newSession() (tea.Model, tea.Cmd) { + session, err := c.sessMgr.NewSession(c.aiCtx) + if err != nil { + c.err = err + return c, nil + } + c.session = session + c.messages = []chatMessage{} + c.streamMessages = []ai.Message{} + c.collapsedThinking = make(map[int]bool) + c.collapsedToolCalls = make(map[int]bool) + c.toolCallCount = 0 // Reset per-query counter + c.updateViewport() + return c, nil +} diff --git a/internal/view/chat_overlay_prompt.go b/internal/view/chat_overlay_prompt.go new file mode 100644 index 00000000..2e7d1eb9 --- /dev/null +++ b/internal/view/chat_overlay_prompt.go @@ -0,0 +1,230 @@ +package view + +import ( + "fmt" + "strings" + + "github.com/clawscli/claws/internal/ai" + "github.com/clawscli/claws/internal/config" +) + +// formatProfileName converts internal profile ID to display name +func formatProfileName(profileID string) string { + sel := config.ProfileSelectionFromID(profileID) + if sel.Mode == config.ModeNamedProfile { + return sel.ProfileName + } + return sel.Mode.String() +} + +func (c *ChatOverlay) buildSystemPrompt() string { + services := c.registry.ListServices() + serviceList := strings.Join(services, ", ") + + prompt := fmt.Sprintf(`You are an AWS resource assistant in claws TUI. + + +%s + + + +When a user asks about AWS resources, you MUST call the appropriate tool. Do not just describe what you would do - actually call the tool. +Use ONLY the service names listed in available_services above. Do not guess or use similar names. +All resource tools require a region parameter. Use profile parameter when querying cross-profile resources. + +Available tools: +- list_resources(service): Lists resource types for a service +- query_resources(service, resource_type, region, profile?, limit?, offset?): Lists resources (default: 100, max: 2000, supports pagination) +- get_resource_detail(service, resource_type, region, id, cluster?, profile?): Gets resource details +- tail_logs(service, resource_type, region, id, cluster?, profile?): Fetches CloudWatch logs for a resource + - Supported: lambda/functions, ecs/services, ecs/tasks, ecs/task-definitions, codebuild/projects, codebuild/builds, cloudtrail/trails, apigateway/stages, apigateway/stages-v2, stepfunctions/state-machines + - cluster parameter required for ecs/services and ecs/tasks +- search_aws_docs(query): Search AWS documentation + + + +Be concise. Use markdown for formatting. +`, serviceList) + + if c.aiCtx != nil { + if len(c.aiCtx.UserRegions) > 0 { + prompt += "\n\n" + prompt += strings.Join(c.aiCtx.UserRegions, ", ") + prompt += "\nThese are ALL regions the user is currently browsing." + prompt += "\nIn list mode, query resources across ALL these regions (call query_resources for each)." + prompt += "\nFor specific resources (detail/diff mode), use the region from current_context instead." + prompt += "\n" + } + + if len(c.aiCtx.UserProfiles) > 0 { + prompt += "\n\n" + prompt += strings.Join(c.aiCtx.UserProfiles, ", ") + prompt += "\nThese are ALL profile IDs the user is currently browsing." + prompt += "\nIn list mode, query resources across ALL these profiles (call query_resources for each)." + prompt += "\nFor specific resources (detail/diff mode), use the profile from current_context instead." + prompt += "\n" + } + + switch c.aiCtx.Mode { + case ai.ContextModeList: + prompt += c.buildListContextPrompt() + case ai.ContextModeDiff: + prompt += c.buildDiffContextPrompt() + default: + prompt += c.buildSingleContextPrompt() + } + } + + return prompt +} + +func (c *ChatOverlay) buildListContextPrompt() string { + ctx := c.aiCtx + if ctx.Service == "" { + return "" + } + + prompt := fmt.Sprintf("\n\nservice=%s, resource_type=%s", ctx.Service, ctx.ResourceType) + prompt += fmt.Sprintf(", count=%d", ctx.ResourceCount) + if ctx.FilterText != "" { + prompt += fmt.Sprintf(", filter=\"%s\"", ctx.FilterText) + } + if ctx.Service == "securityhub" && ctx.ResourceType == "findings" { + if ctx.Toggles["ShowResolved"] { + prompt += ", show_resolved=true" + } else { + prompt += ", show_resolved=false (use include_resolved=true in query_resources for all)" + } + } + prompt += "\n" + prompt += "\nIMPORTANT: When the user asks to list or analyze resources, call query_resources for EACH combination of user_selected_regions and user_selected_profiles to get the complete view across all selected contexts." + return prompt +} + +func (c *ChatOverlay) buildDiffContextPrompt() string { + ctx := c.aiCtx + if ctx.DiffLeft == nil || ctx.DiffRight == nil { + return "" + } + + prompt := fmt.Sprintf("\n\nservice=%s, resource_type=%s", ctx.Service, ctx.ResourceType) + prompt += fmt.Sprintf("\nleft: id=%s, name=%s", ctx.DiffLeft.ID, ctx.DiffLeft.Name) + if ctx.DiffLeft.Region != "" { + prompt += fmt.Sprintf(", region=%s", ctx.DiffLeft.Region) + } + if ctx.DiffLeft.Profile != "" { + prompt += fmt.Sprintf(", profile=%s", ctx.DiffLeft.Profile) + } + if ctx.DiffLeft.Cluster != "" { + prompt += fmt.Sprintf(", cluster=%s", ctx.DiffLeft.Cluster) + } + prompt += fmt.Sprintf("\nright: id=%s, name=%s", ctx.DiffRight.ID, ctx.DiffRight.Name) + if ctx.DiffRight.Region != "" { + prompt += fmt.Sprintf(", region=%s", ctx.DiffRight.Region) + } + if ctx.DiffRight.Profile != "" { + prompt += fmt.Sprintf(", profile=%s", ctx.DiffRight.Profile) + } + if ctx.DiffRight.Cluster != "" { + prompt += fmt.Sprintf(", cluster=%s", ctx.DiffRight.Cluster) + } + prompt += "\n" + prompt += "\nIMPORTANT: Call get_resource_detail twice (once for left, once for right) using each resource's specific region and profile." + return prompt +} + +func (c *ChatOverlay) buildSingleContextPrompt() string { + ctx := c.aiCtx + if ctx.Service == "" { + return "" + } + + prompt := fmt.Sprintf("\nservice=%s", ctx.Service) + if ctx.ResourceType != "" { + prompt += ", resource_type=" + ctx.ResourceType + } + if ctx.ResourceRegion != "" { + prompt += ", region=" + ctx.ResourceRegion + } + if ctx.ResourceID != "" { + prompt += ", id=" + ctx.ResourceID + } + if ctx.ResourceProfile != "" { + prompt += ", profile=" + ctx.ResourceProfile + } + if ctx.Cluster != "" { + prompt += ", cluster=" + ctx.Cluster + } + prompt += "" + prompt += "\nIMPORTANT: Use the region and profile from current_context when querying this resource." + return prompt +} + +func (c *ChatOverlay) renderContextParams() string { + ctx := c.aiCtx + if ctx == nil { + return "" + } + + var lines []string + lines = append(lines, fmt.Sprintf(" mode: %s", ctx.Mode)) + + if len(ctx.UserRegions) > 0 { + lines = append(lines, fmt.Sprintf(" regions: %s", strings.Join(ctx.UserRegions, ", "))) + } + if len(ctx.UserProfiles) > 0 { + var profileNames []string + for _, pid := range ctx.UserProfiles { + profileNames = append(profileNames, formatProfileName(pid)) + } + lines = append(lines, fmt.Sprintf(" profiles: %s", strings.Join(profileNames, ", "))) + } + if ctx.ResourceCount > 0 { + lines = append(lines, fmt.Sprintf(" count: %d", ctx.ResourceCount)) + } + if ctx.FilterText != "" { + lines = append(lines, fmt.Sprintf(" filter: %s", ctx.FilterText)) + } + if ctx.ResourceID != "" { + lines = append(lines, fmt.Sprintf(" id: %s", ctx.ResourceID)) + } + if ctx.ResourceRegion != "" { + lines = append(lines, fmt.Sprintf(" region: %s", ctx.ResourceRegion)) + } + if ctx.ResourceProfile != "" { + profileName := formatProfileName(ctx.ResourceProfile) + lines = append(lines, fmt.Sprintf(" profile: %s", profileName)) + } + if ctx.Cluster != "" { + lines = append(lines, fmt.Sprintf(" cluster: %s", ctx.Cluster)) + } + if ctx.DiffLeft != nil && ctx.DiffRight != nil { + left := fmt.Sprintf("%s/%s", ctx.DiffLeft.ID, ctx.DiffLeft.Name) + if ctx.DiffLeft.Profile != "" { + profileName := formatProfileName(ctx.DiffLeft.Profile) + left += fmt.Sprintf(" [%s]", profileName) + } + if ctx.DiffLeft.Region != "" { + left += fmt.Sprintf(" (%s)", ctx.DiffLeft.Region) + } + right := fmt.Sprintf("%s/%s", ctx.DiffRight.ID, ctx.DiffRight.Name) + if ctx.DiffRight.Profile != "" { + profileName := formatProfileName(ctx.DiffRight.Profile) + right += fmt.Sprintf(" [%s]", profileName) + } + if ctx.DiffRight.Region != "" { + right += fmt.Sprintf(" (%s)", ctx.DiffRight.Region) + } + lines = append(lines, fmt.Sprintf(" left: %s", left)) + lines = append(lines, fmt.Sprintf(" right: %s", right)) + } + if ctx.Service == "securityhub" && ctx.ResourceType == "findings" { + showResolved := "false" + if ctx.Toggles["ShowResolved"] { + showResolved = "true" + } + lines = append(lines, fmt.Sprintf(" show_resolved: %s", showResolved)) + } + + return strings.Join(lines, "\n") + "\n" +} diff --git a/internal/view/chat_overlay_render.go b/internal/view/chat_overlay_render.go new file mode 100644 index 00000000..2ef3a588 --- /dev/null +++ b/internal/view/chat_overlay_render.go @@ -0,0 +1,237 @@ +package view + +import ( + "fmt" + "regexp" + "sort" + "strings" + + "github.com/mattn/go-runewidth" + + "github.com/clawscli/claws/internal/ai" +) + +func (c *ChatOverlay) updateViewport() { + if !c.vp.Ready { + return + } + content := c.renderMessages() + c.vp.Model.SetContent(content) + c.vp.Model.GotoBottom() +} + +func (c *ChatOverlay) renderMessages() string { + var sb strings.Builder + w := c.wrapWidth() + lineNum := 0 + c.thinkingLineRanges = make(map[int][2]int) + c.toolCallLineRanges = make(map[int][2]int) + + if c.contextExpanded && c.aiCtx != nil { + params := c.renderContextParams() + for _, line := range strings.Split(strings.TrimSuffix(params, "\n"), "\n") { + sb.WriteString(c.styles.context.Render(line)) + sb.WriteString("\n") + lineNum++ + } + sb.WriteString("\n") + lineNum++ + } + + for i, msg := range c.messages { + if msg.toolUse != nil { + startLine := lineNum + toolStr := c.renderToolCall(i, msg.toolUse, msg.toolError, w) + sb.WriteString(toolStr) + lineNum += strings.Count(toolStr, "\n") + c.toolCallLineRanges[i] = [2]int{startLine, lineNum} + } else { + switch msg.role { + case ai.RoleUser: + userText := c.styles.userMsg.Render(wrapText("You: "+msg.content, w)) + sb.WriteString(userText) + sb.WriteString("\n") + lineNum += strings.Count(userText, "\n") + 1 + case ai.RoleAssistant: + if msg.thinkingContent != "" { + startLine := lineNum + thinkingStr := c.renderThinking(i, msg.thinkingContent, w) + sb.WriteString(thinkingStr) + lineNum += strings.Count(thinkingStr, "\n") + c.thinkingLineRanges[i] = [2]int{startLine, lineNum} + if msg.content != "" { + sb.WriteString("\n") + lineNum++ + } + } + if msg.content != "" { + rendered := c.renderMarkdown(msg.content, w) + contentStr := c.styles.assistantMsg.Render("AI: ") + "\n" + rendered + sb.WriteString(contentStr) + sb.WriteString("\n") + lineNum += strings.Count(contentStr, "\n") + 1 + } + } + } + sb.WriteString("\n") + lineNum++ + } + + if c.streamingThinking != "" { + sb.WriteString(c.styles.thinking.Render("💭 ▶ Thinking...")) + sb.WriteString("\n") + if c.streamingMsg != "" { + sb.WriteString("\n") + } + } + if c.streamingMsg != "" { + sb.WriteString(c.styles.assistantMsg.Render("AI: ")) + sb.WriteString("\n") + sb.WriteString(wrapText(c.streamingMsg, w)) + sb.WriteString("\n") + } else if c.isStreaming && c.streamingThinking == "" { + sb.WriteString(c.styles.thinking.Render("⏳ Waiting...")) + sb.WriteString("\n") + } + + if c.err != nil { + sb.WriteString(c.styles.errorMsg.Render(wrapText("Error: "+c.err.Error(), w))) + sb.WriteString("\n") + } + + return sb.String() +} + +func (c *ChatOverlay) renderThinking(idx int, content string, width int) string { + collapsed := c.collapsedThinking[idx] + var sb strings.Builder + + if collapsed { + sb.WriteString(c.styles.thinking.Render("💭 ▶ [click to expand]")) + sb.WriteString("\n") + } else { + sb.WriteString(c.styles.thinking.Render("💭 ▼ Thinking:")) + sb.WriteString("\n") + wrapped := wrapText(content, width-2) + for _, line := range strings.Split(wrapped, "\n") { + sb.WriteString(c.styles.thinking.Render(" " + line)) + sb.WriteString("\n") + } + } + return sb.String() +} + +func (c *ChatOverlay) renderToolCall(idx int, tu *ai.ToolUseContent, isError bool, width int) string { + collapsed := c.collapsedToolCalls[idx] + style := c.styles.toolCall + if isError { + style = c.styles.toolError + } + + var sb strings.Builder + paramCount := len(tu.Input) + + if collapsed { + summary := fmt.Sprintf("🔧 %s ▶ [%d params]", tu.Name, paramCount) + sb.WriteString(style.Render(wrapText(summary, width))) + sb.WriteString("\n") + } else { + header := fmt.Sprintf("🔧 %s ▼", tu.Name) + sb.WriteString(style.Render(header)) + sb.WriteString("\n") + + keys := make([]string, 0, len(tu.Input)) + for k := range tu.Input { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + v := tu.Input[k] + line := fmt.Sprintf(" %s: %v", k, v) + sb.WriteString(style.Render(wrapText(line, width))) + sb.WriteString("\n") + } + } + return sb.String() +} + +func (c *ChatOverlay) wrapWidth() int { + if c.width > 4 { + return c.width - 4 + } + return 76 +} + +var ( + mdBold = regexp.MustCompile(`\*\*([^*]+)\*\*`) + mdItalic = regexp.MustCompile(`\*([^*]+)\*`) + mdCode = regexp.MustCompile("`([^`]+)`") +) + +func (c *ChatOverlay) renderMarkdown(text string, width int) string { + wrapped := wrapText(text, width) + + wrapped = mdBold.ReplaceAllStringFunc(wrapped, func(m string) string { + inner := mdBold.FindStringSubmatch(m)[1] + return c.styles.mdBold.Render(inner) + }) + wrapped = mdCode.ReplaceAllStringFunc(wrapped, func(m string) string { + inner := mdCode.FindStringSubmatch(m)[1] + return c.styles.mdCode.Render(inner) + }) + wrapped = mdItalic.ReplaceAllStringFunc(wrapped, func(m string) string { + inner := mdItalic.FindStringSubmatch(m)[1] + return c.styles.mdItalic.Render(inner) + }) + + return wrapped +} + +func wrapText(text string, width int) string { + if width <= 0 { + width = 76 + } + var lines []string + for _, line := range strings.Split(text, "\n") { + lines = append(lines, wrapLine(line, width)...) + } + return strings.Join(lines, "\n") +} + +func wrapLine(line string, width int) []string { + if len(line) == 0 { + return []string{""} + } + runes := []rune(line) + lineWidth := 0 + for _, r := range runes { + lineWidth += runeWidth(r) + } + if lineWidth <= width { + return []string{line} + } + + var lines []string + var current []rune + currentWidth := 0 + + for _, r := range runes { + rw := runeWidth(r) + if currentWidth+rw > width && len(current) > 0 { + lines = append(lines, string(current)) + current = nil + currentWidth = 0 + } + current = append(current, r) + currentWidth += rw + } + if len(current) > 0 { + lines = append(lines, string(current)) + } + return lines +} + +func runeWidth(r rune) int { + return runewidth.RuneWidth(r) +} diff --git a/internal/view/detail_view.go b/internal/view/detail_view.go index b7694368..2b1d1a65 100644 --- a/internal/view/detail_view.go +++ b/internal/view/detail_view.go @@ -225,7 +225,6 @@ func (d *DetailView) SetSize(width, height int) tea.Cmd { return nil } -// StatusLine implements View func (d *DetailView) StatusLine() string { parts := []string{d.resource.GetID()} @@ -251,6 +250,18 @@ func (d *DetailView) StatusLine() string { return strings.Join(parts, " • ") } +func (d *DetailView) Resource() dao.Resource { + return d.resource +} + +func (d *DetailView) Service() string { + return d.service +} + +func (d *DetailView) ResourceType() string { + return d.resType +} + // getNavigationShortcuts returns a string of navigation shortcuts for the current resource func (d *DetailView) getNavigationShortcuts() string { if d.renderer == nil { @@ -309,19 +320,3 @@ func (d *DetailView) renderGenericDetail() string { return out } - -// mergeResources merges the refreshed resource with the original to preserve -// fields that are only available from List() but not from Get(). -func mergeResources(original, refreshed dao.Resource) dao.Resource { - if original == nil { - return refreshed - } - if refreshed == nil { - return original - } - // If refreshed resource implements Mergeable, let it copy fields from original - if m, ok := refreshed.(dao.Mergeable); ok { - m.MergeFrom(original) - } - return refreshed -} diff --git a/internal/view/diff_view.go b/internal/view/diff_view.go index 90abe26b..99f66273 100644 --- a/internal/view/diff_view.go +++ b/internal/view/diff_view.go @@ -108,7 +108,7 @@ func (d *DiffView) SetSize(width, height int) tea.Cmd { // StatusLine implements View func (d *DiffView) StatusLine() string { - return d.left.GetName() + " vs " + d.right.GetName() + " • ↑/↓:scroll • q/esc:back" + return dao.UnwrapResource(d.left).GetName() + " vs " + dao.UnwrapResource(d.right).GetName() + " • ↑/↓:scroll • q/esc:back" } // renderSideBySide generates the side-by-side view @@ -124,8 +124,8 @@ func (d *DiffView) renderSideBySide() string { leftDetail := "" rightDetail := "" if d.renderer != nil { - leftDetail = d.renderer.RenderDetail(d.left) - rightDetail = d.renderer.RenderDetail(d.right) + leftDetail = d.renderer.RenderDetail(dao.UnwrapResource(d.left)) + rightDetail = d.renderer.RenderDetail(dao.UnwrapResource(d.right)) } // Split into lines @@ -136,8 +136,8 @@ func (d *DiffView) renderSideBySide() string { colWidth := (d.width - 3) / 2 // Column headers - leftHeader := TruncateOrPadString("◀ "+d.left.GetName(), colWidth) - rightHeader := TruncateOrPadString(d.right.GetName()+" ▶", colWidth) + leftHeader := TruncateOrPadString("◀ "+dao.UnwrapResource(d.left).GetName(), colWidth) + rightHeader := TruncateOrPadString(dao.UnwrapResource(d.right).GetName()+" ▶", colWidth) out.WriteString(s.header.Render(leftHeader)) out.WriteString(s.separator.Render(" │ ")) out.WriteString(s.header.Render(rightHeader)) @@ -169,3 +169,8 @@ func (d *DiffView) renderSideBySide() string { return out.String() } + +func (d *DiffView) Left() dao.Resource { return d.left } +func (d *DiffView) Right() dao.Resource { return d.right } +func (d *DiffView) Service() string { return d.service } +func (d *DiffView) ResourceType() string { return d.resourceType } diff --git a/internal/view/log_view.go b/internal/view/log_view.go index b4d69a9b..90253f93 100644 --- a/internal/view/log_view.go +++ b/internal/view/log_view.go @@ -351,7 +351,7 @@ func (v *LogView) Update(msg tea.Msg) (tea.Model, tea.Cmd) { v.updateViewportContent() v.SetSize(v.width, v.height) // Recalculate viewport height } - return v, nil + return v, tea.ClearScreen } v.logs = v.logs[:0] v.oldestEventTime = 0 @@ -437,7 +437,7 @@ func (v *LogView) handleFilterInput(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { v.updateViewportContent() } - return v, cmd + return v, tea.Batch(cmd, tea.ClearScreen) } } @@ -566,3 +566,7 @@ func (v *LogView) StatusLine() string { func (v *LogView) HasActiveInput() bool { return v.filterActive } + +func (v *LogView) LogGroupName() string { + return v.logGroupName +} diff --git a/internal/view/modal.go b/internal/view/modal.go index 16b27966..3700b802 100644 --- a/internal/view/modal.go +++ b/internal/view/modal.go @@ -23,6 +23,7 @@ const ( ModalWidthProfile = 55 ModalWidthProfileDetail = 65 ModalWidthActionMenu = 60 + ModalWidthChat = 80 ) type Modal struct { diff --git a/internal/view/resource_browser.go b/internal/view/resource_browser.go index d7d71143..30810545 100644 --- a/internal/view/resource_browser.go +++ b/internal/view/resource_browser.go @@ -122,6 +122,9 @@ type ResourceBrowser struct { // Partial region errors (for multi-region queries) partialErrors []string + + // List-level toggles (e.g., show resolved findings) + toggleStates map[string]bool } // NewResourceBrowser creates a new ResourceBrowser @@ -175,8 +178,9 @@ func newResourceBrowser(ctx context.Context, reg *registry.Registry, service, re spinner: ui.NewSpinner(), styles: newResourceBrowserStyles(), pageSize: 100, - sortColumn: -1, // -1 = no sort + sortColumn: -1, sortAscending: true, + toggleStates: make(map[string]bool), } } @@ -350,7 +354,7 @@ func (r *ResourceBrowser) contextForResource(res dao.Resource) (context.Context, if region := dao.GetResourceRegion(res); region != "" { ctx = aws.WithRegionOverride(ctx, region) } - return ctx, dao.UnwrapResource(res) + return ctx, res } func (r *ResourceBrowser) renderTabs() string { diff --git a/internal/view/resource_browser_fetch.go b/internal/view/resource_browser_fetch.go index c0131f57..b5295a5f 100644 --- a/internal/view/resource_browser_fetch.go +++ b/internal/view/resource_browser_fetch.go @@ -25,7 +25,12 @@ type listResourcesResult struct { func (r *ResourceBrowser) listResourcesWithContext(ctx context.Context, d dao.DAO) listResourcesResult { listCtx := ctx if r.fieldFilter != "" && r.fieldFilterValue != "" { - listCtx = dao.WithFilter(ctx, r.fieldFilter, r.fieldFilterValue) + listCtx = dao.WithFilter(listCtx, r.fieldFilter, r.fieldFilterValue) + } + for key, val := range r.toggleStates { + if val { + listCtx = dao.WithFilter(listCtx, key, "true") + } } var resources []dao.Resource @@ -203,7 +208,12 @@ func (r *ResourceBrowser) fetchWithDAO(ctx context.Context, d dao.DAO, token str if pagDAO, ok := d.(dao.PaginatedDAO); ok { listCtx := ctx if r.fieldFilter != "" && r.fieldFilterValue != "" { - listCtx = dao.WithFilter(ctx, r.fieldFilter, r.fieldFilterValue) + listCtx = dao.WithFilter(listCtx, r.fieldFilter, r.fieldFilterValue) + } + for key, val := range r.toggleStates { + if val { + listCtx = dao.WithFilter(listCtx, key, "true") + } } resources, nextToken, err := pagDAO.ListPage(listCtx, r.pageSize, token) return listResourcesResult{resources: resources, nextToken: nextToken, err: err} @@ -411,7 +421,12 @@ func (r *ResourceBrowser) loadNextPage() tea.Msg { listCtx := r.ctx if r.fieldFilter != "" && r.fieldFilterValue != "" { - listCtx = dao.WithFilter(r.ctx, r.fieldFilter, r.fieldFilterValue) + listCtx = dao.WithFilter(listCtx, r.fieldFilter, r.fieldFilterValue) + } + for key, val := range r.toggleStates { + if val { + listCtx = dao.WithFilter(listCtx, key, "true") + } } resources, nextToken, err := pagDAO.ListPage(listCtx, r.pageSize, r.nextPageToken) diff --git a/internal/view/resource_browser_input.go b/internal/view/resource_browser_input.go index eade01bb..5cedc3f6 100644 --- a/internal/view/resource_browser_input.go +++ b/internal/view/resource_browser_input.go @@ -7,6 +7,7 @@ import ( "github.com/clawscli/claws/internal/action" "github.com/clawscli/claws/internal/clipboard" "github.com/clawscli/claws/internal/dao" + "github.com/clawscli/claws/internal/render" ) func (r *ResourceBrowser) handleKeyPress(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { @@ -20,6 +21,10 @@ func (r *ResourceBrowser) handleKeyPress(msg tea.KeyPressMsg) (tea.Model, tea.Cm } } + if model, cmd := r.handleToggleKey(msg.String()); cmd != nil { + return model, cmd + } + switch msg.String() { case "/": r.filterActive = true @@ -173,7 +178,7 @@ func (r *ResourceBrowser) handleEnter() (tea.Model, tea.Cmd) { if len(r.filtered) > 0 && cursor >= 0 && cursor < len(r.filtered) { ctx, resource := r.contextForResource(r.filtered[cursor]) if r.markedResource != nil && r.markedResource.GetID() != resource.GetID() { - diffView := NewDiffView(ctx, dao.UnwrapResource(r.markedResource), resource, r.renderer, r.service, r.resourceType) + diffView := NewDiffView(ctx, r.markedResource, resource, r.renderer, r.service, r.resourceType) return r, func() tea.Msg { return NavigateMsg{View: diffView} } @@ -343,3 +348,21 @@ func (r *ResourceBrowser) handleCopyARN() (tea.Model, tea.Cmd) { } return r, nil } + +func (r *ResourceBrowser) handleToggleKey(key string) (tea.Model, tea.Cmd) { + if r.renderer == nil { + return nil, nil + } + toggler, ok := r.renderer.(render.Toggler) + if !ok { + return nil, nil + } + for _, toggle := range toggler.ListToggles() { + if toggle.Key == key { + r.toggleStates[toggle.ContextKey] = !r.toggleStates[toggle.ContextKey] + r.loading = true + return r, tea.Batch(r.loadResources, r.spinner.Tick) + } + } + return nil, nil +} diff --git a/internal/view/resource_browser_nav.go b/internal/view/resource_browser_nav.go index f60f63f1..adc300b8 100644 --- a/internal/view/resource_browser_nav.go +++ b/internal/view/resource_browser_nav.go @@ -3,11 +3,13 @@ package view import ( "fmt" "slices" + "strings" tea "charm.land/bubbletea/v2" "github.com/clawscli/claws/internal/action" "github.com/clawscli/claws/internal/dao" + "github.com/clawscli/claws/internal/render" ) // handleNavigation processes navigation key shortcuts @@ -90,6 +92,7 @@ func (r *ResourceBrowser) StatusLine() string { } navInfo := r.getNavigationShortcuts() + toggleInfo := r.getToggleInfo() dHint := "d:describe" if r.markedResource != nil && markInFiltered { @@ -113,7 +116,7 @@ func (r *ResourceBrowser) StatusLine() string { } if r.filterText != "" || filterInfo != "" { - base := fmt.Sprintf("%s/%s%s%s%s%s%s • %d/%d items • c:clear", r.service, r.resourceType, filterInfo, sortInfo, markInfo, autoReloadInfo, partialWarn, shown, total) + base := fmt.Sprintf("%s/%s%s%s%s%s%s%s • %d/%d items • c:clear", r.service, r.resourceType, filterInfo, sortInfo, markInfo, toggleInfo, autoReloadInfo, partialWarn, shown, total) if hasActions { base += " a:actions" } @@ -124,7 +127,7 @@ func (r *ResourceBrowser) StatusLine() string { return base } - base := fmt.Sprintf("%s/%s%s%s%s%s • %d items • /:filter %s", r.service, r.resourceType, sortInfo, markInfo, autoReloadInfo, partialWarn, total, dHint) + base := fmt.Sprintf("%s/%s%s%s%s%s%s • %d items • /:filter %s", r.service, r.resourceType, sortInfo, markInfo, toggleInfo, autoReloadInfo, partialWarn, total, dHint) if hasActions { base += " a:actions" } @@ -160,12 +163,29 @@ func (r *ResourceBrowser) CanRefresh() bool { return true } -// Service returns the service name for this browser func (r *ResourceBrowser) Service() string { return r.service } -// getNavigationShortcuts returns a string of navigation shortcuts for the current resource +func (r *ResourceBrowser) ResourceType() string { + return r.resourceType +} + +func (r *ResourceBrowser) SelectedResource() dao.Resource { + if len(r.filtered) == 0 { + return nil + } + cursor := r.tc.Cursor() + if cursor < 0 || cursor >= len(r.filtered) { + return nil + } + return r.filtered[cursor] +} + +func (r *ResourceBrowser) ResourceCount() int { return len(r.filtered) } +func (r *ResourceBrowser) FilterText() string { return r.filterText } +func (r *ResourceBrowser) ToggleStates() map[string]bool { return r.toggleStates } + func (r *ResourceBrowser) getNavigationShortcuts() string { if r.renderer == nil || len(r.filtered) == 0 { return "" @@ -175,3 +195,26 @@ func (r *ResourceBrowser) getNavigationShortcuts() string { resource := dao.UnwrapResource(r.filtered[r.tc.Cursor()]) return helper.FormatShortcuts(resource) } + +func (r *ResourceBrowser) getToggleInfo() string { + if r.renderer == nil { + return "" + } + toggler, ok := r.renderer.(render.Toggler) + if !ok { + return "" + } + toggles := toggler.ListToggles() + if len(toggles) == 0 { + return "" + } + var parts []string + for _, t := range toggles { + label := t.LabelOff + if r.toggleStates[t.ContextKey] { + label = t.LabelOn + } + parts = append(parts, fmt.Sprintf("%s:%s", t.Key, label)) + } + return " [" + strings.Join(parts, " ") + "]" +} diff --git a/internal/view/service_browser.go b/internal/view/service_browser.go index 0b86c1fe..22ca865f 100644 --- a/internal/view/service_browser.go +++ b/internal/view/service_browser.go @@ -263,14 +263,13 @@ func (s *ServiceBrowser) handleFilterInput(msg tea.KeyPressMsg) (tea.Model, tea. s.rebuildFlatItems() s.cursor = 0 s.updateViewport() - return s, cmd + return s, tea.Batch(cmd, tea.ClearScreen) } func (s *ServiceBrowser) handleNavigation(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { // Handle special keys that work regardless of flatItems state switch msg.String() { case "~": - // Toggle to Dashboard dashboard := NewDashboardView(s.ctx, s.registry) return s, func() tea.Msg { return NavigateMsg{View: dashboard, ClearStack: false} @@ -279,6 +278,24 @@ func (s *ServiceBrowser) handleNavigation(msg tea.KeyPressMsg) (tea.Model, tea.C s.filterActive = true s.filterInput.Focus() return s, textinput.Blink + case "c": + if s.filterText != "" { + s.filterText = "" + s.filterInput.SetValue("") + s.rebuildFlatItems() + s.cursor = 0 + s.updateViewport() + return s, tea.ClearScreen + } + } + + if IsEscKey(msg) && s.filterText != "" { + s.filterText = "" + s.filterInput.SetValue("") + s.rebuildFlatItems() + s.cursor = 0 + s.updateViewport() + return s, tea.ClearScreen } // Navigation requires loaded services @@ -315,25 +332,8 @@ func (s *ServiceBrowser) handleNavigation(msg tea.KeyPressMsg) (tea.Model, tea.C case "enter": return s.selectCurrentService() - - case "c": - if s.filterText != "" { - s.filterText = "" - s.filterInput.SetValue("") - s.rebuildFlatItems() - s.cursor = 0 - } } - // Also allow esc to clear filter (handles various escape sequences) - if IsEscKey(msg) && s.filterText != "" { - s.filterText = "" - s.filterInput.SetValue("") - s.rebuildFlatItems() - s.cursor = 0 - } - - // Update viewport content and scroll to cursor s.updateViewport() return s, nil @@ -344,28 +344,17 @@ func (s *ServiceBrowser) updateViewport() { return } content := s.renderContent() - s.vp.Model.SetContent(content) + vpWidth := s.vp.Model.Width() + vpHeight := s.vp.Model.Height() lines := strings.Split(content, "\n") - totalLines := len(lines) - - if len(s.flatItems) == 0 { - return + emptyLine := strings.Repeat(" ", vpWidth) + for len(lines) < vpHeight { + lines = append(lines, emptyLine) } - cursorRatio := float64(s.cursor) / float64(len(s.flatItems)) - targetLine := int(cursorRatio * float64(totalLines)) - - vpHeight := s.vp.Model.Height() - currentTop := s.vp.Model.YOffset() - - if targetLine < currentTop { - s.vp.Model.SetYOffset(max(0, targetLine-2)) - } else if targetLine > currentTop+vpHeight-cellHeight { - newOffset := targetLine - vpHeight + cellHeight + 2 - newOffset = max(0, min(newOffset, totalLines-vpHeight)) - s.vp.Model.SetYOffset(newOffset) - } + s.vp.Model.SetContent(strings.Join(lines, "\n")) + s.vp.Model.GotoTop() } func (s *ServiceBrowser) moveToNextCategory() { diff --git a/internal/view/session_history.go b/internal/view/session_history.go new file mode 100644 index 00000000..441e9489 --- /dev/null +++ b/internal/view/session_history.go @@ -0,0 +1,147 @@ +package view + +import ( + "fmt" + "strings" + + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + + "github.com/clawscli/claws/internal/ai" + "github.com/clawscli/claws/internal/ui" +) + +const ModalWidthSessionHistory = 50 + +type sessionHistoryStyles struct { + title lipgloss.Style + item lipgloss.Style + selected lipgloss.Style + hint lipgloss.Style + current lipgloss.Style +} + +func newSessionHistoryStyles() sessionHistoryStyles { + return sessionHistoryStyles{ + title: ui.TableHeaderStyle().Padding(0, 1), + item: ui.TextStyle().PaddingLeft(2), + selected: ui.SelectedStyle().PaddingLeft(2), + hint: ui.DimStyle(), + current: ui.AccentStyle(), + } +} + +type SessionSelectedMsg struct { + Session *ai.Session +} + +type NewSessionMsg struct{} + +type CloseHistoryMsg struct{} + +type SessionHistory struct { + sessions []ai.Session + currentID string + cursor int + styles sessionHistoryStyles + width int + height int +} + +func NewSessionHistory(sessions []ai.Session, currentID string) *SessionHistory { + return &SessionHistory{ + sessions: sessions, + currentID: currentID, + cursor: 0, + styles: newSessionHistoryStyles(), + } +} + +func (s *SessionHistory) Init() tea.Cmd { + return nil +} + +func (s *SessionHistory) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyPressMsg: + switch msg.String() { + case "up", "k": + if s.cursor > 0 { + s.cursor-- + } + return s, nil + case "down", "j": + if s.cursor < len(s.sessions)-1 { + s.cursor++ + } + return s, nil + case "enter": + if s.cursor >= 0 && s.cursor < len(s.sessions) { + return s, func() tea.Msg { + return SessionSelectedMsg{Session: &s.sessions[s.cursor]} + } + } + return s, nil + case "n": + return s, func() tea.Msg { + return NewSessionMsg{} + } + case "esc", "q", "ctrl+c", "ctrl+h": + return s, func() tea.Msg { + return CloseHistoryMsg{} + } + } + } + return s, nil +} + +func (s *SessionHistory) View() tea.View { + return tea.NewView(s.ViewString()) +} + +func (s *SessionHistory) ViewString() string { + var b strings.Builder + + b.WriteString(s.styles.title.Render("Chat History")) + b.WriteString("\n\n") + + if len(s.sessions) == 0 { + b.WriteString(s.styles.hint.Render(" No saved sessions")) + b.WriteString("\n") + } else { + for i, sess := range s.sessions { + style := s.styles.item + prefix := " " + if i == s.cursor { + style = s.styles.selected + prefix = "> " + } + + dateStr := sess.UpdatedAt.Format("2006-01-02 15:04") + msgCount := len(sess.Messages) + line := fmt.Sprintf("%s%s (%d msgs)", prefix, dateStr, msgCount) + + if sess.ID == s.currentID { + line += " " + s.styles.current.Render("*") + } + + b.WriteString(style.Render(line)) + b.WriteString("\n") + } + } + + b.WriteString("\n") + b.WriteString(s.styles.hint.Render("j/k:select enter:load n:new esc:close")) + + return b.String() +} + +func (s *SessionHistory) SetSize(width, height int) tea.Cmd { + s.width = width + s.height = height + return nil +} + +func (s *SessionHistory) StatusLine() string { + return "" +} diff --git a/internal/view/view.go b/internal/view/view.go index 2a7bd99e..b6b910f8 100644 --- a/internal/view/view.go +++ b/internal/view/view.go @@ -230,3 +230,28 @@ func (h *NavigationHelper) createLogView(resource dao.Resource) tea.Cmd { return NavigateMsg{View: logView} } } + +// mergeResources merges the refreshed resource with the original to preserve +// fields that are only available from List() but not from Get(). +func mergeResources(original, refreshed dao.Resource) dao.Resource { + if original == nil { + return refreshed + } + if refreshed == nil { + return original + } + // If refreshed resource implements Mergeable, let it copy fields from original + if m, ok := refreshed.(dao.Mergeable); ok { + m.MergeFrom(original) + } + + // Preserve wrapping from original + if rr, ok := original.(*dao.RegionalResource); ok { + return dao.WrapWithRegion(refreshed, rr.Region) + } + if pr, ok := original.(*dao.ProfiledResource); ok { + return dao.WrapWithProfile(refreshed, pr.Profile, pr.AccountID, pr.Region) + } + + return refreshed +}