Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

152 changes: 103 additions & 49 deletions crates/coglet/src/input_validation.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! Input validation against the OpenAPI schema.
//!
//! Validates prediction inputs before dispatching to the Python worker,
//! catching missing required fields and unknown fields early with clear
//! error messages (matching the format users expect from pydantic).
//! Validates prediction inputs before dispatching to the Python worker.
//! Strips unknown fields silently and catches missing required fields
//! with clear error messages (matching the format users expect from pydantic).

use std::collections::HashSet;

Expand Down Expand Up @@ -31,8 +31,9 @@ pub struct InputValidator {
impl InputValidator {
/// Build a validator from a full OpenAPI schema document.
///
/// Extracts `components.schemas.Input`, injects `additionalProperties: false`
/// (for pydantic parity), and compiles a JSON Schema validator.
/// Extracts `components.schemas.Input` and compiles a JSON Schema validator.
/// Unknown input fields should be stripped via `strip_unknown()` before
/// calling `validate()`.
///
/// Returns None if the schema doesn't contain an Input component.
pub fn from_openapi_schema(schema: &Value) -> Option<Self> {
Expand Down Expand Up @@ -62,11 +63,7 @@ impl InputValidator {
})
.unwrap_or_default();

// Clone and inject additionalProperties: false for pydantic parity
let mut resolved = input_schema.clone();
if let Some(obj) = resolved.as_object_mut() {
obj.insert("additionalProperties".to_string(), Value::Bool(false));
}

// Inline $ref pointers so the validator can resolve them without
// the full OpenAPI document context. cog-schema-gen emits $ref for
Expand All @@ -91,6 +88,22 @@ impl InputValidator {
self.required.len()
}

/// Strip unknown input fields in place, returning the names of removed fields.
pub fn strip_unknown(&self, input: &mut Value) -> Vec<String> {
let Some(obj) = input.as_object_mut() else {
return Vec::new();
};
let unknown_keys: Vec<String> = obj
.keys()
.filter(|k| !self.properties.contains(*k))
.cloned()
.collect();
for key in &unknown_keys {
obj.remove(key);
}
unknown_keys
}

/// Validate an input value against the schema.
///
/// Returns Ok(()) on success, or a list of per-field validation errors
Expand All @@ -102,7 +115,6 @@ impl InputValidator {

let mut errors = Vec::new();
let mut seen_required = false;
let mut seen_additional = false;

for error in self.validator.iter_errors(input) {
let msg = error.to_string();
Expand All @@ -126,30 +138,10 @@ impl InputValidator {
continue;
}

// "additionalProperties" errors: emit one entry per unknown field
if msg.contains("Additional properties") && !seen_additional {
seen_additional = true;
if let Some(input_obj) = input.as_object() {
for key in input_obj.keys() {
if !self.properties.contains(key) {
errors.push(ValidationError {
field: key.clone(),
msg: format!("Unexpected field '{key}'"),
error_type: "value_error.extra".to_string(),
});
}
}
}
continue;
}

// Skip duplicate required/additional messages
// Skip duplicate required messages
if seen_required && msg.contains("is a required property") {
continue;
}
if seen_additional && msg.contains("Additional properties") {
continue;
}

// Type/constraint errors on specific fields
let path = error.instance_path.to_string();
Expand Down Expand Up @@ -250,7 +242,7 @@ mod tests {
}

#[test]
fn rejects_additional_properties() {
fn allows_additional_properties_in_validate() {
let schema = make_schema(json!({
"type": "object",
"properties": {
Expand All @@ -261,17 +253,17 @@ mod tests {

let validator = InputValidator::from_openapi_schema(&schema).unwrap();

// Extra field should fail
let errs = validator
.validate(&json!({"s": "hello", "extra": "bad"}))
.unwrap_err();
assert_eq!(errs.len(), 1);
assert_eq!(errs[0].field, "extra");
assert!(errs[0].msg.contains("Unexpected"));
// Extra fields should NOT cause validation failure — they get stripped separately
assert!(
validator
.validate(&json!({"s": "hello", "extra": "bad"}))
.is_ok(),
"unknown inputs should not cause validation errors"
);
}

#[test]
fn missing_and_extra_fields() {
fn strip_unknown_removes_extra_fields() {
let schema = make_schema(json!({
"type": "object",
"properties": {
Expand All @@ -282,15 +274,77 @@ mod tests {

let validator = InputValidator::from_openapi_schema(&schema).unwrap();

// wrong=value with missing s
let errs = validator.validate(&json!({"wrong": "value"})).unwrap_err();
assert!(errs.len() >= 2);
let fields: Vec<&str> = errs.iter().map(|e| e.field.as_str()).collect();
assert!(fields.contains(&"s"), "should report missing s: {fields:?}");
assert!(
fields.contains(&"wrong"),
"should report extra wrong: {fields:?}"
);
let mut input = json!({"s": "hello", "guidance_scale": 7.5, "extra": "bad"});
let stripped = validator.strip_unknown(&mut input);

// Should have removed the unknown fields
assert_eq!(stripped.len(), 2);
assert!(stripped.contains(&"guidance_scale".to_string()));
assert!(stripped.contains(&"extra".to_string()));

// Known field should remain
assert_eq!(input, json!({"s": "hello"}));
}

#[test]
fn strip_unknown_preserves_known_fields() {
let schema = make_schema(json!({
"type": "object",
"properties": {
"s": {"type": "string", "title": "S"},
"n": {"type": "integer"}
},
"required": ["s"]
}));

let validator = InputValidator::from_openapi_schema(&schema).unwrap();

let mut input = json!({"s": "hello", "n": 42});
let stripped = validator.strip_unknown(&mut input);

assert!(stripped.is_empty());
assert_eq!(input, json!({"s": "hello", "n": 42}));
}

#[test]
fn strip_unknown_returns_empty_for_no_extra_fields() {
let schema = make_schema(json!({
"type": "object",
"properties": {
"s": {"type": "string", "title": "S"}
},
"required": ["s"]
}));

let validator = InputValidator::from_openapi_schema(&schema).unwrap();

let mut input = json!({"s": "hello"});
let stripped = validator.strip_unknown(&mut input);
assert!(stripped.is_empty());
}

#[test]
fn missing_required_with_extra_fields() {
let schema = make_schema(json!({
"type": "object",
"properties": {
"s": {"type": "string", "title": "S"}
},
"required": ["s"]
}));

let validator = InputValidator::from_openapi_schema(&schema).unwrap();

// Strip unknowns first, then validate — only the missing required field
// should be an error, not the extra field
let mut input = json!({"wrong": "value"});
let stripped = validator.strip_unknown(&mut input);
assert_eq!(stripped, vec!["wrong".to_string()]);

let errs = validator.validate(&input).unwrap_err();
assert_eq!(errs.len(), 1);
assert_eq!(errs[0].field, "s");
assert_eq!(errs[0].msg, "Field required");
}

#[test]
Expand Down
52 changes: 34 additions & 18 deletions crates/coglet/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,36 +337,52 @@ impl PredictionService {
self.schema.read().await.clone()
}

/// Validate prediction input against the OpenAPI schema.
/// Strip unknown fields from input and validate in one pass.
///
/// Returns Ok(()) if no schema is loaded or if validation passes.
/// Returns Err with per-field validation errors on failure.
pub async fn validate_input(
/// Unknown inputs are silently dropped to match Replicate's historical API
/// behavior. Returns the stripped field names and the validation result
/// under a single lock acquisition.
pub async fn strip_and_validate_input(
&self,
input: &serde_json::Value,
) -> Result<(), Vec<crate::input_validation::ValidationError>> {
input: &mut serde_json::Value,
) -> (
Vec<String>,
Result<(), Vec<crate::input_validation::ValidationError>>,
) {
let guard = self.input_validator.read().await;
if let Some(ref validator) = *guard {
validator.validate(input)
let stripped = validator.strip_unknown(input);
let result = validator.validate(input);
(stripped, result)
} else {
Ok(())
(Vec::new(), Ok(()))
}
}

/// Validate training input against the TrainingInput schema.
/// Strip unknown fields from training input and validate in one pass.
///
/// Falls back to the predict validator if no training schema is present.
pub async fn validate_train_input(
pub async fn strip_and_validate_train_input(
Comment on lines +362 to +365
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Race condition: This method drops the read lock on train_validator before calling strip_and_validate_input. Between dropping the guard and acquiring a new one in strip_and_validate_input, the schema could change, causing the second validation to potentially use a different validator.

Consider restructuring to hold both locks or perform both operations under a single lock acquisition.

Suggested change
/// Strip unknown fields from training input and validate in one pass.
///
/// Falls back to the predict validator if no training schema is present.
pub async fn validate_train_input(
pub async fn strip_and_validate_train_input(
/// Strip unknown fields from training input and validate in one pass.
///
/// Falls back to the predict validator if no training schema is present.
pub async fn strip_and_validate_train_input(
&self,
input: &mut serde_json::Value,
) -> (
Vec<String>,
Result<(), Vec<crate::input_validation::ValidationError>>,
) {
let train_guard = self.train_validator.read().await;
if let Some(ref validator) = *train_guard {
let stripped = validator.strip_unknown(input);
let result = validator.validate(input);
return (stripped, result);
}
drop(train_guard);
// Try the predict validator as fallback
let predict_guard = self.input_validator.read().await;
if let Some(ref validator) = *predict_guard {
let stripped = validator.strip_unknown(input);
let result = validator.validate(input);
return (stripped, result);
}
(Vec::new(), Ok(()))
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion seems legit

&self,
input: &serde_json::Value,
) -> Result<(), Vec<crate::input_validation::ValidationError>> {
let guard = self.train_validator.read().await;
if let Some(ref validator) = *guard {
return validator.validate(input);
input: &mut serde_json::Value,
) -> (
Vec<String>,
Result<(), Vec<crate::input_validation::ValidationError>>,
) {
let train_guard = self.train_validator.read().await;
if let Some(ref validator) = *train_guard {
let stripped = validator.strip_unknown(input);
let result = validator.validate(input);
return (stripped, result);
}
drop(train_guard);
let predict_guard = self.input_validator.read().await;
if let Some(ref validator) = *predict_guard {
let stripped = validator.strip_unknown(input);
let result = validator.validate(input);
return (stripped, result);
}
drop(guard);
// Fallback: no TrainingInput schema — use predict validator (legacy compat)
self.validate_input(input).await
(Vec::new(), Ok(()))
}
Comment thread
michaeldwan marked this conversation as resolved.

/// Run user-defined healthcheck via orchestrator.
Expand Down
18 changes: 13 additions & 5 deletions crates/coglet/src/transport/http/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,20 +329,28 @@ fn build_webhook_sender(
async fn create_prediction_with_id(
service: Arc<PredictionService>,
prediction_id: String,
input: serde_json::Value,
mut input: serde_json::Value,
context: std::collections::HashMap<String, String>,
webhook: Option<String>,
webhook_events_filter: Vec<WebhookEventType>,
respond_async: bool,
trace_context: TraceContext,
is_training: bool,
) -> (StatusCode, Json<serde_json::Value>) {
// Validate input against the appropriate schema
let validation_result = if is_training {
service.validate_train_input(&input).await
// Strip unknown fields and validate in one pass. Unknown inputs are
// silently dropped to match Replicate's historical API behavior.
let (stripped, validation_result) = if is_training {
service.strip_and_validate_train_input(&mut input).await
} else {
service.validate_input(&input).await
service.strip_and_validate_input(&mut input).await
};
if !stripped.is_empty() {
tracing::warn!(
prediction_id = %prediction_id,
fields = ?stripped,
"Stripped unknown input fields"
);
}
if let Err(errors) = validation_result {
let detail: Vec<serde_json::Value> = errors
.into_iter()
Expand Down
Loading
Loading