diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index e0f0bb24..74a64fa2 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -1209,6 +1209,152 @@ pub enum Role { Assistant, } +/// Tool selection mode (SEP-1577). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum ToolChoiceMode { + /// Model decides whether to use tools + Auto, + /// Model must use at least one tool + Required, + /// Model must not use tools + None, +} + +impl Default for ToolChoiceMode { + fn default() -> Self { + Self::Auto + } +} + +/// Tool choice configuration (SEP-1577). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ToolChoice { + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, +} + +impl ToolChoice { + pub fn auto() -> Self { + Self { + mode: Some(ToolChoiceMode::Auto), + } + } + + pub fn required() -> Self { + Self { + mode: Some(ToolChoiceMode::Required), + } + } + + pub fn none() -> Self { + Self { + mode: Some(ToolChoiceMode::None), + } + } +} + +/// Single or array content wrapper (SEP-1577). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum SamplingContent { + Single(T), + Multiple(Vec), +} + +impl SamplingContent { + /// Convert to a Vec regardless of whether it's single or multiple + pub fn into_vec(self) -> Vec { + match self { + SamplingContent::Single(item) => vec![item], + SamplingContent::Multiple(items) => items, + } + } + + /// Check if the content is empty + pub fn is_empty(&self) -> bool { + match self { + SamplingContent::Single(_) => false, + SamplingContent::Multiple(items) => items.is_empty(), + } + } + + /// Get the number of content items + pub fn len(&self) -> usize { + match self { + SamplingContent::Single(_) => 1, + SamplingContent::Multiple(items) => items.len(), + } + } +} + +impl Default for SamplingContent { + fn default() -> Self { + SamplingContent::Multiple(Vec::new()) + } +} + +impl SamplingContent { + /// Get the first item if present + pub fn first(&self) -> Option<&T> { + match self { + SamplingContent::Single(item) => Some(item), + SamplingContent::Multiple(items) => items.first(), + } + } + + /// Iterate over all content items + pub fn iter(&self) -> impl Iterator { + let items: Vec<&T> = match self { + SamplingContent::Single(item) => vec![item], + SamplingContent::Multiple(items) => items.iter().collect(), + }; + items.into_iter() + } +} + +impl SamplingMessageContent { + /// Get the text content if this is a Text variant + pub fn as_text(&self) -> Option<&RawTextContent> { + match self { + SamplingMessageContent::Text(text) => Some(text), + _ => None, + } + } + + /// Get the tool use content if this is a ToolUse variant + pub fn as_tool_use(&self) -> Option<&ToolUseContent> { + match self { + SamplingMessageContent::ToolUse(tool_use) => Some(tool_use), + _ => None, + } + } + + /// Get the tool result content if this is a ToolResult variant + pub fn as_tool_result(&self) -> Option<&ToolResultContent> { + match self { + SamplingMessageContent::ToolResult(tool_result) => Some(tool_result), + _ => None, + } + } +} + +impl From for SamplingContent { + fn from(item: T) -> Self { + SamplingContent::Single(item) + } +} + +impl From> for SamplingContent { + fn from(items: Vec) -> Self { + SamplingContent::Multiple(items) + } +} + /// A message in a sampling conversation, containing a role and content. /// /// This represents a single message in a conversation flow, used primarily @@ -1219,8 +1365,106 @@ pub enum Role { pub struct SamplingMessage { /// The role of the message sender (User or Assistant) pub role: Role, - /// The actual content of the message (text, image, etc.) - pub content: Content, + /// The actual content of the message (text, image, audio, tool use, or tool result) + pub content: SamplingContent, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, +} + +/// Content types for sampling messages (SEP-1577). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum SamplingMessageContent { + Text(RawTextContent), + Image(RawImageContent), + Audio(RawAudioContent), + /// Assistant only + ToolUse(ToolUseContent), + /// User only + ToolResult(ToolResultContent), +} + +impl SamplingMessageContent { + /// Create a text content + pub fn text(text: impl Into) -> Self { + Self::Text(RawTextContent { + text: text.into(), + meta: None, + }) + } + + pub fn tool_use(id: impl Into, name: impl Into, input: JsonObject) -> Self { + Self::ToolUse(ToolUseContent::new(id, name, input)) + } + + pub fn tool_result(tool_use_id: impl Into, content: Vec) -> Self { + Self::ToolResult(ToolResultContent::new(tool_use_id, content)) + } +} + +impl SamplingMessage { + pub fn new(role: Role, content: impl Into) -> Self { + Self { + role, + content: SamplingContent::Single(content.into()), + meta: None, + } + } + + pub fn new_multiple(role: Role, contents: Vec) -> Self { + Self { + role, + content: SamplingContent::Multiple(contents), + meta: None, + } + } + + pub fn user_text(text: impl Into) -> Self { + Self::new(Role::User, SamplingMessageContent::text(text)) + } + + pub fn assistant_text(text: impl Into) -> Self { + Self::new(Role::Assistant, SamplingMessageContent::text(text)) + } + + pub fn user_tool_result(tool_use_id: impl Into, content: Vec) -> Self { + Self::new( + Role::User, + SamplingMessageContent::tool_result(tool_use_id, content), + ) + } + + pub fn assistant_tool_use( + id: impl Into, + name: impl Into, + input: JsonObject, + ) -> Self { + Self::new( + Role::Assistant, + SamplingMessageContent::tool_use(id, name, input), + ) + } +} + +// Conversion from RawTextContent to SamplingMessageContent +impl From for SamplingMessageContent { + fn from(text: RawTextContent) -> Self { + SamplingMessageContent::Text(text) + } +} + +// Conversion from String to SamplingMessageContent (as text) +impl From for SamplingMessageContent { + fn from(text: String) -> Self { + SamplingMessageContent::text(text) + } +} + +impl From<&str> for SamplingMessageContent { + fn from(text: &str) -> Self { + SamplingMessageContent::text(text) + } } /// Specifies how much context should be included in sampling requests. @@ -1281,6 +1525,12 @@ pub struct CreateMessageRequestParams { /// Additional metadata for the request #[serde(skip_serializing_if = "Option::is_none")] pub metadata: Option, + /// Tools available for the model to call (SEP-1577) + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + /// Tool selection behavior (SEP-1577) + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, } impl RequestParamsMeta for CreateMessageRequestParams { @@ -1926,6 +2176,7 @@ pub type CallToolRequestParam = CallToolRequestParams; /// Request to call a specific tool pub type CallToolRequest = Request; +/// Result of sampling/createMessage (SEP-1577). /// The result of a sampling/createMessage request containing the generated response. /// /// This structure contains the generated message along with metadata about @@ -1948,6 +2199,7 @@ impl CreateMessageResult { pub const STOP_REASON_END_TURN: &str = "endTurn"; pub const STOP_REASON_END_SEQUENCE: &str = "stopSequence"; pub const STOP_REASON_END_MAX_TOKEN: &str = "maxTokens"; + pub const STOP_REASON_TOOL_USE: &str = "toolUse"; } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] @@ -2476,7 +2728,9 @@ mod tests { .. }) => { assert_eq!(capabilities.roots.unwrap().list_changed, Some(true)); - assert_eq!(capabilities.sampling.unwrap().len(), 0); + let sampling = capabilities.sampling.unwrap(); + assert_eq!(sampling.tools, None); + assert_eq!(sampling.context, None); assert_eq!(client_info.name, "ExampleClient"); assert_eq!(client_info.version, "1.0.0"); } diff --git a/crates/rmcp/src/model/capabilities.rs b/crates/rmcp/src/model/capabilities.rs index 80353216..cf4924ba 100644 --- a/crates/rmcp/src/model/capabilities.rs +++ b/crates/rmcp/src/model/capabilities.rs @@ -172,6 +172,19 @@ pub struct ElicitationCapability { pub schema_validation: Option, } +/// Sampling capability with optional sub-capabilities (SEP-1577). +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct SamplingCapability { + /// Support for `tools` and `toolChoice` parameters + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option, + /// Support for `includeContext` (soft-deprecated) + #[serde(skip_serializing_if = "Option::is_none")] + pub context: Option, +} + /// /// # Builder /// ```rust @@ -189,8 +202,9 @@ pub struct ClientCapabilities { pub experimental: Option, #[serde(skip_serializing_if = "Option::is_none")] pub roots: Option, + /// Capability for LLM sampling requests (SEP-1577) #[serde(skip_serializing_if = "Option::is_none")] - pub sampling: Option, + pub sampling: Option, /// Capability to handle elicitation requests from servers for interactive user input #[serde(skip_serializing_if = "Option::is_none")] pub elicitation: Option, @@ -392,7 +406,7 @@ builder! { ClientCapabilities{ experimental: ExperimentalCapabilities, roots: RootsCapabilities, - sampling: JsonObject, + sampling: SamplingCapability, elicitation: ElicitationCapability, tasks: TasksCapability, } @@ -409,6 +423,26 @@ impl } } +impl + ClientCapabilitiesBuilder> +{ + /// Enable tool calling in sampling requests + pub fn enable_sampling_tools(mut self) -> Self { + if let Some(c) = self.sampling.as_mut() { + c.tools = Some(JsonObject::default()); + } + self + } + + /// Enable context inclusion in sampling (soft-deprecated) + pub fn enable_sampling_context(mut self) -> Self { + if let Some(c) = self.sampling.as_mut() { + c.context = Some(JsonObject::default()); + } + self + } +} + #[cfg(feature = "elicitation")] impl ClientCapabilitiesBuilder> diff --git a/crates/rmcp/src/model/content.rs b/crates/rmcp/src/model/content.rs index fb82053d..297bc751 100644 --- a/crates/rmcp/src/model/content.rs +++ b/crates/rmcp/src/model/content.rs @@ -59,6 +59,137 @@ pub struct RawAudioContent { pub type AudioContent = Annotated; +/// Tool call request from assistant (SEP-1577). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ToolUseContent { + /// Unique identifier for this tool call + pub id: String, + /// Name of the tool to call + pub name: String, + /// Input arguments for the tool + pub input: super::JsonObject, + /// Optional metadata (preserved for caching) + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, +} + +/// Tool execution result in user message (SEP-1577). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ToolResultContent { + /// Optional metadata + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, + /// ID of the corresponding tool use + pub tool_use_id: String, + /// Content blocks returned by the tool + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub content: Vec, + /// Optional structured result + #[serde(skip_serializing_if = "Option::is_none")] + pub structured_content: Option, + /// Whether tool execution failed + #[serde(skip_serializing_if = "Option::is_none")] + pub is_error: Option, +} + +impl ToolUseContent { + pub fn new(id: impl Into, name: impl Into, input: super::JsonObject) -> Self { + Self { + id: id.into(), + name: name.into(), + input, + meta: None, + } + } +} + +impl ToolResultContent { + pub fn new(tool_use_id: impl Into, content: Vec) -> Self { + Self { + meta: None, + tool_use_id: tool_use_id.into(), + content, + structured_content: None, + is_error: None, + } + } + + pub fn error(tool_use_id: impl Into, content: Vec) -> Self { + Self { + meta: None, + tool_use_id: tool_use_id.into(), + content, + structured_content: None, + is_error: Some(true), + } + } +} + +/// Assistant message content types (SEP-1577). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum AssistantMessageContent { + Text(RawTextContent), + Image(RawImageContent), + Audio(RawAudioContent), + ToolUse(ToolUseContent), +} + +/// User message content types (SEP-1577). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum UserMessageContent { + Text(RawTextContent), + Image(RawImageContent), + Audio(RawAudioContent), + ToolResult(ToolResultContent), +} + +impl AssistantMessageContent { + /// Create a text content + pub fn text(text: impl Into) -> Self { + Self::Text(RawTextContent { + text: text.into(), + meta: None, + }) + } + + /// Create a tool use content + pub fn tool_use( + id: impl Into, + name: impl Into, + input: super::JsonObject, + ) -> Self { + Self::ToolUse(ToolUseContent::new(id, name, input)) + } +} + +impl UserMessageContent { + /// Create a text content + pub fn text(text: impl Into) -> Self { + Self::Text(RawTextContent { + text: text.into(), + meta: None, + }) + } + + /// Create a tool result content + pub fn tool_result(tool_use_id: impl Into, content: Vec) -> Self { + Self::ToolResult(ToolResultContent::new(tool_use_id, content)) + } + + /// Create an error tool result content + pub fn tool_result_error(tool_use_id: impl Into, content: Vec) -> Self { + Self::ToolResult(ToolResultContent::error(tool_use_id, content)) + } +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] diff --git a/crates/rmcp/tests/common/handlers.rs b/crates/rmcp/tests/common/handlers.rs index 373d278c..654413fa 100644 --- a/crates/rmcp/tests/common/handlers.rs +++ b/crates/rmcp/tests/common/handlers.rs @@ -72,10 +72,7 @@ impl ClientHandler for TestClientHandler { }; Ok(CreateMessageResult { - message: SamplingMessage { - role: Role::Assistant, - content: Content::text(response.to_string()), - }, + message: SamplingMessage::assistant_text(response.to_string()), model: "test-model".to_string(), stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), }) diff --git a/crates/rmcp/tests/test_message_protocol.rs b/crates/rmcp/tests/test_message_protocol.rs index b2851b5a..7ec3258c 100644 --- a/crates/rmcp/tests/test_message_protocol.rs +++ b/crates/rmcp/tests/test_message_protocol.rs @@ -13,14 +13,8 @@ use tokio_util::sync::CancellationToken; #[tokio::test] async fn test_message_roles() { let messages = vec![ - SamplingMessage { - role: Role::User, - content: Content::text("user message"), - }, - SamplingMessage { - role: Role::Assistant, - content: Content::text("assistant message"), - }, + SamplingMessage::user_text("user message"), + SamplingMessage::assistant_text("assistant message"), ]; // Verify all roles can be serialized/deserialized correctly @@ -50,10 +44,7 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { params: CreateMessageRequestParams { meta: None, task: None, - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test message"), - }], + messages: vec![SamplingMessage::user_text("test message")], include_context: Some(ContextInclusion::ThisServer), model_preferences: None, system_prompt: None, @@ -61,6 +52,8 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { max_tokens: 100, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }, extensions: Default::default(), }); @@ -79,7 +72,15 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { .await?; if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); + let text = result + .message + .content + .first() + .unwrap() + .as_text() + .unwrap() + .text + .as_str(); assert!( text.contains("test context"), "Response should include context for ThisServer" @@ -94,10 +95,7 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { params: CreateMessageRequestParams { meta: None, task: None, - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test message"), - }], + messages: vec![SamplingMessage::user_text("test message")], include_context: Some(ContextInclusion::AllServers), model_preferences: None, system_prompt: None, @@ -105,6 +103,8 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { max_tokens: 100, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }, extensions: Default::default(), }); @@ -123,7 +123,15 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { .await?; if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); + let text = result + .message + .content + .first() + .unwrap() + .as_text() + .unwrap() + .text + .as_str(); assert!( text.contains("test context"), "Response should include context for AllServers" @@ -138,10 +146,7 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { params: CreateMessageRequestParams { meta: None, task: None, - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test message"), - }], + messages: vec![SamplingMessage::user_text("test message")], include_context: Some(ContextInclusion::None), model_preferences: None, system_prompt: None, @@ -149,6 +154,8 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { max_tokens: 100, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }, extensions: Default::default(), }); @@ -167,7 +174,15 @@ async fn test_context_inclusion_integration() -> anyhow::Result<()> { .await?; if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); + let text = result + .message + .content + .first() + .unwrap() + .as_text() + .unwrap() + .text + .as_str(); assert!( !text.contains("test context"), "Response should not include context for None" @@ -202,10 +217,7 @@ async fn test_context_inclusion_ignored_integration() -> anyhow::Result<()> { params: CreateMessageRequestParams { meta: None, task: None, - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test message"), - }], + messages: vec![SamplingMessage::user_text("test message")], include_context: Some(ContextInclusion::ThisServer), model_preferences: None, system_prompt: None, @@ -213,6 +225,8 @@ async fn test_context_inclusion_ignored_integration() -> anyhow::Result<()> { max_tokens: 100, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }, extensions: Default::default(), }); @@ -231,7 +245,15 @@ async fn test_context_inclusion_ignored_integration() -> anyhow::Result<()> { .await?; if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); + let text = result + .message + .content + .first() + .unwrap() + .as_text() + .unwrap() + .text + .as_str(); assert!( !text.contains("test context"), "Context should be ignored when client chooses not to honor requests" @@ -266,14 +288,8 @@ async fn test_message_sequence_integration() -> anyhow::Result<()> { meta: None, task: None, messages: vec![ - SamplingMessage { - role: Role::User, - content: Content::text("first message"), - }, - SamplingMessage { - role: Role::Assistant, - content: Content::text("second message"), - }, + SamplingMessage::user_text("first message"), + SamplingMessage::assistant_text("second message"), ], include_context: Some(ContextInclusion::ThisServer), model_preferences: None, @@ -282,6 +298,8 @@ async fn test_message_sequence_integration() -> anyhow::Result<()> { max_tokens: 100, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }, extensions: Default::default(), }); @@ -300,7 +318,15 @@ async fn test_message_sequence_integration() -> anyhow::Result<()> { .await?; if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); + let text = result + .message + .content + .first() + .unwrap() + .as_text() + .unwrap() + .text + .as_str(); assert!( text.contains("test context"), "Response should include context when ThisServer is specified" @@ -339,18 +365,9 @@ async fn test_message_sequence_validation_integration() -> anyhow::Result<()> { meta: None, task: None, messages: vec![ - SamplingMessage { - role: Role::User, - content: Content::text("first user message"), - }, - SamplingMessage { - role: Role::Assistant, - content: Content::text("first assistant response"), - }, - SamplingMessage { - role: Role::User, - content: Content::text("second user message"), - }, + SamplingMessage::user_text("first user message"), + SamplingMessage::assistant_text("first assistant response"), + SamplingMessage::user_text("second user message"), ], include_context: None, model_preferences: None, @@ -359,6 +376,8 @@ async fn test_message_sequence_validation_integration() -> anyhow::Result<()> { max_tokens: 100, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }, extensions: Default::default(), }); @@ -384,10 +403,7 @@ async fn test_message_sequence_validation_integration() -> anyhow::Result<()> { params: CreateMessageRequestParams { meta: None, task: None, - messages: vec![SamplingMessage { - role: Role::Assistant, - content: Content::text("assistant message"), - }], + messages: vec![SamplingMessage::assistant_text("assistant message")], include_context: None, model_preferences: None, system_prompt: None, @@ -395,6 +411,8 @@ async fn test_message_sequence_validation_integration() -> anyhow::Result<()> { max_tokens: 100, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }, extensions: Default::default(), }); @@ -439,10 +457,7 @@ async fn test_selective_context_handling_integration() -> anyhow::Result<()> { params: CreateMessageRequestParams { meta: None, task: None, - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test message"), - }], + messages: vec![SamplingMessage::user_text("test message")], include_context: Some(ContextInclusion::ThisServer), model_preferences: None, system_prompt: None, @@ -450,6 +465,8 @@ async fn test_selective_context_handling_integration() -> anyhow::Result<()> { max_tokens: 100, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }, extensions: Default::default(), }); @@ -468,7 +485,15 @@ async fn test_selective_context_handling_integration() -> anyhow::Result<()> { .await?; if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); + let text = result + .message + .content + .first() + .unwrap() + .as_text() + .unwrap() + .text + .as_str(); assert!( text.contains("test context"), "ThisServer context request should be honored" @@ -481,10 +506,7 @@ async fn test_selective_context_handling_integration() -> anyhow::Result<()> { params: CreateMessageRequestParams { meta: None, task: None, - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test message"), - }], + messages: vec![SamplingMessage::user_text("test message")], include_context: Some(ContextInclusion::AllServers), model_preferences: None, system_prompt: None, @@ -492,6 +514,8 @@ async fn test_selective_context_handling_integration() -> anyhow::Result<()> { max_tokens: 100, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }, extensions: Default::default(), }); @@ -510,7 +534,15 @@ async fn test_selective_context_handling_integration() -> anyhow::Result<()> { .await?; if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); + let text = result + .message + .content + .first() + .unwrap() + .as_text() + .unwrap() + .text + .as_str(); assert!( !text.contains("test context"), "AllServers context request should be ignored" @@ -540,10 +572,7 @@ async fn test_context_inclusion() -> anyhow::Result<()> { params: CreateMessageRequestParams { meta: None, task: None, - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test"), - }], + messages: vec![SamplingMessage::user_text("test")], include_context: Some(ContextInclusion::ThisServer), model_preferences: None, system_prompt: None, @@ -551,6 +580,8 @@ async fn test_context_inclusion() -> anyhow::Result<()> { max_tokens: 100, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }, extensions: Default::default(), }); @@ -569,7 +600,15 @@ async fn test_context_inclusion() -> anyhow::Result<()> { .await?; if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); + let text = result + .message + .content + .first() + .unwrap() + .as_text() + .unwrap() + .text + .as_str(); assert!(text.contains("test context")); } diff --git a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json index e0d90fa8..b3fe82b1 100644 --- a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json +++ b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json @@ -307,11 +307,15 @@ ] }, "sampling": { - "type": [ - "object", - "null" - ], - "additionalProperties": true + "description": "Capability for LLM sampling requests (SEP-1577)", + "anyOf": [ + { + "$ref": "#/definitions/SamplingCapability" + }, + { + "type": "null" + } + ] }, "tasks": { "anyOf": [ @@ -420,14 +424,21 @@ ] }, "CreateMessageResult": { - "description": "The result of a sampling/createMessage request containing the generated response.\n\nThis structure contains the generated message along with metadata about\nhow the generation was performed and why it stopped.", + "description": "Result of sampling/createMessage (SEP-1577).\nThe result of a sampling/createMessage request containing the generated response.\n\nThis structure contains the generated message along with metadata about\nhow the generation was performed and why it stopped.", "type": "object", "properties": { + "_meta": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, "content": { - "description": "The actual content of the message (text, image, etc.)", + "description": "The actual content of the message (text, image, audio, tool use, or tool result)", "allOf": [ { - "$ref": "#/definitions/Annotated" + "$ref": "#/definitions/SamplingContent" } ] }, @@ -1730,6 +1741,134 @@ "format": "const", "const": "notifications/roots/list_changed" }, + "SamplingCapability": { + "description": "Sampling capability with optional sub-capabilities (SEP-1577).", + "type": "object", + "properties": { + "context": { + "description": "Support for `includeContext` (soft-deprecated)", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "tools": { + "description": "Support for `tools` and `toolChoice` parameters", + "type": [ + "object", + "null" + ], + "additionalProperties": true + } + } + }, + "SamplingContent": { + "description": "Single or array content wrapper (SEP-1577).", + "anyOf": [ + { + "$ref": "#/definitions/SamplingMessageContent" + }, + { + "type": "array", + "items": { + "$ref": "#/definitions/SamplingMessageContent" + } + } + ] + }, + "SamplingMessageContent": { + "description": "Content types for sampling messages (SEP-1577).", + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawTextContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "image" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawImageContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "audio" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawAudioContent" + } + ], + "required": [ + "type" + ] + }, + { + "description": "Assistant only", + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "tool_use" + } + }, + "allOf": [ + { + "$ref": "#/definitions/ToolUseContent" + } + ], + "required": [ + "type" + ] + }, + { + "description": "User only", + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "tool_result" + } + }, + "allOf": [ + { + "$ref": "#/definitions/ToolResultContent" + } + ], + "required": [ + "type" + ] + } + ] + }, "SamplingTaskCapability": { "type": "object", "properties": { @@ -1864,6 +2003,81 @@ } } }, + "ToolResultContent": { + "description": "Tool execution result in user message (SEP-1577).", + "type": "object", + "properties": { + "_meta": { + "description": "Optional metadata", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "content": { + "description": "Content blocks returned by the tool", + "type": "array", + "items": { + "$ref": "#/definitions/Annotated" + } + }, + "isError": { + "description": "Whether tool execution failed", + "type": [ + "boolean", + "null" + ] + }, + "structuredContent": { + "description": "Optional structured result", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "toolUseId": { + "description": "ID of the corresponding tool use", + "type": "string" + } + }, + "required": [ + "toolUseId" + ] + }, + "ToolUseContent": { + "description": "Tool call request from assistant (SEP-1577).", + "type": "object", + "properties": { + "_meta": { + "description": "Optional metadata (preserved for caching)", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "id": { + "description": "Unique identifier for this tool call", + "type": "string" + }, + "input": { + "description": "Input arguments for the tool", + "type": "object", + "additionalProperties": true + }, + "name": { + "description": "Name of the tool to call", + "type": "string" + } + }, + "required": [ + "id", + "name", + "input" + ] + }, "ToolsTaskCapability": { "type": "object", "properties": { diff --git a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema_current.json b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema_current.json index e0d90fa8..b3fe82b1 100644 --- a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema_current.json +++ b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema_current.json @@ -307,11 +307,15 @@ ] }, "sampling": { - "type": [ - "object", - "null" - ], - "additionalProperties": true + "description": "Capability for LLM sampling requests (SEP-1577)", + "anyOf": [ + { + "$ref": "#/definitions/SamplingCapability" + }, + { + "type": "null" + } + ] }, "tasks": { "anyOf": [ @@ -420,14 +424,21 @@ ] }, "CreateMessageResult": { - "description": "The result of a sampling/createMessage request containing the generated response.\n\nThis structure contains the generated message along with metadata about\nhow the generation was performed and why it stopped.", + "description": "Result of sampling/createMessage (SEP-1577).\nThe result of a sampling/createMessage request containing the generated response.\n\nThis structure contains the generated message along with metadata about\nhow the generation was performed and why it stopped.", "type": "object", "properties": { + "_meta": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, "content": { - "description": "The actual content of the message (text, image, etc.)", + "description": "The actual content of the message (text, image, audio, tool use, or tool result)", "allOf": [ { - "$ref": "#/definitions/Annotated" + "$ref": "#/definitions/SamplingContent" } ] }, @@ -1730,6 +1741,134 @@ "format": "const", "const": "notifications/roots/list_changed" }, + "SamplingCapability": { + "description": "Sampling capability with optional sub-capabilities (SEP-1577).", + "type": "object", + "properties": { + "context": { + "description": "Support for `includeContext` (soft-deprecated)", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "tools": { + "description": "Support for `tools` and `toolChoice` parameters", + "type": [ + "object", + "null" + ], + "additionalProperties": true + } + } + }, + "SamplingContent": { + "description": "Single or array content wrapper (SEP-1577).", + "anyOf": [ + { + "$ref": "#/definitions/SamplingMessageContent" + }, + { + "type": "array", + "items": { + "$ref": "#/definitions/SamplingMessageContent" + } + } + ] + }, + "SamplingMessageContent": { + "description": "Content types for sampling messages (SEP-1577).", + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawTextContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "image" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawImageContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "audio" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawAudioContent" + } + ], + "required": [ + "type" + ] + }, + { + "description": "Assistant only", + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "tool_use" + } + }, + "allOf": [ + { + "$ref": "#/definitions/ToolUseContent" + } + ], + "required": [ + "type" + ] + }, + { + "description": "User only", + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "tool_result" + } + }, + "allOf": [ + { + "$ref": "#/definitions/ToolResultContent" + } + ], + "required": [ + "type" + ] + } + ] + }, "SamplingTaskCapability": { "type": "object", "properties": { @@ -1864,6 +2003,81 @@ } } }, + "ToolResultContent": { + "description": "Tool execution result in user message (SEP-1577).", + "type": "object", + "properties": { + "_meta": { + "description": "Optional metadata", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "content": { + "description": "Content blocks returned by the tool", + "type": "array", + "items": { + "$ref": "#/definitions/Annotated" + } + }, + "isError": { + "description": "Whether tool execution failed", + "type": [ + "boolean", + "null" + ] + }, + "structuredContent": { + "description": "Optional structured result", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "toolUseId": { + "description": "ID of the corresponding tool use", + "type": "string" + } + }, + "required": [ + "toolUseId" + ] + }, + "ToolUseContent": { + "description": "Tool call request from assistant (SEP-1577).", + "type": "object", + "properties": { + "_meta": { + "description": "Optional metadata (preserved for caching)", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "id": { + "description": "Unique identifier for this tool call", + "type": "string" + }, + "input": { + "description": "Input arguments for the tool", + "type": "object", + "additionalProperties": true + }, + "name": { + "description": "Name of the tool to call", + "type": "string" + } + }, + "required": [ + "id", + "name", + "input" + ] + }, "ToolsTaskCapability": { "type": "object", "properties": { diff --git a/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema.json b/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema.json index b848d4ee..b995910a 100644 --- a/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema.json +++ b/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema.json @@ -641,6 +641,27 @@ "null" ], "format": "float" + }, + "toolChoice": { + "description": "Tool selection behavior (SEP-1577)", + "anyOf": [ + { + "$ref": "#/definitions/ToolChoice" + }, + { + "type": "null" + } + ] + }, + "tools": { + "description": "Tools available for the model to call (SEP-1577)", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Tool" + } } }, "required": [ @@ -2308,15 +2329,36 @@ } ] }, + "SamplingContent": { + "description": "Single or array content wrapper (SEP-1577).", + "anyOf": [ + { + "$ref": "#/definitions/SamplingMessageContent" + }, + { + "type": "array", + "items": { + "$ref": "#/definitions/SamplingMessageContent" + } + } + ] + }, "SamplingMessage": { "description": "A message in a sampling conversation, containing a role and content.\n\nThis represents a single message in a conversation flow, used primarily\nin LLM sampling requests where the conversation history is important\nfor generating appropriate responses.", "type": "object", "properties": { + "_meta": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, "content": { - "description": "The actual content of the message (text, image, etc.)", + "description": "The actual content of the message (text, image, audio, tool use, or tool result)", "allOf": [ { - "$ref": "#/definitions/Annotated" + "$ref": "#/definitions/SamplingContent" } ] }, @@ -2334,6 +2376,98 @@ "content" ] }, + "SamplingMessageContent": { + "description": "Content types for sampling messages (SEP-1577).", + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawTextContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "image" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawImageContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "audio" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawAudioContent" + } + ], + "required": [ + "type" + ] + }, + { + "description": "Assistant only", + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "tool_use" + } + }, + "allOf": [ + { + "$ref": "#/definitions/ToolUseContent" + } + ], + "required": [ + "type" + ] + }, + { + "description": "User only", + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "tool_result" + } + }, + "allOf": [ + { + "$ref": "#/definitions/ToolResultContent" + } + ], + "required": [ + "type" + ] + } + ] + }, "SamplingTaskCapability": { "type": "object", "properties": { @@ -2965,11 +3099,122 @@ } } }, + "ToolChoice": { + "description": "Tool choice configuration (SEP-1577).", + "type": "object", + "properties": { + "mode": { + "anyOf": [ + { + "$ref": "#/definitions/ToolChoiceMode" + }, + { + "type": "null" + } + ] + } + } + }, + "ToolChoiceMode": { + "description": "Tool selection mode (SEP-1577).", + "oneOf": [ + { + "description": "Model decides whether to use tools", + "type": "string", + "const": "auto" + }, + { + "description": "Model must use at least one tool", + "type": "string", + "const": "required" + }, + { + "description": "Model must not use tools", + "type": "string", + "const": "none" + } + ] + }, "ToolListChangedNotificationMethod": { "type": "string", "format": "const", "const": "notifications/tools/list_changed" }, + "ToolResultContent": { + "description": "Tool execution result in user message (SEP-1577).", + "type": "object", + "properties": { + "_meta": { + "description": "Optional metadata", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "content": { + "description": "Content blocks returned by the tool", + "type": "array", + "items": { + "$ref": "#/definitions/Annotated" + } + }, + "isError": { + "description": "Whether tool execution failed", + "type": [ + "boolean", + "null" + ] + }, + "structuredContent": { + "description": "Optional structured result", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "toolUseId": { + "description": "ID of the corresponding tool use", + "type": "string" + } + }, + "required": [ + "toolUseId" + ] + }, + "ToolUseContent": { + "description": "Tool call request from assistant (SEP-1577).", + "type": "object", + "properties": { + "_meta": { + "description": "Optional metadata (preserved for caching)", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "id": { + "description": "Unique identifier for this tool call", + "type": "string" + }, + "input": { + "description": "Input arguments for the tool", + "type": "object", + "additionalProperties": true + }, + "name": { + "description": "Name of the tool to call", + "type": "string" + } + }, + "required": [ + "id", + "name", + "input" + ] + }, "ToolsCapability": { "type": "object", "properties": { diff --git a/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema_current.json b/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema_current.json index b848d4ee..b995910a 100644 --- a/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema_current.json +++ b/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema_current.json @@ -641,6 +641,27 @@ "null" ], "format": "float" + }, + "toolChoice": { + "description": "Tool selection behavior (SEP-1577)", + "anyOf": [ + { + "$ref": "#/definitions/ToolChoice" + }, + { + "type": "null" + } + ] + }, + "tools": { + "description": "Tools available for the model to call (SEP-1577)", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/Tool" + } } }, "required": [ @@ -2308,15 +2329,36 @@ } ] }, + "SamplingContent": { + "description": "Single or array content wrapper (SEP-1577).", + "anyOf": [ + { + "$ref": "#/definitions/SamplingMessageContent" + }, + { + "type": "array", + "items": { + "$ref": "#/definitions/SamplingMessageContent" + } + } + ] + }, "SamplingMessage": { "description": "A message in a sampling conversation, containing a role and content.\n\nThis represents a single message in a conversation flow, used primarily\nin LLM sampling requests where the conversation history is important\nfor generating appropriate responses.", "type": "object", "properties": { + "_meta": { + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, "content": { - "description": "The actual content of the message (text, image, etc.)", + "description": "The actual content of the message (text, image, audio, tool use, or tool result)", "allOf": [ { - "$ref": "#/definitions/Annotated" + "$ref": "#/definitions/SamplingContent" } ] }, @@ -2334,6 +2376,98 @@ "content" ] }, + "SamplingMessageContent": { + "description": "Content types for sampling messages (SEP-1577).", + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawTextContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "image" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawImageContent" + } + ], + "required": [ + "type" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "audio" + } + }, + "allOf": [ + { + "$ref": "#/definitions/RawAudioContent" + } + ], + "required": [ + "type" + ] + }, + { + "description": "Assistant only", + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "tool_use" + } + }, + "allOf": [ + { + "$ref": "#/definitions/ToolUseContent" + } + ], + "required": [ + "type" + ] + }, + { + "description": "User only", + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "tool_result" + } + }, + "allOf": [ + { + "$ref": "#/definitions/ToolResultContent" + } + ], + "required": [ + "type" + ] + } + ] + }, "SamplingTaskCapability": { "type": "object", "properties": { @@ -2965,11 +3099,122 @@ } } }, + "ToolChoice": { + "description": "Tool choice configuration (SEP-1577).", + "type": "object", + "properties": { + "mode": { + "anyOf": [ + { + "$ref": "#/definitions/ToolChoiceMode" + }, + { + "type": "null" + } + ] + } + } + }, + "ToolChoiceMode": { + "description": "Tool selection mode (SEP-1577).", + "oneOf": [ + { + "description": "Model decides whether to use tools", + "type": "string", + "const": "auto" + }, + { + "description": "Model must use at least one tool", + "type": "string", + "const": "required" + }, + { + "description": "Model must not use tools", + "type": "string", + "const": "none" + } + ] + }, "ToolListChangedNotificationMethod": { "type": "string", "format": "const", "const": "notifications/tools/list_changed" }, + "ToolResultContent": { + "description": "Tool execution result in user message (SEP-1577).", + "type": "object", + "properties": { + "_meta": { + "description": "Optional metadata", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "content": { + "description": "Content blocks returned by the tool", + "type": "array", + "items": { + "$ref": "#/definitions/Annotated" + } + }, + "isError": { + "description": "Whether tool execution failed", + "type": [ + "boolean", + "null" + ] + }, + "structuredContent": { + "description": "Optional structured result", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "toolUseId": { + "description": "ID of the corresponding tool use", + "type": "string" + } + }, + "required": [ + "toolUseId" + ] + }, + "ToolUseContent": { + "description": "Tool call request from assistant (SEP-1577).", + "type": "object", + "properties": { + "_meta": { + "description": "Optional metadata (preserved for caching)", + "type": [ + "object", + "null" + ], + "additionalProperties": true + }, + "id": { + "description": "Unique identifier for this tool call", + "type": "string" + }, + "input": { + "description": "Input arguments for the tool", + "type": "object", + "additionalProperties": true + }, + "name": { + "description": "Name of the tool to call", + "type": "string" + } + }, + "required": [ + "id", + "name", + "input" + ] + }, "ToolsCapability": { "type": "object", "properties": { diff --git a/crates/rmcp/tests/test_sampling.rs b/crates/rmcp/tests/test_sampling.rs index 83a4325c..f9295883 100644 --- a/crates/rmcp/tests/test_sampling.rs +++ b/crates/rmcp/tests/test_sampling.rs @@ -13,13 +13,8 @@ use tokio_util::sync::CancellationToken; #[tokio::test] async fn test_basic_sampling_message_creation() -> Result<()> { - // Test basic sampling message structure - let message = SamplingMessage { - role: Role::User, - content: Content::text("What is the capital of France?"), - }; + let message = SamplingMessage::user_text("What is the capital of France?"); - // Verify serialization/deserialization let json = serde_json::to_string(&message)?; let deserialized: SamplingMessage = serde_json::from_str(&json)?; assert_eq!(message, deserialized); @@ -30,14 +25,10 @@ async fn test_basic_sampling_message_creation() -> Result<()> { #[tokio::test] async fn test_sampling_request_params() -> Result<()> { - // Test sampling request parameters structure let params = CreateMessageRequestParams { meta: None, task: None, - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("Hello, world!"), - }], + messages: vec![SamplingMessage::user_text("Hello, world!")], model_preferences: Some(ModelPreferences { hints: Some(vec![ModelHint { name: Some("claude".to_string()), @@ -52,14 +43,14 @@ async fn test_sampling_request_params() -> Result<()> { stop_sequences: Some(vec!["STOP".to_string()]), include_context: Some(ContextInclusion::None), metadata: Some(serde_json::json!({"test": "value"})), + tools: None, + tool_choice: None, }; - // Verify serialization/deserialization let json = serde_json::to_string(¶ms)?; let deserialized: CreateMessageRequestParams = serde_json::from_str(&json)?; assert_eq!(params, deserialized); - // Verify specific fields assert_eq!(params.messages.len(), 1); assert_eq!(params.max_tokens, 100); assert_eq!(params.temperature, Some(0.7)); @@ -69,22 +60,16 @@ async fn test_sampling_request_params() -> Result<()> { #[tokio::test] async fn test_sampling_result_structure() -> Result<()> { - // Test sampling result structure let result = CreateMessageResult { - message: SamplingMessage { - role: Role::Assistant, - content: Content::text("The capital of France is Paris."), - }, + message: SamplingMessage::assistant_text("The capital of France is Paris."), model: "test-model".to_string(), stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), }; - // Verify serialization/deserialization let json = serde_json::to_string(&result)?; let deserialized: CreateMessageResult = serde_json::from_str(&json)?; assert_eq!(result, deserialized); - // Verify specific fields assert_eq!(result.message.role, Role::Assistant); assert_eq!(result.model, "test-model"); assert_eq!( @@ -97,7 +82,6 @@ async fn test_sampling_result_structure() -> Result<()> { #[tokio::test] async fn test_sampling_context_inclusion_enum() -> Result<()> { - // Test context inclusion enum values let test_cases = vec![ (ContextInclusion::None, "none"), (ContextInclusion::ThisServer, "thisServer"), @@ -139,10 +123,7 @@ async fn test_sampling_integration_with_test_handlers() -> Result<()> { params: CreateMessageRequestParams { meta: None, task: None, - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("What is the capital of France?"), - }], + messages: vec![SamplingMessage::user_text("What is the capital of France?")], include_context: Some(ContextInclusion::ThisServer), model_preferences: Some(ModelPreferences { hints: Some(vec![ModelHint { @@ -157,6 +138,8 @@ async fn test_sampling_integration_with_test_handlers() -> Result<()> { max_tokens: 100, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }, extensions: Default::default(), }); @@ -183,7 +166,15 @@ async fn test_sampling_integration_with_test_handlers() -> Result<()> { Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()) ); - let response_text = result.message.content.as_text().unwrap().text.as_str(); + let response_text = result + .message + .content + .first() + .unwrap() + .as_text() + .unwrap() + .text + .as_str(); assert!( response_text.contains("test context"), "Response should include context for ThisServer inclusion" @@ -221,10 +212,7 @@ async fn test_sampling_no_context_inclusion() -> Result<()> { params: CreateMessageRequestParams { meta: None, task: None, - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("Hello"), - }], + messages: vec![SamplingMessage::user_text("Hello")], include_context: Some(ContextInclusion::None), model_preferences: None, system_prompt: None, @@ -232,6 +220,8 @@ async fn test_sampling_no_context_inclusion() -> Result<()> { max_tokens: 50, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }, extensions: Default::default(), }); @@ -254,7 +244,15 @@ async fn test_sampling_no_context_inclusion() -> Result<()> { assert_eq!(result.message.role, Role::Assistant); assert_eq!(result.model, "test-model"); - let response_text = result.message.content.as_text().unwrap().text.as_str(); + let response_text = result + .message + .content + .first() + .unwrap() + .as_text() + .unwrap() + .text + .as_str(); assert!( !response_text.contains("test context"), "Response should not include context for None inclusion" @@ -292,10 +290,9 @@ async fn test_sampling_error_invalid_message_sequence() -> Result<()> { params: CreateMessageRequestParams { meta: None, task: None, - messages: vec![SamplingMessage { - role: Role::Assistant, - content: Content::text("I'm an assistant message without a user message"), - }], + messages: vec![SamplingMessage::assistant_text( + "I'm an assistant message without a user message", + )], include_context: Some(ContextInclusion::None), model_preferences: None, system_prompt: None, @@ -303,6 +300,8 @@ async fn test_sampling_error_invalid_message_sequence() -> Result<()> { max_tokens: 50, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }, extensions: Default::default(), }); @@ -327,3 +326,212 @@ async fn test_sampling_error_invalid_message_sequence() -> Result<()> { server_handle.await??; Ok(()) } + +#[tokio::test] +async fn test_tool_choice_serialization() -> Result<()> { + let auto = ToolChoice::auto(); + let json = serde_json::to_string(&auto)?; + assert!(json.contains("auto")); + let deserialized: ToolChoice = serde_json::from_str(&json)?; + assert_eq!(auto, deserialized); + + let required = ToolChoice::required(); + let json = serde_json::to_string(&required)?; + assert!(json.contains("required")); + let deserialized: ToolChoice = serde_json::from_str(&json)?; + assert_eq!(required, deserialized); + + let none = ToolChoice::none(); + let json = serde_json::to_string(&none)?; + assert!(json.contains("none")); + let deserialized: ToolChoice = serde_json::from_str(&json)?; + assert_eq!(none, deserialized); + + Ok(()) +} + +#[tokio::test] +async fn test_sampling_with_tools() -> Result<()> { + use std::sync::Arc; + + let tool = Tool::new( + "get_weather", + "Get the current weather for a location", + Arc::new( + serde_json::json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + }) + .as_object() + .unwrap() + .clone(), + ), + ); + + let params = CreateMessageRequestParams { + meta: None, + task: None, + messages: vec![SamplingMessage::user_text( + "What's the weather in San Francisco?", + )], + model_preferences: None, + system_prompt: None, + include_context: None, + temperature: None, + max_tokens: 100, + stop_sequences: None, + metadata: None, + tools: Some(vec![tool]), + tool_choice: Some(ToolChoice::auto()), + }; + + let json = serde_json::to_string(¶ms)?; + let deserialized: CreateMessageRequestParams = serde_json::from_str(&json)?; + + assert!(deserialized.tools.is_some()); + assert_eq!(deserialized.tools.as_ref().unwrap().len(), 1); + assert_eq!(deserialized.tools.as_ref().unwrap()[0].name, "get_weather"); + assert!(deserialized.tool_choice.is_some()); + + Ok(()) +} + +#[tokio::test] +async fn test_tool_use_content_serialization() -> Result<()> { + let tool_use = ToolUseContent::new( + "call_123", + "get_weather", + serde_json::json!({ + "location": "San Francisco, CA" + }) + .as_object() + .unwrap() + .clone(), + ); + + let json = serde_json::to_string(&tool_use)?; + let deserialized: ToolUseContent = serde_json::from_str(&json)?; + assert_eq!(tool_use, deserialized); + assert_eq!(deserialized.id, "call_123"); + assert_eq!(deserialized.name, "get_weather"); + + Ok(()) +} + +#[tokio::test] +async fn test_tool_result_content_serialization() -> Result<()> { + let tool_result = ToolResultContent::new( + "call_123", + vec![Content::text( + "The weather in San Francisco is 72°F and sunny.", + )], + ); + + let json = serde_json::to_string(&tool_result)?; + let deserialized: ToolResultContent = serde_json::from_str(&json)?; + assert_eq!(tool_result, deserialized); + assert_eq!(deserialized.tool_use_id, "call_123"); + assert!(!deserialized.content.is_empty()); + + Ok(()) +} + +#[tokio::test] +async fn test_sampling_message_with_tool_use() -> Result<()> { + let message = SamplingMessage::assistant_tool_use( + "call_123", + "get_weather", + serde_json::json!({ + "location": "San Francisco, CA" + }) + .as_object() + .unwrap() + .clone(), + ); + + let json = serde_json::to_string(&message)?; + let deserialized: SamplingMessage = serde_json::from_str(&json)?; + assert_eq!(message, deserialized); + assert_eq!(deserialized.role, Role::Assistant); + + let tool_use = deserialized.content.first().unwrap().as_tool_use().unwrap(); + assert_eq!(tool_use.name, "get_weather"); + + Ok(()) +} + +#[tokio::test] +async fn test_sampling_message_with_tool_result() -> Result<()> { + let message = + SamplingMessage::user_tool_result("call_123", vec![Content::text("72°F and sunny")]); + + let json = serde_json::to_string(&message)?; + let deserialized: SamplingMessage = serde_json::from_str(&json)?; + assert_eq!(message, deserialized); + assert_eq!(deserialized.role, Role::User); + + let tool_result = deserialized + .content + .first() + .unwrap() + .as_tool_result() + .unwrap(); + assert_eq!(tool_result.tool_use_id, "call_123"); + + Ok(()) +} + +#[tokio::test] +async fn test_create_message_result_tool_use_stop_reason() -> Result<()> { + let result = CreateMessageResult { + message: SamplingMessage::assistant_tool_use( + "call_123", + "get_weather", + serde_json::json!({ + "location": "San Francisco" + }) + .as_object() + .unwrap() + .clone(), + ), + model: "test-model".to_string(), + stop_reason: Some(CreateMessageResult::STOP_REASON_TOOL_USE.to_string()), + }; + + let json = serde_json::to_string(&result)?; + let deserialized: CreateMessageResult = serde_json::from_str(&json)?; + assert_eq!(result, deserialized); + assert_eq!(deserialized.stop_reason, Some("toolUse".to_string())); + + Ok(()) +} + +#[tokio::test] +async fn test_sampling_capability() -> Result<()> { + let cap = SamplingCapability { + tools: Some(JsonObject::default()), + context: None, + }; + + let json = serde_json::to_string(&cap)?; + let deserialized: SamplingCapability = serde_json::from_str(&json)?; + assert_eq!(cap, deserialized); + assert!(deserialized.tools.is_some()); + assert!(deserialized.context.is_none()); + + let client_cap = ClientCapabilities::builder() + .enable_sampling() + .enable_sampling_tools() + .build(); + + assert!(client_cap.sampling.is_some()); + assert!(client_cap.sampling.as_ref().unwrap().tools.is_some()); + + Ok(()) +} diff --git a/examples/clients/src/sampling_stdio.rs b/examples/clients/src/sampling_stdio.rs index cdefad58..e2a7a6d5 100644 --- a/examples/clients/src/sampling_stdio.rs +++ b/examples/clients/src/sampling_stdio.rs @@ -41,10 +41,7 @@ impl ClientHandler for SamplingDemoClient { self.mock_llm_response(¶ms.messages, params.system_prompt.as_deref()); Ok(CreateMessageResult { - message: SamplingMessage { - role: Role::Assistant, - content: Content::text(response_text), - }, + message: SamplingMessage::assistant_text(response_text), model: "mock_llm".to_string(), stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), }) diff --git a/examples/servers/src/sampling_stdio.rs b/examples/servers/src/sampling_stdio.rs index 29198bb8..1b57b353 100644 --- a/examples/servers/src/sampling_stdio.rs +++ b/examples/servers/src/sampling_stdio.rs @@ -51,10 +51,7 @@ impl ServerHandler for SamplingDemoServer { .create_message(CreateMessageRequestParams { meta: None, task: None, - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text(question), - }], + messages: vec![SamplingMessage::user_text(question)], model_preferences: Some(ModelPreferences { hints: Some(vec![ModelHint { name: Some("claude".to_string()), @@ -69,6 +66,8 @@ impl ServerHandler for SamplingDemoServer { max_tokens: 150, stop_sequences: None, metadata: None, + tools: None, + tool_choice: None, }) .await .map_err(|e| { @@ -85,7 +84,8 @@ impl ServerHandler for SamplingDemoServer { response .message .content - .as_text() + .first() + .and_then(|c| c.as_text()) .map(|t| &t.text) .unwrap_or(&"No text response".to_string()) ))]))