diff --git a/rust/lance-namespace-impls/Cargo.toml b/rust/lance-namespace-impls/Cargo.toml index a91ae937671..9ce32692ffc 100644 --- a/rust/lance-namespace-impls/Cargo.toml +++ b/rust/lance-namespace-impls/Cargo.toml @@ -48,7 +48,7 @@ arrow-schema = { workspace = true } # REST adapter implementation dependencies (optional, enabled by "rest-adapter" feature) axum = { workspace = true, optional = true } tower = { workspace = true, optional = true } -tower-http = { workspace = true, optional = true, features = ["trace", "cors"] } +tower-http = { workspace = true, optional = true, features = ["trace", "cors", "normalize-path"] } serde = { workspace = true, optional = true } # Common dependencies diff --git a/rust/lance-namespace-impls/src/rest_adapter.rs b/rust/lance-namespace-impls/src/rest_adapter.rs index dd94b15e7c4..284b0d42fa9 100644 --- a/rust/lance-namespace-impls/src/rest_adapter.rs +++ b/rust/lance-namespace-impls/src/rest_adapter.rs @@ -11,14 +11,16 @@ use std::sync::Arc; use axum::{ body::Bytes, - extract::{Path, Query, State}, + extract::{Path, Query, Request, State}, http::StatusCode, response::{IntoResponse, Response}, routing::{get, post}, - Json, Router, + Json, Router, ServiceExt, }; use serde::Deserialize; use tokio::sync::watch; +use tower::Layer; +use tower_http::normalize_path::NormalizePathLayer; use tower_http::trace::TraceLayer; use lance_core::{Error, Result}; @@ -154,9 +156,10 @@ impl RestAdapter { let (shutdown_tx, mut shutdown_rx) = watch::channel(false); let (done_tx, done_rx) = tokio::sync::oneshot::channel::<()>(); let router = self.router(); + let app = NormalizePathLayer::trim_trailing_slash().layer(router); tokio::spawn(async move { - let result = axum::serve(listener, router) + let result = axum::serve(listener, ServiceExt::::into_make_service(app)) .with_graceful_shutdown(async move { let _ = shutdown_rx.changed().await; }) @@ -1168,6 +1171,60 @@ mod tests { Bytes::from(buffer) } + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_trailing_slash_handling() { + let fixture = RestServerFixture::new().await; + let port = fixture.server_handle.port(); + + // Create a namespace using the normal API (without trailing slash) + let create_req = CreateNamespaceRequest { + id: Some(vec!["test_namespace".to_string()]), + properties: None, + mode: None, + }; + fixture + .namespace + .create_namespace(create_req) + .await + .unwrap(); + + // Test that a request with trailing slash works (using direct HTTP) + let client = reqwest::Client::new(); + + // Test POST endpoint with trailing slash + let response = client + .post(format!( + "http://127.0.0.1:{}/v1/namespace/test_namespace/exists/", + port + )) + .json(&serde_json::json!({})) + .send() + .await + .unwrap(); + + assert_eq!( + response.status(), + 204, + "POST request with trailing slash should succeed with 204 No Content" + ); + + // Test GET endpoint with trailing slash + let response = client + .get(format!( + "http://127.0.0.1:{}/v1/namespace/test_namespace/list/", + port + )) + .send() + .await + .unwrap(); + + assert!( + response.status().is_success(), + "GET request with trailing slash should succeed, got status: {}", + response.status() + ); + } + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_create_and_list_child_namespaces() { let fixture = RestServerFixture::new().await;