Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 257 additions & 3 deletions crates/rmcp/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,152 @@ pub enum Role {
Assistant,
}

/// Tool selection mode (SEP-1577).
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub enum ToolChoiceMode {
/// Model decides whether to use tools
Auto,
/// Model must use at least one tool
Required,
/// Model must not use tools
None,
}

impl Default for ToolChoiceMode {
fn default() -> Self {
Self::Auto
}
}

/// Tool choice configuration (SEP-1577).
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct ToolChoice {
#[serde(skip_serializing_if = "Option::is_none")]
pub mode: Option<ToolChoiceMode>,
}

impl ToolChoice {
pub fn auto() -> Self {
Self {
mode: Some(ToolChoiceMode::Auto),
}
}

pub fn required() -> Self {
Self {
mode: Some(ToolChoiceMode::Required),
}
}

pub fn none() -> Self {
Self {
mode: Some(ToolChoiceMode::None),
}
}
}

/// Single or array content wrapper (SEP-1577).
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub enum SamplingContent<T> {
Single(T),
Multiple(Vec<T>),
}

impl<T> SamplingContent<T> {
/// Convert to a Vec regardless of whether it's single or multiple
pub fn into_vec(self) -> Vec<T> {
match self {
SamplingContent::Single(item) => vec![item],
SamplingContent::Multiple(items) => items,
}
}

/// Check if the content is empty
pub fn is_empty(&self) -> bool {
match self {
SamplingContent::Single(_) => false,
SamplingContent::Multiple(items) => items.is_empty(),
}
}

/// Get the number of content items
pub fn len(&self) -> usize {
match self {
SamplingContent::Single(_) => 1,
SamplingContent::Multiple(items) => items.len(),
}
}
}

impl<T> Default for SamplingContent<T> {
fn default() -> Self {
SamplingContent::Multiple(Vec::new())
}
}

impl<T> SamplingContent<T> {
/// Get the first item if present
pub fn first(&self) -> Option<&T> {
match self {
SamplingContent::Single(item) => Some(item),
SamplingContent::Multiple(items) => items.first(),
}
}

/// Iterate over all content items
pub fn iter(&self) -> impl Iterator<Item = &T> {
let items: Vec<&T> = match self {
SamplingContent::Single(item) => vec![item],
SamplingContent::Multiple(items) => items.iter().collect(),
};
items.into_iter()
}
}

impl SamplingMessageContent {
/// Get the text content if this is a Text variant
pub fn as_text(&self) -> Option<&RawTextContent> {
match self {
SamplingMessageContent::Text(text) => Some(text),
_ => None,
}
}

/// Get the tool use content if this is a ToolUse variant
pub fn as_tool_use(&self) -> Option<&ToolUseContent> {
match self {
SamplingMessageContent::ToolUse(tool_use) => Some(tool_use),
_ => None,
}
}

/// Get the tool result content if this is a ToolResult variant
pub fn as_tool_result(&self) -> Option<&ToolResultContent> {
match self {
SamplingMessageContent::ToolResult(tool_result) => Some(tool_result),
_ => None,
}
}
}

impl<T> From<T> for SamplingContent<T> {
fn from(item: T) -> Self {
SamplingContent::Single(item)
}
}

impl<T> From<Vec<T>> for SamplingContent<T> {
fn from(items: Vec<T>) -> Self {
SamplingContent::Multiple(items)
}
}

/// A message in a sampling conversation, containing a role and content.
///
/// This represents a single message in a conversation flow, used primarily
Expand All @@ -1219,8 +1365,106 @@ pub enum Role {
pub struct SamplingMessage {
/// The role of the message sender (User or Assistant)
pub role: Role,
/// The actual content of the message (text, image, etc.)
pub content: Content,
/// The actual content of the message (text, image, audio, tool use, or tool result)
pub content: SamplingContent<SamplingMessageContent>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<Meta>,
}

/// Content types for sampling messages (SEP-1577).
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub enum SamplingMessageContent {
Text(RawTextContent),
Image(RawImageContent),
Audio(RawAudioContent),
/// Assistant only
ToolUse(ToolUseContent),
/// User only
ToolResult(ToolResultContent),
}

impl SamplingMessageContent {
/// Create a text content
pub fn text(text: impl Into<String>) -> Self {
Self::Text(RawTextContent {
text: text.into(),
meta: None,
})
}

pub fn tool_use(id: impl Into<String>, name: impl Into<String>, input: JsonObject) -> Self {
Self::ToolUse(ToolUseContent::new(id, name, input))
}

pub fn tool_result(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
Self::ToolResult(ToolResultContent::new(tool_use_id, content))
}
}

impl SamplingMessage {
pub fn new(role: Role, content: impl Into<SamplingMessageContent>) -> Self {
Self {
role,
content: SamplingContent::Single(content.into()),
meta: None,
}
}

pub fn new_multiple(role: Role, contents: Vec<SamplingMessageContent>) -> Self {
Self {
role,
content: SamplingContent::Multiple(contents),
meta: None,
}
}

pub fn user_text(text: impl Into<String>) -> Self {
Self::new(Role::User, SamplingMessageContent::text(text))
}

pub fn assistant_text(text: impl Into<String>) -> Self {
Self::new(Role::Assistant, SamplingMessageContent::text(text))
}

pub fn user_tool_result(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
Self::new(
Role::User,
SamplingMessageContent::tool_result(tool_use_id, content),
)
}

pub fn assistant_tool_use(
id: impl Into<String>,
name: impl Into<String>,
input: JsonObject,
) -> Self {
Self::new(
Role::Assistant,
SamplingMessageContent::tool_use(id, name, input),
)
}
}

// Conversion from RawTextContent to SamplingMessageContent
impl From<RawTextContent> for SamplingMessageContent {
fn from(text: RawTextContent) -> Self {
SamplingMessageContent::Text(text)
}
}

// Conversion from String to SamplingMessageContent (as text)
impl From<String> for SamplingMessageContent {
fn from(text: String) -> Self {
SamplingMessageContent::text(text)
}
}

impl From<&str> for SamplingMessageContent {
fn from(text: &str) -> Self {
SamplingMessageContent::text(text)
}
}

/// Specifies how much context should be included in sampling requests.
Expand Down Expand Up @@ -1281,6 +1525,12 @@ pub struct CreateMessageRequestParams {
/// Additional metadata for the request
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<Value>,
/// Tools available for the model to call (SEP-1577)
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
/// Tool selection behavior (SEP-1577)
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}

impl RequestParamsMeta for CreateMessageRequestParams {
Expand Down Expand Up @@ -1926,6 +2176,7 @@ pub type CallToolRequestParam = CallToolRequestParams;
/// Request to call a specific tool
pub type CallToolRequest = Request<CallToolRequestMethod, CallToolRequestParams>;

/// Result of sampling/createMessage (SEP-1577).
/// The result of a sampling/createMessage request containing the generated response.
///
/// This structure contains the generated message along with metadata about
Expand All @@ -1948,6 +2199,7 @@ impl CreateMessageResult {
pub const STOP_REASON_END_TURN: &str = "endTurn";
pub const STOP_REASON_END_SEQUENCE: &str = "stopSequence";
pub const STOP_REASON_END_MAX_TOKEN: &str = "maxTokens";
pub const STOP_REASON_TOOL_USE: &str = "toolUse";
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
Expand Down Expand Up @@ -2476,7 +2728,9 @@ mod tests {
..
}) => {
assert_eq!(capabilities.roots.unwrap().list_changed, Some(true));
assert_eq!(capabilities.sampling.unwrap().len(), 0);
let sampling = capabilities.sampling.unwrap();
assert_eq!(sampling.tools, None);
assert_eq!(sampling.context, None);
assert_eq!(client_info.name, "ExampleClient");
assert_eq!(client_info.version, "1.0.0");
}
Expand Down
38 changes: 36 additions & 2 deletions crates/rmcp/src/model/capabilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,19 @@ pub struct ElicitationCapability {
pub schema_validation: Option<bool>,
}

/// Sampling capability with optional sub-capabilities (SEP-1577).
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct SamplingCapability {
/// Support for `tools` and `toolChoice` parameters
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<JsonObject>,
/// Support for `includeContext` (soft-deprecated)
#[serde(skip_serializing_if = "Option::is_none")]
pub context: Option<JsonObject>,
}

///
/// # Builder
/// ```rust
Expand All @@ -189,8 +202,9 @@ pub struct ClientCapabilities {
pub experimental: Option<ExperimentalCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub roots: Option<RootsCapabilities>,
/// Capability for LLM sampling requests (SEP-1577)
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling: Option<JsonObject>,
pub sampling: Option<SamplingCapability>,
/// Capability to handle elicitation requests from servers for interactive user input
#[serde(skip_serializing_if = "Option::is_none")]
pub elicitation: Option<ElicitationCapability>,
Expand Down Expand Up @@ -392,7 +406,7 @@ builder! {
ClientCapabilities{
experimental: ExperimentalCapabilities,
roots: RootsCapabilities,
sampling: JsonObject,
sampling: SamplingCapability,
elicitation: ElicitationCapability,
tasks: TasksCapability,
}
Expand All @@ -409,6 +423,26 @@ impl<const E: bool, const S: bool, const EL: bool, const TASKS: bool>
}
}

impl<const E: bool, const R: bool, const EL: bool, const TASKS: bool>
ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<E, R, true, EL, TASKS>>
{
/// Enable tool calling in sampling requests
pub fn enable_sampling_tools(mut self) -> Self {
if let Some(c) = self.sampling.as_mut() {
c.tools = Some(JsonObject::default());
}
self
}

/// Enable context inclusion in sampling (soft-deprecated)
pub fn enable_sampling_context(mut self) -> Self {
if let Some(c) = self.sampling.as_mut() {
c.context = Some(JsonObject::default());
}
self
}
}

#[cfg(feature = "elicitation")]
impl<const E: bool, const R: bool, const S: bool, const TASKS: bool>
ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<E, R, S, true, TASKS>>
Expand Down
Loading