diff --git a/.changeset/stream-multipart-uploads.md b/.changeset/stream-multipart-uploads.md new file mode 100644 index 00000000..868f7ace --- /dev/null +++ b/.changeset/stream-multipart-uploads.md @@ -0,0 +1,5 @@ +--- +"@googleworkspace/cli": patch +--- + +Stream file uploads instead of buffering entire file in memory, fixing OOM crashes on large files diff --git a/Cargo.lock b/Cargo.lock index 68931d42..29d54997 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -851,6 +851,7 @@ dependencies = [ "anyhow", "async-trait", "base64", + "bytes", "chrono", "clap", "crossterm", @@ -872,6 +873,7 @@ dependencies = [ "tempfile", "thiserror 2.0.18", "tokio", + "tokio-util", "yup-oauth2", "zeroize", ] diff --git a/Cargo.toml b/Cargo.toml index 44bf0235..ebd547f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,7 @@ serde_json = "1" sha2 = "0.10" thiserror = "2" tokio = { version = "1", features = ["full"] } +tokio-util = { version = "0.7", features = ["io"] } yup-oauth2 = "12" futures-util = "0.3" base64 = "0.22.1" @@ -57,6 +58,7 @@ async-trait = "0.1.89" serde_yaml = "0.9.34" percent-encoding = "2.3.2" zeroize = { version = "1.8.2", features = ["derive"] } +bytes = "1.11.1" # The profile that 'cargo dist' will build with diff --git a/src/executor.rs b/src/executor.rs index 49101ece..78982d59 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -22,7 +22,7 @@ use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use anyhow::Context; -use futures_util::StreamExt; +use futures_util::{StreamExt, TryStreamExt}; use serde_json::{json, Map, Value}; use tokio::io::AsyncWriteExt; @@ -182,17 +182,22 @@ async fn build_http_request( if input.is_upload { let upload_path = upload_path.expect("upload_path must be Some when is_upload is true"); - let file_bytes = tokio::fs::read(upload_path).await.map_err(|e| { - GwsError::Validation(format!( - "Failed to read upload file '{}': {}", - upload_path, e - )) - })?; + let file_size = tokio::fs::metadata(upload_path) + .await + .map_err(|e| { + GwsError::Validation(format!( + "Failed to read upload file '{}': {}", + upload_path, e + )) + })? + .len(); request = request.query(&[("uploadType", "multipart")]); - let (multipart_body, content_type) = build_multipart_body(&input.body, &file_bytes)?; + let (body, content_type, content_length) = + build_multipart_stream(&input.body, upload_path, file_size); request = request.header("Content-Type", content_type); - request = request.body(multipart_body); + request = request.header("Content-Length", content_length); + request = request.body(body); } else if let Some(ref body_val) = input.body { request = request.header("Content-Type", "application/json"); request = request.json(body_val); @@ -731,6 +736,7 @@ fn handle_error_response( /// Builds a multipart/related body for media upload requests. /// /// Returns the body bytes and the Content-Type header value (with boundary). +#[cfg(test)] fn build_multipart_body( metadata: &Option, file_bytes: &[u8], @@ -768,6 +774,67 @@ fn build_multipart_body( Ok((body, content_type)) } +/// Build a streaming multipart/related body for file uploads. +/// +/// Instead of reading the entire file into memory, this streams the file +/// contents from disk in 64 KB chunks, keeping memory usage constant +/// regardless of file size. Returns `(body, content_type, content_length)`. +fn build_multipart_stream( + metadata: &Option, + file_path: &str, + file_size: u64, +) -> (reqwest::Body, String, u64) { + let boundary = format!("gws_boundary_{:016x}", rand::random::()); + + let media_mime = metadata + .as_ref() + .and_then(|m| m.get("mimeType")) + .and_then(|v| v.as_str()) + .unwrap_or("application/octet-stream") + .to_string(); + + let metadata_json = metadata + .as_ref() + .map(|m| serde_json::to_string(m).unwrap_or_else(|_| "{}".to_string())) + .unwrap_or_else(|| "{}".to_string()); + + let preamble = format!( + "--{boundary}\r\n\ + Content-Type: application/json; charset=UTF-8\r\n\r\n\ + {metadata_json}\r\n\ + --{boundary}\r\n\ + Content-Type: {media_mime}\r\n\r\n" + ); + let postamble = format!("\r\n--{boundary}--\r\n"); + + let content_length = preamble.len() as u64 + file_size + postamble.len() as u64; + let content_type = format!("multipart/related; boundary={boundary}"); + + // Chain: preamble bytes -> file chunks (via ReaderStream) -> postamble bytes + // All parts use bytes::Bytes for zero-copy streaming. + let file_path = file_path.to_owned(); + let preamble_bytes = bytes::Bytes::from(preamble.into_bytes()); + let postamble_bytes = bytes::Bytes::from(postamble.into_bytes()); + + let file_stream = + futures_util::stream::once(async move { tokio::fs::File::open(file_path).await }) + .map_ok(tokio_util::io::ReaderStream::new) + .try_flatten(); + + let stream = + futures_util::stream::once(async { Ok::<_, std::io::Error>(preamble_bytes) }) + .chain(file_stream) + .chain(futures_util::stream::once(async { + Ok::<_, std::io::Error>(postamble_bytes) + })); + + ( + reqwest::Body::wrap_stream(stream), + content_type, + content_length, + ) +} + /// Validates a JSON body against a Discovery Document schema. fn validate_body_against_schema( body: &Value, @@ -1218,6 +1285,57 @@ mod tests { assert!(body_str.contains("Binary data")); } + #[tokio::test] + async fn test_build_multipart_stream_content_length() { + use std::io::Write; + let metadata = Some(json!({ "name": "test.txt", "mimeType": "text/plain" })); + let content = b"Hello streaming world"; + + let mut tmp = tempfile::NamedTempFile::new().unwrap(); + tmp.write_all(content).unwrap(); + let path = tmp.path().to_str().unwrap().to_string(); + + let (_, content_type, content_length) = + build_multipart_stream(&metadata, &path, content.len() as u64); + + assert!(content_type.starts_with("multipart/related; boundary=")); + let boundary = content_type.split("boundary=").nth(1).unwrap(); + assert!(boundary.starts_with("gws_boundary_")); + + // Verify content_length matches the expected structure: + // preamble + file_size + postamble + let metadata_json = serde_json::to_string(metadata.as_ref().unwrap()).unwrap(); + let preamble_len = format!( + "--{boundary}\r\nContent-Type: application/json; charset=UTF-8\r\n\r\n{metadata_json}\r\n--{boundary}\r\nContent-Type: text/plain\r\n\r\n" + ).len() as u64; + let postamble_len = format!("\r\n--{boundary}--\r\n").len() as u64; + assert_eq!(content_length, preamble_len + content.len() as u64 + postamble_len); + } + + #[tokio::test] + async fn test_build_multipart_stream_large_file() { + use std::io::Write; + let metadata = Some(json!({ "name": "big.bin", "mimeType": "application/octet-stream" })); + let content = vec![0xABu8; 256 * 1024]; // 256KB — larger than the 64KB chunk size + + let mut tmp = tempfile::NamedTempFile::new().unwrap(); + tmp.write_all(&content).unwrap(); + let path = tmp.path().to_str().unwrap().to_string(); + + let (_, _, content_length) = + build_multipart_stream(&metadata, &path, content.len() as u64); + + // Verify the declared content_length is consistent + let metadata_json = serde_json::to_string(metadata.as_ref().unwrap()).unwrap(); + let boundary_example = "gws_boundary_0000000000000000"; + // Structural overhead is: preamble + postamble (boundary length is fixed at 29 chars) + let expected_overhead = format!( + "--{boundary_example}\r\nContent-Type: application/json; charset=UTF-8\r\n\r\n{metadata_json}\r\n--{boundary_example}\r\nContent-Type: application/octet-stream\r\n\r\n" + ).len() as u64 + + format!("\r\n--{boundary_example}--\r\n").len() as u64; + assert_eq!(content_length, expected_overhead + content.len() as u64); + } + #[test] fn test_build_url_basic() { let doc = RestDescription {