Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
496 changes: 291 additions & 205 deletions clients/agent-runtime/src/gateway/mod.rs

Large diffs are not rendered by default.

102 changes: 58 additions & 44 deletions clients/agent-runtime/src/gateway/webhook_dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,40 @@ fn sanitize_sse_id(session_id: &str) -> String {
.replace(['\r', '\n'], "")
}

fn handled_ingress_to_webhook_result(
request: &WebhookTurnRequest,
model: &str,
handled: HandledIngress,
) -> Option<WebhookTurnResult> {
match handled {
HandledIngress::Handled(HandledIngressOutcome::SessionCommandSuccess(success)) => {
Some(WebhookTurnResult {
session_id: request.session_id.clone(),
model: model.to_string(),
outcome: WebhookTerminalOutcome::Completed,
response_text: Some(success.message.clone()),
event_frames: Vec::new(),
tools_called: Vec::new(),
})
}
HandledIngress::Handled(HandledIngressOutcome::SessionCommandFailure {
class: _,
failure,
}) => Some(WebhookTurnResult {
session_id: request.session_id.clone(),
model: model.to_string(),
outcome: WebhookTerminalOutcome::Failed,
response_text: Some(failure.message),
event_frames: Vec::new(),
tools_called: Vec::new(),
}),
HandledIngress::Handled(HandledIngressOutcome::Blocking(blocking)) => Some(
map_canonical_result(request, model, CanonicalWebhookResult::Blocking(blocking)),
),
HandledIngress::NotHandled => None,
}
}

pub(crate) async fn execute(
config: &Config,
provider: Arc<dyn Provider>,
Expand All @@ -400,50 +434,10 @@ pub(crate) async fn execute(
Err(result) => return result,
};

match evaluate_webhook_ingress(memory.as_ref(), &tool_snapshot, &request, clamped_mode).await {
HandledIngress::Handled(HandledIngressOutcome::SessionCommandSuccess(success)) => {
return WebhookTurnResult {
session_id: request.session_id.clone(),
model: model.to_string(),
outcome: WebhookTerminalOutcome::Completed,
response_text: Some(success.message.clone()),
event_frames: Vec::new(),
tools_called: Vec::new(),
};
}
HandledIngress::Handled(HandledIngressOutcome::SessionCommandFailure {
class: _,
failure,
}) => {
return WebhookTurnResult {
session_id: request.session_id.clone(),
model: model.to_string(),
outcome: WebhookTerminalOutcome::Failed,
response_text: Some(failure.message),
event_frames: Vec::new(),
tools_called: Vec::new(),
};
}
HandledIngress::Handled(HandledIngressOutcome::Blocking(blocking)) => match blocking {
BlockingOutcome::ApprovalRequired { tool, reason } => {
return map_canonical_result(
&request,
model,
CanonicalWebhookResult::Blocking(BlockingOutcome::ApprovalRequired {
tool,
reason,
}),
);
}
other => {
return map_canonical_result(
&request,
model,
CanonicalWebhookResult::Blocking(other),
);
}
},
HandledIngress::NotHandled => {}
let handled_ingress =
evaluate_webhook_ingress(memory.as_ref(), &tool_snapshot, &request, clamped_mode).await;
if let Some(result) = handled_ingress_to_webhook_result(&request, model, handled_ingress) {
return result;
}

let mut effective_config = config.clone();
Expand Down Expand Up @@ -660,6 +654,26 @@ mod tests {
assert_eq!(frame.matches("id:").count(), 1);
}

#[test]
fn handled_ingress_failure_maps_to_failed_webhook_result() {
let request = sample_request(WebhookSessionSource::Explicit);
let handled = HandledIngress::Handled(HandledIngressOutcome::SessionCommandFailure {
class: crate::pre_execution::SessionCommandFailureClass::Failed,
failure: crate::session_commands::SessionCommandFailure {
kind: crate::session_commands::SessionCommandFailureKind::InvalidState,
message: "boom".into(),
command: "/tldr",
session_id: Some("webhook-123".into()),
},
});

let result = handled_ingress_to_webhook_result(&request, "test-model", handled)
.expect("expected handled result");

assert_eq!(result.outcome, WebhookTerminalOutcome::Failed);
assert_eq!(result.response_text.as_deref(), Some("boom"));
}

#[test]
fn maps_completed_agent_turn_into_completed_webhook_result() {
let result = map_canonical_result(
Expand Down
84 changes: 46 additions & 38 deletions clients/agent-runtime/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,41 @@ async fn maybe_handle_cli_handled_ingress(
}
}

async fn maybe_print_code_fast_path(config: &Config, message: Option<&str>) -> Result<bool> {
if let Some(raw_message) = message {
if let Some(result_message) = maybe_handle_cli_handled_ingress(config, raw_message).await? {
println!("{result_message}");
return Ok(true);
}
}

Ok(false)
}

async fn run_code_message_or_interactive(
agent: &mut crate::agent::Agent,
message: Option<String>,
) -> Result<()> {
let Some(message) = message else {
return agent.run_interactive().await;
};

let turn_result = agent
.turn_with_context(&message, crate::agent::TurnContext::default())
.await;

if let Ok(turn_result) = &turn_result {
if let Some(response) = turn_result.final_text.as_deref() {
println!("{response}");
}
if let Some(err) = cli_blocking_error_from_turn_result(turn_result) {
return Err(err);
}
}

turn_result.map(|_| ())
}

async fn handle_code_command(
config: Config,
message: Option<String>,
Expand All @@ -1643,13 +1678,8 @@ async fn handle_code_command(
) -> Result<()> {
let config = apply_code_session_config(config, provider, model, temperature, plan);

// Intercept shared handled-ingress commands before agent initialization.
if let Some(raw_message) = message.as_deref() {
if let Some(result_message) = maybe_handle_cli_handled_ingress(&config, raw_message).await?
{
println!("{result_message}");
return Ok(());
}
if maybe_print_code_fast_path(&config, message.as_deref()).await? {
return Ok(());
}

info!("Starting code-specialist session (profile=code)");
Expand All @@ -1672,38 +1702,16 @@ async fn handle_code_command(

agent.record_agent_start_event(&provider_name, &model_name);

let run_result = if let Some(msg) = message {
let turn_result = agent
.turn_with_context(&msg, crate::agent::TurnContext::default())
.await;
if let Ok(turn_result) = &turn_result {
let blocking_err = cli_blocking_error_from_turn_result(turn_result);
if let Some(response) = turn_result.final_text.as_deref() {
println!("{response}");
}
if let Some(err) = blocking_err {
let summary_result = agent.session_cost_summary(chrono::Utc::now());
agent.record_agent_end_event(&provider_name, &model_name, session_start.elapsed());
match summary_result {
Ok(summary) => print_cli_session_summary(summary, CliSessionSurface::Code),
Err(error) => {
tracing::warn!("Failed to load code session cost summary: {error}");
}
}
return Err(err);
}
}
turn_result.map(|_| ())
} else {
agent.run_interactive().await
};
let run_result = run_code_message_or_interactive(&mut agent, message).await;

let summary_result = agent.session_cost_summary(chrono::Utc::now());
agent.record_agent_end_event(&provider_name, &model_name, session_start.elapsed());
match summary_result {
Ok(summary) => print_cli_session_summary(summary, CliSessionSurface::Code),
Err(error) => tracing::warn!("Failed to load code session cost summary: {error}"),
}
finish_cli_session(
&agent,
&provider_name,
&model_name,
session_start,
CliSessionSurface::Code,
"code",
);

run_result
}
Expand Down
89 changes: 50 additions & 39 deletions clients/agent-runtime/src/security/policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,43 @@ impl SecurityPolicy {
.any(|w| w == "tee" || w.ends_with("/tee"))
}

fn is_likely_path(arg: &str) -> bool {
(arg.contains('/') && !arg.contains(':'))
|| arg.starts_with('~')
|| arg.starts_with('.')
|| arg.contains(std::path::MAIN_SEPARATOR)
}

fn effective_path_arg(arg: &str) -> &str {
if arg.starts_with("--") {
arg.split_once('=').map(|(_, value)| value).unwrap_or(arg)
} else if arg.starts_with('-') && arg.len() > 2 {
arg.char_indices()
.nth(2)
.map(|(idx, _)| &arg[idx..])
.unwrap_or("")
} else {
arg
}
}

fn is_path_argument_safe(&self, effective_arg: &str) -> bool {
if !Self::is_likely_path(effective_arg) {
return true;
}

if Path::new(effective_arg)
.components()
.any(|c| matches!(c, std::path::Component::ParentDir))
|| (self.workspace_only
&& (effective_arg.starts_with('/') || effective_arg.starts_with('~')))
{
return false;
}

!matches_any_forbidden_path(effective_arg, &self.forbidden_paths)
}

fn validate_command_segments(&self, command: &str) -> bool {
let normalized = self.normalize_command(command);

Expand Down Expand Up @@ -494,50 +531,18 @@ impl SecurityPolicy {
.map(|arg| arg.to_ascii_lowercase())
.collect();

// Helper to identify tokens that likely represent paths
fn is_likely_path(arg: &str) -> bool {
(arg.contains('/') && !arg.contains(':'))
|| arg.starts_with('~')
|| arg.starts_with('.')
|| arg.contains(std::path::MAIN_SEPARATOR)
}

// Ensure no argument is a forbidden path or a traversal attempt.
// We only check arguments that look like paths to avoid false positives
// on non-path tokens (e.g., git diff patterns, grep globs, brace literals).
for (_raw_arg, arg) in raw_args.iter().zip(normalized_args.iter()) {
// Extract potential path from flags (e.g. --file=/path or -C/path)
let effective_arg = if arg.starts_with("--") {
arg.split_once('=').map(|(_, v)| v).unwrap_or(arg)
} else if arg.starts_with('-') && arg.len() > 2 {
arg.char_indices()
.nth(2)
.map(|(idx, _)| &arg[idx..])
.unwrap_or("")
} else {
arg
};

if !is_likely_path(effective_arg) {
continue;
}

if Path::new(effective_arg)
.components()
.any(|c| matches!(c, std::path::Component::ParentDir))
|| (self.workspace_only
&& (effective_arg.starts_with('/') || effective_arg.starts_with('~')))
{
for arg in &normalized_args {
let effective_arg = Self::effective_path_arg(arg);
if !self.is_path_argument_safe(effective_arg) {
return false;
}
}

// Check against forbidden paths (e.g. /etc, ~/.ssh)
if matches_any_forbidden_path(effective_arg, &self.forbidden_paths) {
return false;
}
if !self.is_args_safe(base_raw, &args) {
return false;
}

self.is_args_safe(base_raw, &args)
true
}

fn is_allowed_command(&self, base_raw: &str) -> bool {
Expand Down Expand Up @@ -1434,6 +1439,12 @@ mod tests {

// ── Edge cases: path traversal ──────────────────────────

#[test]
fn command_with_flag_embedded_absolute_path_is_blocked() {
let p = default_policy();
assert!(!p.is_command_allowed("grep --file=/etc/passwd foo.txt"));
}

#[test]
fn path_traversal_encoded_dots() {
let p = default_policy();
Expand Down
Loading
Loading