From f265e47e1391ea61d5f8ad92c5c847c24708c711 Mon Sep 17 00:00:00 2001 From: Josh Sephton Date: Tue, 7 Apr 2026 11:45:09 +0100 Subject: [PATCH 1/9] feat: add blob storage for file uploads and async validation --- Cargo.lock | 979 ++++++++++++++++-- config.yaml | 16 + dwctl/Cargo.toml | 3 + dwctl/migrations/090_add_file_ingest_jobs.sql | 10 + dwctl/src/api/handlers/batches.rs | 23 + dwctl/src/api/handlers/files.rs | 412 +++++++- dwctl/src/blob_storage.rs | 102 ++ dwctl/src/config.rs | 130 +++ dwctl/src/lib.rs | 3 + dwctl/src/tasks.rs | 35 +- dwctl/src/test/mod.rs | 6 + dwctl/src/test/utils.rs | 4 + 12 files changed, 1594 insertions(+), 129 deletions(-) create mode 100644 dwctl/migrations/090_add_file_ingest_jobs.sql create mode 100644 dwctl/src/blob_storage.rs diff --git a/Cargo.lock b/Cargo.lock index e54f9e12e..8d7b183af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -269,8 +269,8 @@ dependencies = [ "async-stripe-shared", "bytes", "http-body-util", - "hyper", - "hyper-rustls", + "hyper 1.8.1", + "hyper-rustls 0.27.7", "hyper-util", "miniserde", "thiserror 2.0.18", @@ -518,6 +518,48 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "aws-config" +version = "1.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11493b0bad143270fb8ad284a096dd529ba91924c5409adeac856cc1bf047dbc" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sdk-sso", + "aws-sdk-ssooidc", + "aws-sdk-sts", + "aws-smithy-async", + "aws-smithy-http 0.63.6", + "aws-smithy-json 0.62.5", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "hex", + "http 1.4.0", + "sha1", + "time", + "tokio", + "tracing", + "url", + "zeroize", +] + +[[package]] +name = "aws-credential-types" +version = "1.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f20799b373a1be121fe3005fba0c2090af9411573878f224df44b42727fcaf7" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "zeroize", +] + [[package]] name = "aws-lc-rs" version = "1.16.2" @@ -540,6 +582,412 @@ dependencies = [ "fs_extra", ] +[[package]] +name = "aws-runtime" +version = "1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fc0651c57e384202e47153c1260b84a9936e19803d747615edf199dc3b98d17" +dependencies = [ + "aws-credential-types", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-eventstream", + "aws-smithy-http 0.63.6", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "bytes-utils", + "fastrand", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "http-body 1.0.1", + "percent-encoding", + "pin-project-lite", + "tracing", + "uuid", +] + +[[package]] +name = "aws-sdk-s3" +version = "1.119.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d65fddc3844f902dfe1864acb8494db5f9342015ee3ab7890270d36fbd2e01c" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-checksums", + "aws-smithy-eventstream", + "aws-smithy-http 0.62.6", + "aws-smithy-json 0.61.9", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", + "aws-types", + "bytes", + "fastrand", + "hex", + "hmac", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "lru", + "percent-encoding", + "regex-lite", + "sha2", + "tracing", + "url", +] + +[[package]] +name = "aws-sdk-sso" +version = "1.97.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aadc669e184501caaa6beafb28c6267fc1baef0810fb58f9b205485ca3f2567" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http 0.63.6", + "aws-smithy-json 0.62.5", + "aws-smithy-observability", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "http 1.4.0", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-ssooidc" +version = "1.99.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1342a7db8f358d3de0aed2007a0b54e875458e39848d54cc1d46700b2bfcb0a8" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http 0.63.6", + "aws-smithy-json 0.62.5", + "aws-smithy-observability", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "http 1.4.0", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-sts" +version = "1.101.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab41ad64e4051ecabeea802d6a17845a91e83287e1dd249e6963ea1ba78c428a" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http 0.63.6", + "aws-smithy-json 0.62.5", + "aws-smithy-observability", + "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", + "aws-types", + "fastrand", + "http 0.2.12", + "http 1.4.0", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sigv4" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0b660013a6683ab23797778e21f1f854744fdf05f68204b4cca4c8c04b5d1f4" +dependencies = [ + "aws-credential-types", + "aws-smithy-eventstream", + "aws-smithy-http 0.63.6", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "crypto-bigint 0.5.5", + "form_urlencoded", + "hex", + "hmac", + "http 0.2.12", + "http 1.4.0", + "p256", + "percent-encoding", + "ring", + "sha2", + "subtle", + "time", + "tracing", + "zeroize", +] + +[[package]] +name = "aws-smithy-async" +version = "1.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ffcaf626bdda484571968400c326a244598634dc75fd451325a54ad1a59acfc" +dependencies = [ + "futures-util", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "aws-smithy-checksums" +version = "0.63.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87294a084b43d649d967efe58aa1f9e0adc260e13a6938eb904c0ae9b45824ae" +dependencies = [ + "aws-smithy-http 0.62.6", + "aws-smithy-types", + "bytes", + "crc-fast", + "hex", + "http 0.2.12", + "http-body 0.4.6", + "md-5", + "pin-project-lite", + "sha1", + "sha2", + "tracing", +] + +[[package]] +name = "aws-smithy-eventstream" +version = "0.60.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faf09d74e5e32f76b8762da505a3cd59303e367a664ca67295387baa8c1d7548" +dependencies = [ + "aws-smithy-types", + "bytes", + "crc32fast", +] + +[[package]] +name = "aws-smithy-http" +version = "0.62.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "826141069295752372f8203c17f28e30c464d22899a43a0c9fd9c458d469c88b" +dependencies = [ + "aws-smithy-eventstream", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "bytes-utils", + "futures-core", + "futures-util", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "percent-encoding", + "pin-project-lite", + "pin-utils", + "tracing", +] + +[[package]] +name = "aws-smithy-http" +version = "0.63.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba1ab2dc1c2c3749ead27180d333c42f11be8b0e934058fb4b2258ee8dbe5231" +dependencies = [ + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "bytes-utils", + "futures-core", + "futures-util", + "http 1.4.0", + "http-body 1.0.1", + "http-body-util", + "percent-encoding", + "pin-project-lite", + "pin-utils", + "tracing", +] + +[[package]] +name = "aws-smithy-http-client" +version = "1.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a2f165a7feee6f263028b899d0a181987f4fa7179a6411a32a439fba7c5f769" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "h2 0.3.27", + "h2 0.4.13", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "hyper 0.14.32", + "hyper 1.8.1", + "hyper-rustls 0.24.2", + "hyper-rustls 0.27.7", + "hyper-util", + "pin-project-lite", + "rustls 0.21.12", + "rustls 0.23.37", + "rustls-native-certs", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.4", + "tower", + "tracing", +] + +[[package]] +name = "aws-smithy-json" +version = "0.61.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49fa1213db31ac95288d981476f78d05d9cbb0353d22cdf3472cc05bb02f6551" +dependencies = [ + "aws-smithy-types", +] + +[[package]] +name = "aws-smithy-json" +version = "0.62.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9648b0bb82a2eedd844052c6ad2a1a822d1f8e3adee5fbf668366717e428856a" +dependencies = [ + "aws-smithy-types", +] + +[[package]] +name = "aws-smithy-observability" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06c2315d173edbf1920da8ba3a7189695827002e4c0fc961973ab1c54abca9c" +dependencies = [ + "aws-smithy-runtime-api", +] + +[[package]] +name = "aws-smithy-query" +version = "0.60.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a56d79744fb3edb5d722ef79d86081e121d3b9422cb209eb03aea6aa4f21ebd" +dependencies = [ + "aws-smithy-types", + "urlencoding", +] + +[[package]] +name = "aws-smithy-runtime" +version = "1.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "028999056d2d2fd58a697232f9eec4a643cf73a71cf327690a7edad1d2af2110" +dependencies = [ + "aws-smithy-async", + "aws-smithy-http 0.63.6", + "aws-smithy-http-client", + "aws-smithy-observability", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "fastrand", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "http-body 1.0.1", + "http-body-util", + "pin-project-lite", + "pin-utils", + "tokio", + "tracing", +] + +[[package]] +name = "aws-smithy-runtime-api" +version = "1.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "876ab3c9c29791ba4ba02b780a3049e21ec63dabda09268b175272c3733a79e6" +dependencies = [ + "aws-smithy-async", + "aws-smithy-types", + "bytes", + "http 0.2.12", + "http 1.4.0", + "pin-project-lite", + "tokio", + "tracing", + "zeroize", +] + +[[package]] +name = "aws-smithy-types" +version = "1.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d73dbfbaa8e4bc57b9045137680b958d274823509a360abfd8e1d514d40c95c" +dependencies = [ + "base64-simd", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http 1.4.0", + "http-body 0.4.6", + "http-body 1.0.1", + "http-body-util", + "itoa", + "num-integer", + "pin-project-lite", + "pin-utils", + "ryu", + "serde", + "time", + "tokio", + "tokio-util", +] + +[[package]] +name = "aws-smithy-xml" +version = "0.60.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce02add1aa3677d022f8adf81dcbe3046a95f17a1b1e8979c145cd21d3d22b3" +dependencies = [ + "xmlparser", +] + +[[package]] +name = "aws-types" +version = "1.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47c8323699dd9b3c8d5b3c13051ae9cdef58fd179957c882f8374dd8725962d9" +dependencies = [ + "aws-credential-types", + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "rustc_version", + "tracing", +] + [[package]] name = "axum" version = "0.8.8" @@ -550,10 +998,10 @@ dependencies = [ "bytes", "form_urlencoded", "futures-util", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-util", "itoa", "matchit", @@ -582,8 +1030,8 @@ checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -602,8 +1050,8 @@ dependencies = [ "axum", "bytes", "futures-core", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "matchit", "metrics", "metrics-exporter-prometheus", @@ -625,9 +1073,9 @@ dependencies = [ "bytesize", "cookie", "expect-json", - "http", + "http 1.4.0", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-util", "mime", "pretty_assertions", @@ -642,12 +1090,28 @@ dependencies = [ "url", ] +[[package]] +name = "base16ct" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce" + [[package]] name = "base64" version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195" +dependencies = [ + "outref", + "vsimd", +] + [[package]] name = "base64ct" version = "1.8.3" @@ -815,6 +1279,16 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +[[package]] +name = "bytes-utils" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" +dependencies = [ + "bytes", + "either", +] + [[package]] name = "bytesize" version = "2.3.1" @@ -1068,6 +1542,19 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crc-fast" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ddc2d09feefeee8bd78101665bd8645637828fa9317f9f292496dbbd8c65ff3" +dependencies = [ + "crc", + "digest", + "rand 0.9.2", + "regex", + "rustversion", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -1162,6 +1649,28 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" +[[package]] +name = "crypto-bigint" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef2b4b23cddf68b89b8f8069890e8c270d54e2d5fe1b143820234805e4cb17ef" +dependencies = [ + "generic-array", + "rand_core 0.6.4", + "subtle", + "zeroize", +] + +[[package]] +name = "crypto-bigint" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +dependencies = [ + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "crypto-common" version = "0.1.7" @@ -1300,6 +1809,16 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +[[package]] +name = "der" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1a467a65c5e759bce6e65eaf91cc29f466cdc57cb65777bd646872a8a1fd4de" +dependencies = [ + "const-oid", + "zeroize", +] + [[package]] name = "der" version = "0.7.10" @@ -1425,6 +1944,9 @@ dependencies = [ "async-stripe-types", "async-stripe-webhook", "async-trait", + "aws-config", + "aws-credential-types", + "aws-sdk-s3", "axum", "axum-prometheus", "axum-test", @@ -1467,7 +1989,7 @@ dependencies = [ "reqwest 0.13.2", "rust-embed", "rust_decimal", - "rustls", + "rustls 0.23.37", "scopeguard", "serde", "serde_json", @@ -1503,6 +2025,18 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" +[[package]] +name = "ecdsa" +version = "0.14.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413301934810f597c1d19ca71c8710e99a3f1ba28a0d2ebc01551a2daeea3c5c" +dependencies = [ + "der 0.6.1", + "elliptic-curve", + "rfc6979", + "signature 1.6.4", +] + [[package]] name = "either" version = "1.15.0" @@ -1512,6 +2046,26 @@ dependencies = [ "serde", ] +[[package]] +name = "elliptic-curve" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7bb888ab5300a19b8e5bceef25ac745ad065f3c9f7efc6de1b91958110891d3" +dependencies = [ + "base16ct", + "crypto-bigint 0.4.9", + "der 0.6.1", + "digest", + "ff", + "generic-array", + "group", + "pkcs8 0.9.0", + "rand_core 0.6.4", + "sec1", + "subtle", + "zeroize", +] + [[package]] name = "email-encoding" version = "0.4.1" @@ -1666,6 +2220,16 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "ff" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d013fc25338cc558c5c2cfbad646908fb23591e2404481826742b651c9af7160" +dependencies = [ + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "figment" version = "0.10.19" @@ -2006,6 +2570,36 @@ dependencies = [ "web-time", ] +[[package]] +name = "group" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dfbfb3a6cfbd390d5c9564ab283a0349b9b9fcd46a706c1eb10e0db70bfbac7" +dependencies = [ + "ff", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "h2" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 0.2.12", + "indexmap 2.13.0", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "h2" version = "0.4.13" @@ -2017,7 +2611,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http", + "http 1.4.0", "indexmap 2.13.0", "slab", "tokio", @@ -2138,6 +2732,17 @@ dependencies = [ "windows-link", ] +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http" version = "1.4.0" @@ -2148,6 +2753,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http 0.2.12", + "pin-project-lite", +] + [[package]] name = "http-body" version = "1.0.1" @@ -2155,7 +2771,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http", + "http 1.4.0", ] [[package]] @@ -2166,8 +2782,8 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -2205,6 +2821,30 @@ dependencies = [ "serde", ] +[[package]] +name = "hyper" +version = "0.14.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2 0.3.27", + "http 0.2.12", + "http-body 0.4.6", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2 0.5.10", + "tokio", + "tower-service", + "tracing", + "want", +] + [[package]] name = "hyper" version = "1.8.1" @@ -2215,9 +2855,9 @@ dependencies = [ "bytes", "futures-channel", "futures-core", - "h2", - "http", - "http-body", + "h2 0.4.13", + "http 1.4.0", + "http-body 1.0.1", "httparse", "httpdate", "itoa", @@ -2228,21 +2868,36 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http 0.2.12", + "hyper 0.14.32", + "log", + "rustls 0.21.12", + "tokio", + "tokio-rustls 0.24.1", +] + [[package]] name = "hyper-rustls" version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "http", - "hyper", + "http 1.4.0", + "hyper 1.8.1", "hyper-util", "log", - "rustls", + "rustls 0.23.37", "rustls-native-certs", "rustls-pki-types", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.4", "tower-service", "webpki-roots 1.0.6", ] @@ -2255,7 +2910,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-util", "native-tls", "tokio", @@ -2273,14 +2928,14 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http", - "http-body", - "hyper", + "http 1.4.0", + "http-body 1.0.1", + "hyper 1.8.1", "ipnet", "libc", "percent-encoding", "pin-project-lite", - "socket2", + "socket2 0.6.3", "system-configuration", "tokio", "tower-service", @@ -2714,10 +3369,10 @@ dependencies = [ "nom 8.0.0", "percent-encoding", "quoted_printable", - "rustls", - "socket2", + "rustls 0.23.37", + "socket2 0.6.3", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.4", "url", "uuid", "webpki-roots 1.0.6", @@ -2784,6 +3439,15 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lru" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" +dependencies = [ + "hashbrown 0.15.5", +] + [[package]] name = "lru-slab" version = "0.1.2" @@ -2845,15 +3509,15 @@ checksum = "3589659543c04c7dc5526ec858591015b87cd8746583b51b48ef4353f99dbcda" dependencies = [ "base64", "http-body-util", - "hyper", - "hyper-rustls", + "hyper 1.8.1", + "hyper-rustls 0.27.7", "hyper-util", "indexmap 2.13.0", "ipnet", "metrics", "metrics-util", "quanta", - "rustls", + "rustls 0.23.37", "thiserror 2.0.18", "tokio", "tracing", @@ -2980,7 +3644,7 @@ dependencies = [ "bytes", "encoding_rs", "futures-util", - "http", + "http 1.4.0", "httparse", "memchr", "mime", @@ -3202,7 +3866,7 @@ dependencies = [ "futures-util", "governor", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-tls", "hyper-util", "metrics", @@ -3311,7 +3975,7 @@ checksum = "d7a6d09a73194e6b66df7c8f1b680f156d916a1a942abf2de06823dd02b7855d" dependencies = [ "async-trait", "bytes", - "http", + "http 1.4.0", "opentelemetry", "reqwest 0.12.28", ] @@ -3322,7 +3986,7 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f69cd6acbb9af919df949cd1ec9e5e7fdc2ef15d234b6b795aaa525cc02f71f" dependencies = [ - "http", + "http 1.4.0", "opentelemetry", "opentelemetry-http", "opentelemetry-proto", @@ -3395,7 +4059,7 @@ dependencies = [ "base64", "bytes", "chrono", - "http", + "http 1.4.0", "metrics", "outlet", "serde", @@ -3408,6 +4072,23 @@ dependencies = [ "uuid", ] +[[package]] +name = "outref" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" + +[[package]] +name = "p256" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51f44edd08f51e2ade572f141051021c5af22677e42b7dd28a88155151c33594" +dependencies = [ + "ecdsa", + "elliptic-curve", + "sha2", +] + [[package]] name = "page_size" version = "0.6.0" @@ -3550,9 +4231,19 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" dependencies = [ - "der", - "pkcs8", - "spki", + "der 0.7.10", + "pkcs8 0.10.2", + "spki 0.7.3", +] + +[[package]] +name = "pkcs8" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9eca2c590a5f85da82668fa685c09ce2888b9430e83299debf1f34b65fd4a4ba" +dependencies = [ + "der 0.6.1", + "spki 0.6.0", ] [[package]] @@ -3561,8 +4252,8 @@ version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" dependencies = [ - "der", - "spki", + "der 0.7.10", + "spki 0.7.3", ] [[package]] @@ -3869,8 +4560,8 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls", - "socket2", + "rustls 0.23.37", + "socket2 0.6.3", "thiserror 2.0.18", "tokio", "tracing", @@ -3890,7 +4581,7 @@ dependencies = [ "rand 0.9.2", "ring", "rustc-hash", - "rustls", + "rustls 0.23.37", "rustls-pki-types", "slab", "thiserror 2.0.18", @@ -3908,7 +4599,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2", + "socket2 0.6.3", "tracing", "windows-sys 0.60.2", ] @@ -4153,18 +4844,18 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", - "hyper", - "hyper-rustls", + "hyper 1.8.1", + "hyper-rustls 0.27.7", "hyper-util", "js-sys", "log", "percent-encoding", "pin-project-lite", "quinn", - "rustls", + "rustls 0.23.37", "rustls-native-certs", "rustls-pki-types", "serde", @@ -4172,7 +4863,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.4", "tower", "tower-http", "tower-service", @@ -4193,12 +4884,12 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", + "h2 0.4.13", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", - "hyper", - "hyper-rustls", + "hyper 1.8.1", + "hyper-rustls 0.27.7", "hyper-tls", "hyper-util", "js-sys", @@ -4208,7 +4899,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls", + "rustls 0.23.37", "rustls-pki-types", "rustls-platform-verifier", "serde", @@ -4217,7 +4908,7 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", - "tokio-rustls", + "tokio-rustls 0.26.4", "tokio-util", "tower", "tower-http", @@ -4237,7 +4928,7 @@ checksum = "199dda04a536b532d0cc04d7979e39b1c763ea749bf91507017069c00b96056f" dependencies = [ "anyhow", "async-trait", - "http", + "http 1.4.0", "reqwest 0.13.2", "serde", "thiserror 2.0.18", @@ -4254,8 +4945,8 @@ dependencies = [ "async-trait", "futures", "getrandom 0.2.17", - "http", - "hyper", + "http 1.4.0", + "hyper 1.8.1", "reqwest 0.13.2", "reqwest-middleware", "retry-policies", @@ -4274,7 +4965,7 @@ dependencies = [ "anyhow", "async-trait", "getrandom 0.2.17", - "http", + "http 1.4.0", "matchit", "reqwest 0.13.2", "reqwest-middleware", @@ -4299,6 +4990,17 @@ dependencies = [ "rand 0.9.2", ] +[[package]] +name = "rfc6979" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7743f17af12fa0b03b803ba12cd6a8d9483a587e89c69445e3909655c0b9fabb" +dependencies = [ + "crypto-bigint 0.4.9", + "hmac", + "zeroize", +] + [[package]] name = "ring" version = "0.17.14" @@ -4354,10 +5056,10 @@ dependencies = [ "num-integer", "num-traits", "pkcs1", - "pkcs8", + "pkcs8 0.10.2", "rand_core 0.6.4", - "signature", - "spki", + "signature 2.2.0", + "spki 0.7.3", "subtle", "zeroize", ] @@ -4405,7 +5107,7 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "http", + "http 1.4.0", "mime", "rand 0.9.2", "thiserror 2.0.18", @@ -4434,6 +5136,15 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "1.1.4" @@ -4447,6 +5158,18 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls" +version = "0.21.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" +dependencies = [ + "log", + "ring", + "rustls-webpki 0.101.7", + "sct", +] + [[package]] name = "rustls" version = "0.23.37" @@ -4458,7 +5181,7 @@ dependencies = [ "once_cell", "ring", "rustls-pki-types", - "rustls-webpki", + "rustls-webpki 0.103.10", "subtle", "zeroize", ] @@ -4496,10 +5219,10 @@ dependencies = [ "jni", "log", "once_cell", - "rustls", + "rustls 0.23.37", "rustls-native-certs", "rustls-platform-verifier-android", - "rustls-webpki", + "rustls-webpki 0.103.10", "security-framework", "security-framework-sys", "webpki-root-certs", @@ -4512,6 +5235,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustls-webpki" version = "0.103.10" @@ -4593,6 +5326,16 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "sdd" version = "3.0.10" @@ -4605,6 +5348,20 @@ version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b" +[[package]] +name = "sec1" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3be24c1842290c45df0a7bf069e0c268a747ad05a192f2fd7dcfdbc1cba40928" +dependencies = [ + "base16ct", + "der 0.6.1", + "generic-array", + "pkcs8 0.9.0", + "subtle", + "zeroize", +] + [[package]] name = "security-framework" version = "3.7.0" @@ -4829,6 +5586,16 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "1.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74233d3b3b2f6d4b006dc19dee745e73e2a6bfb6f93607cd3b02bd5b00797d7c" +dependencies = [ + "digest", + "rand_core 0.6.4", +] + [[package]] name = "signature" version = "2.2.0" @@ -4894,6 +5661,16 @@ dependencies = [ "serde", ] +[[package]] +name = "socket2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "socket2" version = "0.6.3" @@ -4922,6 +5699,16 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spki" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67cf02bbac7a337dc36e4f5a693db6c21e7863f45070f7064577eb4367a3212b" +dependencies = [ + "base64ct", + "der 0.6.1", +] + [[package]] name = "spki" version = "0.7.3" @@ -4929,7 +5716,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" dependencies = [ "base64ct", - "der", + "der 0.7.10", ] [[package]] @@ -4971,7 +5758,7 @@ dependencies = [ "once_cell", "percent-encoding", "rust_decimal", - "rustls", + "rustls 0.23.37", "serde", "serde_json", "sha2", @@ -5426,7 +6213,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.6.3", "tokio-macros", "windows-sys 0.61.2", ] @@ -5452,13 +6239,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls 0.21.12", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" dependencies = [ - "rustls", + "rustls 0.23.37", "tokio", ] @@ -5537,8 +6334,8 @@ dependencies = [ "async-trait", "base64", "bytes", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", "percent-encoding", "pin-project", @@ -5586,8 +6383,8 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", "http-range-header", "httpdate", @@ -5850,6 +6647,12 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -5929,6 +6732,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + [[package]] name = "walkdir" version = "2.5.0" @@ -6569,9 +7378,9 @@ dependencies = [ "base64", "deadpool", "futures", - "http", + "http 1.4.0", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-util", "log", "once_cell", @@ -6695,6 +7504,12 @@ dependencies = [ "rustix", ] +[[package]] +name = "xmlparser" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" + [[package]] name = "yansi" version = "1.0.1" diff --git a/config.yaml b/config.yaml index 626052b39..9bf563035 100644 --- a/config.yaml +++ b/config.yaml @@ -283,6 +283,22 @@ batches: # Higher values improve throughput for large files but use more memory # Reduce value if you encounter memory spike issues with many concurrent uploads batch_insert_size: 5000 + # Storage backend for upload path: + # - postgres: legacy (validate + insert synchronously during upload) + # - object_store: write raw JSONL to blob storage then ingest asynchronously + storage_backend: postgres + # Required when storage_backend=object_store + # object_store: + # provider: s3_compatible + # endpoint: "http://localhost:9000" + # bucket: "dwctl-batches" + # region: "us-east-1" + # access_key_id: "minioadmin" + # secret_access_key: "minioadmin" + # path_style: true + # prefix: "uploads/" + # connect_timeout_ms: 5000 + # request_timeout_ms: 120000 # Resource limits configuration # Controls file size, request count, and concurrency limits to prevent resource exhaustion diff --git a/dwctl/Cargo.toml b/dwctl/Cargo.toml index a3d7fd073..47c8c0b2c 100644 --- a/dwctl/Cargo.toml +++ b/dwctl/Cargo.toml @@ -146,6 +146,9 @@ multer = "3.1.0" ctor = "0.6" hmac = "0.12.1" sha2 = "0.10.9" +aws-config = "1.8.10" +aws-sdk-s3 = "1.115.0" +aws-credential-types = "1.2.9" [dev-dependencies] axum-test = { version = "18.4.1" } diff --git a/dwctl/migrations/090_add_file_ingest_jobs.sql b/dwctl/migrations/090_add_file_ingest_jobs.sql new file mode 100644 index 000000000..707b2bdf0 --- /dev/null +++ b/dwctl/migrations/090_add_file_ingest_jobs.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS file_ingest_jobs ( + file_id UUID PRIMARY KEY, + object_key TEXT NOT NULL, + status TEXT NOT NULL CHECK (status IN ('pending', 'processing', 'processed', 'failed')), + error_message TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_file_ingest_jobs_status ON file_ingest_jobs(status); diff --git a/dwctl/src/api/handlers/batches.rs b/dwctl/src/api/handlers/batches.rs index 36eda3ec2..9403dab8f 100644 --- a/dwctl/src/api/handlers/batches.rs +++ b/dwctl/src/api/handlers/batches.rs @@ -353,6 +353,29 @@ pub async fn create_batch( }); } + // If asynchronous file ingestion is enabled, gate batch creation on ingest state. + if let Some(status) = crate::api::handlers::files::get_file_ingest_status(state.db.write(), file_id).await? { + match status.as_str() { + "pending" | "processing" => { + return Err(Error::BadRequest { + message: "File is still being processed. Please retry shortly.".to_string(), + }); + } + "failed" => { + let msg = sqlx::query_scalar::<_, Option>( + "SELECT error_message FROM file_ingest_jobs WHERE file_id = $1", + ) + .bind(file_id) + .fetch_one(state.db.write()) + .await + .map_err(|e| Error::Database(e.into()))? + .unwrap_or_else(|| "File ingestion failed".to_string()); + return Err(Error::BadRequest { message: msg }); + } + _ => {} + } + } + // Check that the file owner (whose API key is embedded in the request templates) // has sufficient balance. This catches cases where an admin creates a batch from // a file owned by another user/org that has negative balance. diff --git a/dwctl/src/api/handlers/files.rs b/dwctl/src/api/handlers/files.rs index a079f3d91..0e13224f8 100644 --- a/dwctl/src/api/handlers/files.rs +++ b/dwctl/src/api/handlers/files.rs @@ -13,6 +13,8 @@ use crate::api::models::files::{ }; use crate::api::models::users::CurrentUser; use crate::auth::permissions::{RequiresPermission, can_read_all_resources, operation, resource}; +use crate::blob_storage::BlobStorageClient; +use crate::config::FileStorageBackend; use crate::AppState; use crate::db::{ @@ -43,6 +45,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::pin::Pin; use std::sync::{Arc, Mutex}; +use tokio::io::AsyncWriteExt; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use uuid::Uuid; @@ -425,6 +428,247 @@ struct FileRequestContext { allowed_url_paths: Vec, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IngestFileInput { + pub file_id: Uuid, + pub object_key: String, + pub endpoint: String, + pub api_key: String, + pub api_key_id: Uuid, + pub uploaded_by: String, + pub filename: String, + pub size_bytes: i64, +} + +pub(crate) async fn get_file_ingest_status(pool: &sqlx::PgPool, file_id: Uuid) -> Result> { + sqlx::query_scalar::<_, String>("SELECT status FROM file_ingest_jobs WHERE file_id = $1") + .bind(file_id) + .fetch_optional(pool) + .await + .map_err(|e| Error::Database(e.into())) +} + +async fn set_file_ingest_status( + pool: &sqlx::PgPool, + file_id: Uuid, + object_key: &str, + status: &str, + error_message: Option<&str>, +) -> Result<()> { + sqlx::query( + r#" + INSERT INTO file_ingest_jobs (file_id, object_key, status, error_message, created_at, updated_at) + VALUES ($1, $2, $3, $4, NOW(), NOW()) + ON CONFLICT (file_id) DO UPDATE + SET object_key = EXCLUDED.object_key, + status = EXCLUDED.status, + error_message = EXCLUDED.error_message, + updated_at = NOW() + "#, + ) + .bind(file_id) + .bind(object_key) + .bind(status) + .bind(error_message) + .execute(pool) + .await + .map_err(|e| Error::Database(e.into()))?; + Ok(()) +} + +async fn insert_file_placeholder( + pool: &sqlx::PgPool, + input: &IngestFileInput, + purpose: &str, +) -> Result<()> { + sqlx::query( + r#" + INSERT INTO fusillade.files (id, name, purpose, size_bytes, status, uploaded_by, api_key_id, created_at, updated_at) + VALUES ($1, $2, $3, $4, 'processed', $5, $6, NOW(), NOW()) + "#, + ) + .bind(input.file_id) + .bind(&input.filename) + .bind(purpose) + .bind(input.size_bytes) + .bind(&input.uploaded_by) + .bind(input.api_key_id) + .execute(pool) + .await + .map_err(|e| Error::Database(e.into()))?; + Ok(()) +} + +async fn ingest_blob_to_templates( + db_pool: &sqlx::PgPool, + config: &crate::config::Config, + input: &IngestFileInput, +) -> std::result::Result<(), underway::task::Error> { + use underway::task::Error as TaskError; + + let Some(store_config) = config.batches.files.object_store.as_ref() else { + return Err(TaskError::Fatal("object store config missing".to_string())); + }; + + let client = BlobStorageClient::new(store_config) + .await + .map_err(|e| TaskError::Retryable(format!("init object storage client: {e}")))?; + + set_file_ingest_status(db_pool, input.file_id, &input.object_key, "processing", None) + .await + .map_err(|e| TaskError::Retryable(format!("set ingest status processing: {e}")))?; + + let bytes = client + .get_file_bytes(&input.object_key) + .await + .map_err(|e| TaskError::Retryable(format!("download blob for ingest: {e}")))?; + + let content = std::str::from_utf8(&bytes).map_err(|e| TaskError::Fatal(format!("Invalid UTF-8 in upload: {e}")))?; + + let mut conn = db_pool + .acquire() + .await + .map_err(|e| TaskError::Retryable(format!("acquire db conn for ingest: {e}")))?; + + let mut deployments_repo = Deployments::new(&mut conn); + let target_user_id = Uuid::parse_str(&input.uploaded_by) + .map_err(|e| TaskError::Fatal(format!("invalid uploaded_by owner id on file ingest: {e}")))?; + let filter = DeploymentFilter::new(0, i64::MAX) + .with_accessible_to(target_user_id) + .with_statuses(vec![ModelStatus::Active]) + .with_deleted(false); + let accessible_deployments = deployments_repo + .list(&filter) + .await + .map_err(|e| TaskError::Retryable(format!("query accessible deployments for ingest: {e}")))?; + let accessible_models: HashMap> = + accessible_deployments.into_iter().map(|d| (d.alias, d.model_type)).collect(); + drop(conn); + + let mut templates: Vec = Vec::new(); + let mut non_empty_lines: usize = 0; + for (idx, line) in content.lines().enumerate() { + let line_no = idx + 1; + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + non_empty_lines += 1; + if config.limits.files.max_requests_per_file > 0 && non_empty_lines > config.limits.files.max_requests_per_file { + return Err(TaskError::Fatal(format!( + "Line {}: file contains too many requests (>{})", + line_no, config.limits.files.max_requests_per_file + ))); + } + + let openai_req: OpenAIBatchRequest = serde_json::from_str(trimmed) + .map_err(|e| TaskError::Fatal(format!("Invalid JSON on line {}: {}", line_no, e)))?; + let template = openai_req + .to_internal( + &input.endpoint, + input.api_key.clone(), + &accessible_models, + &config.batches.allowed_url_paths, + ) + .map_err(|e| TaskError::Fatal(format!("Line {}: {}", line_no, e)))?; + + if config.limits.requests.max_body_size > 0 && template.body.len() as u64 > config.limits.requests.max_body_size { + return Err(TaskError::Fatal(format!( + "Line {}: Request body is {} bytes, exceeds max {} bytes", + line_no, + template.body.len(), + config.limits.requests.max_body_size + ))); + } + templates.push(template); + } + + if templates.is_empty() { + return Err(TaskError::Fatal("File contains no valid request templates".to_string())); + } + + let mut tx = db_pool + .begin() + .await + .map_err(|e| TaskError::Retryable(format!("begin template ingest tx: {e}")))?; + + for chunk in templates.chunks(config.batches.files.batch_insert_size.max(1)) { + for template in chunk { + let template_id = Uuid::new_v4(); + sqlx::query( + r#" + INSERT INTO fusillade.request_templates (id, file_id, model, api_key, endpoint, path, body, custom_id, method) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + "#, + ) + .bind(template_id) + .bind(input.file_id) + .bind(&template.model) + .bind(&template.api_key) + .bind(&template.endpoint) + .bind(&template.path) + .bind(&template.body) + .bind(&template.custom_id) + .bind(&template.method) + .execute(&mut *tx) + .await + .map_err(|e| TaskError::Retryable(format!("insert request template: {e}")))?; + } + } + + tx.commit() + .await + .map_err(|e| TaskError::Retryable(format!("commit template ingest tx: {e}")))?; + + Ok(()) +} + +pub async fn build_ingest_file_job( + pool: sqlx::PgPool, + state: crate::tasks::TaskState

, +) -> anyhow::Result>> { + use underway::Job; + use underway::job::To; + use underway::task::Error as TaskError; + + Job::::builder() + .state(state) + .step(|cx, input: IngestFileInput| async move { + let config = cx.state.config.snapshot(); + match ingest_blob_to_templates(&cx.state.db_pool, &config, &input).await { + Ok(()) => { + set_file_ingest_status(&cx.state.db_pool, input.file_id, &input.object_key, "processed", None) + .await + .map_err(|e| TaskError::Retryable(format!("set ingest status processed: {e}")))?; + tracing::info!(file_id = %input.file_id, "File ingest completed"); + To::done() + } + Err(TaskError::Fatal(msg)) => { + let _ = + set_file_ingest_status(&cx.state.db_pool, input.file_id, &input.object_key, "failed", Some(&msg)).await; + Err(TaskError::Fatal(msg)) + } + Err(TaskError::Retryable(msg)) => { + let _ = set_file_ingest_status( + &cx.state.db_pool, + input.file_id, + &input.object_key, + "pending", + Some(&format!("retrying after transient error: {msg}")), + ) + .await; + Err(TaskError::Retryable(msg)) + } + Err(other) => Err(other), + } + }) + .name("ingest-file") + .pool(pool) + .build() + .await + .map_err(Into::into) +} + /// Helper function to create a stream of FileStreamItem from multipart upload /// This handles the entire multipart parsing inside the stream #[tracing::instrument(skip(multipart, req_ctx), fields(config.max_file_size, config.max_requests_per_file, uploaded_by = ?uploaded_by, endpoint = %req_ctx.endpoint, config.buffer_size))] @@ -803,7 +1047,7 @@ Each line must be a valid JSON object containing `custom_id`, `method`, `url`, a description = "Multipart form with `file` (the JSONL file) and `purpose` (must be `batch`)." ), responses( - (status = 201, description = "File uploaded and validated successfully.", body = FileResponse), + (status = 201, description = "File accepted for processing.", body = FileResponse), (status = 400, description = "Invalid file format, malformed JSON, missing required fields, etc."), (status = 403, description = "Model referenced in the file is not configured or not accessible to your account."), (status = 413, description = "File exceeds the maximum allowed size."), @@ -850,12 +1094,6 @@ pub async fn upload_file( message: format!("Invalid multipart request: {}", e), })?; - let stream_config = FileStreamConfig { - max_file_size: config.limits.files.max_file_size, - max_requests_per_file: config.limits.files.max_requests_per_file, - max_request_body_size: config.limits.requests.max_body_size, - buffer_size: config.batches.files.upload_buffer_size, - }; // When in org context, attribute file ownership to the org (not the individual user). // Also used for the hidden API key lookup below. let target_user_id = current_user.active_organization.unwrap_or(current_user.id); @@ -886,44 +1124,140 @@ pub async fn upload_file( // drop conn so it isn't persisted for entire upload process drop(conn); - // Create a stream that parses the multipart upload and yields FileStreamItems - let (file_stream, error_slot) = create_file_stream( - multipart, - stream_config, - uploaded_by, - FileRequestContext { + let created_file_id = if config.batches.files.storage_backend == FileStorageBackend::ObjectStore { + let object_store_cfg = config.batches.files.object_store.as_ref().ok_or_else(|| Error::Internal { + operation: "object store backend selected but no object_store config found".to_string(), + })?; + let blob = BlobStorageClient::new(object_store_cfg).await?; + + let mut multipart = multipart; + let mut filename: Option = None; + let mut purpose: Option = None; + let mut total_size: u64 = 0; + let file_id = Uuid::new_v4(); + let tmp_path = format!("/tmp/dwctl-upload-{}.jsonl", file_id); + let mut tmp_file = tokio::fs::File::create(&tmp_path).await.map_err(|e| Error::Internal { + operation: format!("create temp upload file: {e}"), + })?; + let mut saw_file = false; + + while let Some(field) = multipart.next_field().await.map_err(|e| Error::BadRequest { + message: format!("Invalid multipart upload: {}", e), + })? { + match field.name().unwrap_or("") { + "purpose" => { + purpose = Some(field.text().await.unwrap_or_default()); + } + "file" => { + saw_file = true; + filename = field.file_name().map(|s| s.to_string()); + let mut field = field; + while let Some(chunk) = field.chunk().await.map_err(|e| Error::BadRequest { + message: format!("Read upload chunk: {e}"), + })? { + total_size = total_size.saturating_add(chunk.len() as u64); + if max_file_size > 0 && total_size > max_file_size { + return Err(Error::PayloadTooLarge { + message: format!("File exceeds the maximum allowed size of {} bytes", max_file_size), + }); + } + tmp_file.write_all(&chunk).await.map_err(|e| Error::Internal { + operation: format!("write temp upload bytes: {e}"), + })?; + } + } + _ => {} + } + } + + if !saw_file { + return Err(Error::BadRequest { + message: "No file field found in multipart upload".to_string(), + }); + } + if total_size == 0 { + return Err(Error::BadRequest { + message: "File contains no valid request templates".to_string(), + }); + } + if purpose.as_deref() != Some("batch") { + return Err(Error::BadRequest { + message: "Invalid purpose. Only 'batch' is supported.".to_string(), + }); + } + + tmp_file.flush().await.map_err(|e| Error::Internal { + operation: format!("flush temp upload file: {e}"), + })?; + + let object_key = blob.object_key_for_file(file_id); + blob.put_file_from_path(&object_key, &tmp_path, "application/x-ndjson") + .await?; + let _ = tokio::fs::remove_file(&tmp_path).await; + + let ingest = IngestFileInput { + file_id, + object_key: object_key.clone(), endpoint, api_key: user_api_key, - accessible_models, - allowed_url_paths: config.batches.allowed_url_paths.clone(), - }, - Some(api_key_id), - ); - - // Create file via request manager with streaming - let created_file_result = state.request_manager.create_file_stream(file_stream).await.map_err(|e| { - // Check if WE aborted (control-layer error in slot) - // Handle poisoned mutex gracefully - the data is still valid - let upload_err = match error_slot.lock() { - Ok(mut guard) => guard.take(), - Err(poisoned) => poisoned.into_inner().take(), + api_key_id, + uploaded_by: uploaded_by.clone().unwrap_or_default(), + filename: filename.clone().unwrap_or_else(|| "upload.jsonl".to_string()), + size_bytes: i64::try_from(total_size).unwrap_or(i64::MAX), }; - if let Some(upload_err) = upload_err { - tracing::warn!("File upload aborted with error: {:?}", upload_err); - return upload_err.into_http_error(); - } - // Otherwise it's a fusillade error - tracing::warn!("Fusillade error during file upload: {:?}", e); - match e { - fusillade::FusilladeError::ValidationError(msg) => Error::BadRequest { message: msg }, - _ => Error::Internal { - operation: format!("create file: {}", e), + insert_file_placeholder(state.db.write(), &ingest, "batch").await?; + set_file_ingest_status(state.db.write(), file_id, &object_key, "pending", None).await?; + + state + .task_runner + .ingest_file_job + .enqueue(&ingest) + .await + .map_err(|e| Error::Internal { + operation: format!("enqueue file ingest job: {e}"), + })?; + file_id.into() + } else { + let stream_config = FileStreamConfig { + max_file_size: config.limits.files.max_file_size, + max_requests_per_file: config.limits.files.max_requests_per_file, + max_request_body_size: config.limits.requests.max_body_size, + buffer_size: config.batches.files.upload_buffer_size, + }; + let (file_stream, error_slot) = create_file_stream( + multipart, + stream_config, + uploaded_by, + FileRequestContext { + endpoint, + api_key: user_api_key, + accessible_models, + allowed_url_paths: config.batches.allowed_url_paths.clone(), }, - } - })?; + Some(api_key_id), + ); + + let created_file_result = state.request_manager.create_file_stream(file_stream).await.map_err(|e| { + let upload_err = match error_slot.lock() { + Ok(mut guard) => guard.take(), + Err(poisoned) => poisoned.into_inner().take(), + }; + if let Some(upload_err) = upload_err { + tracing::warn!("File upload aborted with error: {:?}", upload_err); + return upload_err.into_http_error(); + } + tracing::warn!("Fusillade error during file upload: {:?}", e); + match e { + fusillade::FusilladeError::ValidationError(msg) => Error::BadRequest { message: msg }, + _ => Error::Internal { + operation: format!("create file: {}", e), + }, + } + })?; - let created_file_id = resolve_upload_stream_result(created_file_result, &error_slot)?; + resolve_upload_stream_result(created_file_result, &error_slot)? + }; tracing::debug!("File {} uploaded successfully", created_file_id); diff --git a/dwctl/src/blob_storage.rs b/dwctl/src/blob_storage.rs new file mode 100644 index 000000000..f23fe1d5b --- /dev/null +++ b/dwctl/src/blob_storage.rs @@ -0,0 +1,102 @@ +use std::time::Duration; + +use aws_config::BehaviorVersion; +use aws_config::meta::region::RegionProviderChain; +use aws_credential_types::{Credentials, provider::SharedCredentialsProvider}; +use aws_sdk_s3::Client; +use aws_sdk_s3::config::{Region, timeout::TimeoutConfig}; +use aws_sdk_s3::primitives::ByteStream; +use uuid::Uuid; + +use crate::config::{ObjectStoreProvider, ObjectStoreConfig}; +use crate::errors::{Error, Result}; + +#[derive(Clone)] +pub struct BlobStorageClient { + client: Client, + bucket: String, + prefix: String, +} + +impl BlobStorageClient { + pub async fn new(config: &ObjectStoreConfig) -> Result { + match config.provider { + ObjectStoreProvider::S3Compatible => {} + } + + let creds = Credentials::new( + config.access_key_id.clone(), + config.secret_access_key.clone(), + config.session_token.clone(), + None, + "dwctl-object-store", + ); + + let timeout_config = TimeoutConfig::builder() + .connect_timeout(Duration::from_millis(config.connect_timeout_ms)) + .operation_timeout(Duration::from_millis(config.request_timeout_ms)) + .build(); + + let sdk_config = aws_config::defaults(BehaviorVersion::latest()) + .region(RegionProviderChain::first_try(Region::new(config.region.clone()))) + .credentials_provider(SharedCredentialsProvider::new(creds)) + .endpoint_url(config.endpoint.clone()) + .timeout_config(timeout_config) + .load() + .await; + + let s3_config = aws_sdk_s3::config::Builder::from(&sdk_config) + .force_path_style(config.path_style) + .build(); + + Ok(Self { + client: Client::from_conf(s3_config), + bucket: config.bucket.clone(), + prefix: config.prefix.clone(), + }) + } + + pub fn object_key_for_file(&self, file_id: Uuid) -> String { + format!("{}{file_id}.jsonl", self.prefix) + } + + pub async fn put_file_from_path(&self, key: &str, path: &str, content_type: &str) -> Result<()> { + let body = ByteStream::from_path(std::path::Path::new(path)) + .await + .map_err(|e| Error::Internal { + operation: format!("open upload file for object storage: {e}"), + })?; + + self.client + .put_object() + .bucket(&self.bucket) + .key(key) + .content_type(content_type) + .body(body) + .send() + .await + .map_err(|e| Error::Internal { + operation: format!("put object to blob storage: {e}"), + })?; + Ok(()) + } + + pub async fn get_file_bytes(&self, key: &str) -> Result> { + let obj = self + .client + .get_object() + .bucket(&self.bucket) + .key(key) + .send() + .await + .map_err(|e| Error::Internal { + operation: format!("get object from blob storage: {e}"), + })?; + + let bytes = obj.body.collect().await.map_err(|e| Error::Internal { + operation: format!("read blob object body: {e}"), + })?; + + Ok(bytes.into_bytes().to_vec()) + } +} diff --git a/dwctl/src/config.rs b/dwctl/src/config.rs index 163fbc430..afcaadd68 100644 --- a/dwctl/src/config.rs +++ b/dwctl/src/config.rs @@ -798,6 +798,13 @@ pub struct FilesConfig { pub download_buffer_size: usize, /// Number of templates to insert in each batch during file upload (default: 5000) pub batch_insert_size: usize, + /// Storage backend for batch input file uploads. + /// + /// - "postgres": legacy mode, parse/upload directly into fusillade request_templates. + /// - "object_store": write raw JSONL to object storage first, then ingest asynchronously. + pub storage_backend: FileStorageBackend, + /// Optional object storage configuration (required when storage_backend=object_store). + pub object_store: Option, } impl Default for FilesConfig { @@ -809,10 +816,60 @@ impl Default for FilesConfig { upload_buffer_size: 100, download_buffer_size: 100, batch_insert_size: 5000, + storage_backend: FileStorageBackend::Postgres, + object_store: None, } } } +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum FileStorageBackend { + Postgres, + ObjectStore, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(default, deny_unknown_fields)] +pub struct ObjectStoreConfig { + pub provider: ObjectStoreProvider, + pub endpoint: String, + pub bucket: String, + pub region: String, + pub access_key_id: String, + pub secret_access_key: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub session_token: Option, + pub path_style: bool, + pub prefix: String, + pub connect_timeout_ms: u64, + pub request_timeout_ms: u64, +} + +impl Default for ObjectStoreConfig { + fn default() -> Self { + Self { + provider: ObjectStoreProvider::S3Compatible, + endpoint: String::new(), + bucket: String::new(), + region: "us-east-1".to_string(), + access_key_id: String::new(), + secret_access_key: String::new(), + session_token: None, + path_style: true, + prefix: "uploads/".to_string(), + connect_timeout_ms: 5000, + request_timeout_ms: 120000, + } + } +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ObjectStoreProvider { + S3Compatible, +} + /// Resource limits for protecting system capacity. /// /// These limits help prevent resource exhaustion under high load by rejecting @@ -1951,6 +2008,39 @@ impl Config { }); } + if self.batches.files.storage_backend == FileStorageBackend::ObjectStore { + let store = self.batches.files.object_store.as_ref().ok_or_else(|| Error::Internal { + operation: "Config validation: batches.files.object_store must be set when storage_backend=object_store." + .to_string(), + })?; + + if store.endpoint.trim().is_empty() { + return Err(Error::Internal { + operation: "Config validation: batches.files.object_store.endpoint cannot be empty.".to_string(), + }); + } + if store.bucket.trim().is_empty() { + return Err(Error::Internal { + operation: "Config validation: batches.files.object_store.bucket cannot be empty.".to_string(), + }); + } + if store.region.trim().is_empty() { + return Err(Error::Internal { + operation: "Config validation: batches.files.object_store.region cannot be empty.".to_string(), + }); + } + if store.access_key_id.trim().is_empty() { + return Err(Error::Internal { + operation: "Config validation: batches.files.object_store.access_key_id cannot be empty.".to_string(), + }); + } + if store.secret_access_key.trim().is_empty() { + return Err(Error::Internal { + operation: "Config validation: batches.files.object_store.secret_access_key cannot be empty.".to_string(), + }); + } + } + // Validate file size limits are sensible (0 = unlimited is allowed but not recommended) // Note: max_file_size is now in limits.files, not batches.files @@ -2276,6 +2366,46 @@ secret_key: "test-secret-key" assert!(result.unwrap_err().to_string().contains("upload_buffer_size cannot be 0")); } + #[test] + fn test_object_store_requires_config_when_enabled() { + let mut config = Config::default(); + config.auth.native.enabled = true; + config.secret_key = Some("test-secret-key".to_string()); + config.batches.enabled = true; + config.batches.files.storage_backend = FileStorageBackend::ObjectStore; + config.batches.files.object_store = None; + + let result = config.validate(); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("object_store must be set when storage_backend=object_store") + ); + } + + #[test] + fn test_object_store_empty_endpoint_rejected() { + let mut config = Config::default(); + config.auth.native.enabled = true; + config.secret_key = Some("test-secret-key".to_string()); + config.batches.enabled = true; + config.batches.files.storage_backend = FileStorageBackend::ObjectStore; + config.batches.files.object_store = Some(ObjectStoreConfig { + endpoint: "".to_string(), + bucket: "bucket".to_string(), + region: "us-east-1".to_string(), + access_key_id: "key".to_string(), + secret_access_key: "secret".to_string(), + ..Default::default() + }); + + let result = config.validate(); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("object_store.endpoint cannot be empty")); + } + #[test] fn test_download_buffer_size_zero_validation() { let mut config = Config::default(); diff --git a/dwctl/src/lib.rs b/dwctl/src/lib.rs index 91e165d26..c0d76833b 100644 --- a/dwctl/src/lib.rs +++ b/dwctl/src/lib.rs @@ -146,6 +146,7 @@ pub mod config; mod config_watcher; mod crypto; pub mod db; +mod blob_storage; mod email; mod error_enrichment; pub mod errors; @@ -2174,6 +2175,8 @@ async fn setup_background_services( // Build the underway task runner for background jobs (batch population, etc.) let task_state = tasks::TaskState { request_manager: request_manager.clone(), + db_pool: pool.clone(), + config: SharedConfig::new(config.clone()), }; let task_runner = Arc::new(tasks::TaskRunner::new(underway_pool, task_state).await?); for handle in task_runner.start(shutdown_token.clone()) { diff --git a/dwctl/src/tasks.rs b/dwctl/src/tasks.rs index 3e80caeae..4f56ea5a9 100644 --- a/dwctl/src/tasks.rs +++ b/dwctl/src/tasks.rs @@ -14,6 +14,8 @@ use tokio_util::sync::CancellationToken; use underway::Job; use crate::api::handlers::batches::{CreateBatchInput, build_create_batch_job}; +use crate::api::handlers::files::{IngestFileInput, build_ingest_file_job}; +use crate::SharedConfig; /// Shared state available to all task step closures. /// @@ -22,6 +24,8 @@ use crate::api::handlers::batches::{CreateBatchInput, build_create_batch_job}; #[derive(Clone)] pub struct TaskState { pub request_manager: Arc>, + pub db_pool: PgPool, + pub config: SharedConfig, } /// Manages underway jobs and worker lifecycle. @@ -30,6 +34,7 @@ pub struct TaskState { /// work; the worker processes jobs in the background. pub struct TaskRunner { pub create_batch_job: Job>, + pub ingest_file_job: Job>, } impl TaskRunner

{ @@ -37,8 +42,12 @@ impl TaskRunner

{ /// /// Call [`start`] to begin processing. pub async fn new(pool: PgPool, state: TaskState

) -> Result { - let create_batch_job = build_create_batch_job(pool, state).await?; - Ok(Self { create_batch_job }) + let create_batch_job = build_create_batch_job(pool.clone(), state.clone()).await?; + let ingest_file_job = build_ingest_file_job(pool, state).await?; + Ok(Self { + create_batch_job, + ingest_file_job, + }) } /// Start the underway worker with the given shutdown token. @@ -48,12 +57,22 @@ impl TaskRunner

{ /// Interrupted tasks are retried on next startup (state tracked in Postgres). pub fn start(&self, shutdown_token: CancellationToken) -> Vec> { let mut worker = self.create_batch_job.worker(); - worker.set_shutdown_token(shutdown_token); + worker.set_shutdown_token(shutdown_token.clone()); - vec![tokio::spawn(async move { - if let Err(e) = worker.run().await { - tracing::error!(error = %e, "Underway worker error"); - } - })] + let mut ingest_worker = self.ingest_file_job.worker(); + ingest_worker.set_shutdown_token(shutdown_token.clone()); + + vec![ + tokio::spawn(async move { + if let Err(e) = worker.run().await { + tracing::error!(error = %e, "Underway worker error"); + } + }), + tokio::spawn(async move { + if let Err(e) = ingest_worker.run().await { + tracing::error!(error = %e, "Underway ingest worker error"); + } + }), + ] } } diff --git a/dwctl/src/test/mod.rs b/dwctl/src/test/mod.rs index 9f087de94..42864408f 100644 --- a/dwctl/src/test/mod.rs +++ b/dwctl/src/test/mod.rs @@ -871,6 +871,8 @@ async fn test_request_logging_disabled(pool: PgPool) { underway::run_migrations(&pool).await.expect("Failed to run underway migrations"); let task_state = TaskState { request_manager: request_manager.clone(), + db_pool: pool.clone(), + config: crate::SharedConfig::new(config.clone()), }; let task_runner = std::sync::Arc::new( crate::tasks::TaskRunner::new(pool.clone(), task_state) @@ -1220,6 +1222,8 @@ async fn test_build_router_with_metrics_disabled(pool: PgPool) { underway::run_migrations(&pool).await.expect("Failed to run underway migrations"); let task_state = crate::tasks::TaskState { request_manager: request_manager.clone(), + db_pool: pool.clone(), + config: crate::SharedConfig::new(config.clone()), }; let task_runner = std::sync::Arc::new( crate::tasks::TaskRunner::new(pool.clone(), task_state) @@ -1261,6 +1265,8 @@ async fn test_build_router_with_metrics_enabled(pool: PgPool) { underway::run_migrations(&pool).await.expect("Failed to run underway migrations"); let task_state = TaskState { request_manager: request_manager.clone(), + db_pool: pool.clone(), + config: crate::SharedConfig::new(config.clone()), }; let task_runner = std::sync::Arc::new( crate::tasks::TaskRunner::new(pool.clone(), task_state) diff --git a/dwctl/src/test/utils.rs b/dwctl/src/test/utils.rs index ca46e5d6c..9c1ce7aef 100644 --- a/dwctl/src/test/utils.rs +++ b/dwctl/src/test/utils.rs @@ -42,6 +42,8 @@ pub async fn create_test_app_state_with_config(pool: PgPool, config: crate::conf underway::run_migrations(&pool).await.expect("Failed to run underway migrations"); let task_state = crate::tasks::TaskState { request_manager: request_manager.clone(), + db_pool: pool.clone(), + config: crate::SharedConfig::new(config.clone()), }; let task_runner = std::sync::Arc::new( crate::tasks::TaskRunner::new(pool, task_state) @@ -91,6 +93,8 @@ pub async fn create_test_app_state_with_fusillade(pool: PgPool, config: crate::c underway::run_migrations(&pool).await.expect("Failed to run underway migrations"); let task_state = crate::tasks::TaskState { request_manager: request_manager.clone(), + db_pool: pool.clone(), + config: crate::SharedConfig::new(config.clone()), }; let task_runner = std::sync::Arc::new( crate::tasks::TaskRunner::new(pool, task_state) From a427be9f761dfcf0727af5883788a9d0fbc55a05 Mon Sep 17 00:00:00 2001 From: Josh Sephton Date: Tue, 7 Apr 2026 11:49:12 +0100 Subject: [PATCH 2/9] chore: lint --- dwctl/src/api/handlers/batches.rs | 14 ++++++-------- dwctl/src/api/handlers/files.rs | 20 +++++++------------- dwctl/src/blob_storage.rs | 2 +- dwctl/src/config.rs | 3 +-- dwctl/src/lib.rs | 2 +- dwctl/src/tasks.rs | 2 +- 6 files changed, 17 insertions(+), 26 deletions(-) diff --git a/dwctl/src/api/handlers/batches.rs b/dwctl/src/api/handlers/batches.rs index 9403dab8f..e35abb842 100644 --- a/dwctl/src/api/handlers/batches.rs +++ b/dwctl/src/api/handlers/batches.rs @@ -362,14 +362,12 @@ pub async fn create_batch( }); } "failed" => { - let msg = sqlx::query_scalar::<_, Option>( - "SELECT error_message FROM file_ingest_jobs WHERE file_id = $1", - ) - .bind(file_id) - .fetch_one(state.db.write()) - .await - .map_err(|e| Error::Database(e.into()))? - .unwrap_or_else(|| "File ingestion failed".to_string()); + let msg = sqlx::query_scalar::<_, Option>("SELECT error_message FROM file_ingest_jobs WHERE file_id = $1") + .bind(file_id) + .fetch_one(state.db.write()) + .await + .map_err(|e| Error::Database(e.into()))? + .unwrap_or_else(|| "File ingestion failed".to_string()); return Err(Error::BadRequest { message: msg }); } _ => {} diff --git a/dwctl/src/api/handlers/files.rs b/dwctl/src/api/handlers/files.rs index 0e13224f8..9cb9981b4 100644 --- a/dwctl/src/api/handlers/files.rs +++ b/dwctl/src/api/handlers/files.rs @@ -476,11 +476,7 @@ async fn set_file_ingest_status( Ok(()) } -async fn insert_file_placeholder( - pool: &sqlx::PgPool, - input: &IngestFileInput, - purpose: &str, -) -> Result<()> { +async fn insert_file_placeholder(pool: &sqlx::PgPool, input: &IngestFileInput, purpose: &str) -> Result<()> { sqlx::query( r#" INSERT INTO fusillade.files (id, name, purpose, size_bytes, status, uploaded_by, api_key_id, created_at, updated_at) @@ -531,8 +527,8 @@ async fn ingest_blob_to_templates( .map_err(|e| TaskError::Retryable(format!("acquire db conn for ingest: {e}")))?; let mut deployments_repo = Deployments::new(&mut conn); - let target_user_id = Uuid::parse_str(&input.uploaded_by) - .map_err(|e| TaskError::Fatal(format!("invalid uploaded_by owner id on file ingest: {e}")))?; + let target_user_id = + Uuid::parse_str(&input.uploaded_by).map_err(|e| TaskError::Fatal(format!("invalid uploaded_by owner id on file ingest: {e}")))?; let filter = DeploymentFilter::new(0, i64::MAX) .with_accessible_to(target_user_id) .with_statuses(vec![ModelStatus::Active]) @@ -561,8 +557,8 @@ async fn ingest_blob_to_templates( ))); } - let openai_req: OpenAIBatchRequest = serde_json::from_str(trimmed) - .map_err(|e| TaskError::Fatal(format!("Invalid JSON on line {}: {}", line_no, e)))?; + let openai_req: OpenAIBatchRequest = + serde_json::from_str(trimmed).map_err(|e| TaskError::Fatal(format!("Invalid JSON on line {}: {}", line_no, e)))?; let template = openai_req .to_internal( &input.endpoint, @@ -644,8 +640,7 @@ pub async fn build_ingest_file_job { - let _ = - set_file_ingest_status(&cx.state.db_pool, input.file_id, &input.object_key, "failed", Some(&msg)).await; + let _ = set_file_ingest_status(&cx.state.db_pool, input.file_id, &input.object_key, "failed", Some(&msg)).await; Err(TaskError::Fatal(msg)) } Err(TaskError::Retryable(msg)) => { @@ -1191,8 +1186,7 @@ pub async fn upload_file( })?; let object_key = blob.object_key_for_file(file_id); - blob.put_file_from_path(&object_key, &tmp_path, "application/x-ndjson") - .await?; + blob.put_file_from_path(&object_key, &tmp_path, "application/x-ndjson").await?; let _ = tokio::fs::remove_file(&tmp_path).await; let ingest = IngestFileInput { diff --git a/dwctl/src/blob_storage.rs b/dwctl/src/blob_storage.rs index f23fe1d5b..0c0a11e9a 100644 --- a/dwctl/src/blob_storage.rs +++ b/dwctl/src/blob_storage.rs @@ -8,7 +8,7 @@ use aws_sdk_s3::config::{Region, timeout::TimeoutConfig}; use aws_sdk_s3::primitives::ByteStream; use uuid::Uuid; -use crate::config::{ObjectStoreProvider, ObjectStoreConfig}; +use crate::config::{ObjectStoreConfig, ObjectStoreProvider}; use crate::errors::{Error, Result}; #[derive(Clone)] diff --git a/dwctl/src/config.rs b/dwctl/src/config.rs index afcaadd68..e14a08669 100644 --- a/dwctl/src/config.rs +++ b/dwctl/src/config.rs @@ -2010,8 +2010,7 @@ impl Config { if self.batches.files.storage_backend == FileStorageBackend::ObjectStore { let store = self.batches.files.object_store.as_ref().ok_or_else(|| Error::Internal { - operation: "Config validation: batches.files.object_store must be set when storage_backend=object_store." - .to_string(), + operation: "Config validation: batches.files.object_store must be set when storage_backend=object_store.".to_string(), })?; if store.endpoint.trim().is_empty() { diff --git a/dwctl/src/lib.rs b/dwctl/src/lib.rs index c0d76833b..5928f689d 100644 --- a/dwctl/src/lib.rs +++ b/dwctl/src/lib.rs @@ -142,11 +142,11 @@ fn install_crypto_provider() { pub mod api; pub mod auth; +mod blob_storage; pub mod config; mod config_watcher; mod crypto; pub mod db; -mod blob_storage; mod email; mod error_enrichment; pub mod errors; diff --git a/dwctl/src/tasks.rs b/dwctl/src/tasks.rs index 4f56ea5a9..47155f73f 100644 --- a/dwctl/src/tasks.rs +++ b/dwctl/src/tasks.rs @@ -13,9 +13,9 @@ use sqlx_pool_router::PoolProvider; use tokio_util::sync::CancellationToken; use underway::Job; +use crate::SharedConfig; use crate::api::handlers::batches::{CreateBatchInput, build_create_batch_job}; use crate::api::handlers::files::{IngestFileInput, build_ingest_file_job}; -use crate::SharedConfig; /// Shared state available to all task step closures. /// From 2c92de20ea1739c26d2a6bb8d58c8ed9a14004cd Mon Sep 17 00:00:00 2001 From: Josh Sephton Date: Tue, 7 Apr 2026 12:24:15 +0100 Subject: [PATCH 3/9] feat: move fusillade-scoped db writes to fusillade --- Cargo.lock | 2 - dwctl/Cargo.toml | 2 +- dwctl/src/api/handlers/files.rs | 78 +++++++++++---------------------- 3 files changed, 26 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8d7b183af..21e1c0fd3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2350,8 +2350,6 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "fusillade" version = "14.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad7eaf8711b77ceb075ba924299afd937c8ae717ffe94d2751786aad531c6edf" dependencies = [ "anyhow", "async-trait", diff --git a/dwctl/Cargo.toml b/dwctl/Cargo.toml index 47c8c0b2c..01684a768 100644 --- a/dwctl/Cargo.toml +++ b/dwctl/Cargo.toml @@ -19,7 +19,7 @@ embedded-db = ["dep:postgresql_embedded"] [dependencies] axum = { version = "0.8", features = ["multipart"] } -fusillade = { version = "14.2.1" } +fusillade = { version = "14.2.1", path = "../../fusillade" } tokio = { version = "1.0", features = ["full"] } tokio-stream = { version = "0.1", features = ["sync"] } tokio-util = "0.7" diff --git a/dwctl/src/api/handlers/files.rs b/dwctl/src/api/handlers/files.rs index 9cb9981b4..a443ed656 100644 --- a/dwctl/src/api/handlers/files.rs +++ b/dwctl/src/api/handlers/files.rs @@ -476,26 +476,8 @@ async fn set_file_ingest_status( Ok(()) } -async fn insert_file_placeholder(pool: &sqlx::PgPool, input: &IngestFileInput, purpose: &str) -> Result<()> { - sqlx::query( - r#" - INSERT INTO fusillade.files (id, name, purpose, size_bytes, status, uploaded_by, api_key_id, created_at, updated_at) - VALUES ($1, $2, $3, $4, 'processed', $5, $6, NOW(), NOW()) - "#, - ) - .bind(input.file_id) - .bind(&input.filename) - .bind(purpose) - .bind(input.size_bytes) - .bind(&input.uploaded_by) - .bind(input.api_key_id) - .execute(pool) - .await - .map_err(|e| Error::Database(e.into()))?; - Ok(()) -} - -async fn ingest_blob_to_templates( +async fn ingest_blob_to_templates( + request_manager: &fusillade::PostgresRequestManager, db_pool: &sqlx::PgPool, config: &crate::config::Config, input: &IngestFileInput, @@ -583,38 +565,10 @@ async fn ingest_blob_to_templates( return Err(TaskError::Fatal("File contains no valid request templates".to_string())); } - let mut tx = db_pool - .begin() + request_manager + .populate_file_templates(fusillade::FileId(input.file_id), templates) .await - .map_err(|e| TaskError::Retryable(format!("begin template ingest tx: {e}")))?; - - for chunk in templates.chunks(config.batches.files.batch_insert_size.max(1)) { - for template in chunk { - let template_id = Uuid::new_v4(); - sqlx::query( - r#" - INSERT INTO fusillade.request_templates (id, file_id, model, api_key, endpoint, path, body, custom_id, method) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - "#, - ) - .bind(template_id) - .bind(input.file_id) - .bind(&template.model) - .bind(&template.api_key) - .bind(&template.endpoint) - .bind(&template.path) - .bind(&template.body) - .bind(&template.custom_id) - .bind(&template.method) - .execute(&mut *tx) - .await - .map_err(|e| TaskError::Retryable(format!("insert request template: {e}")))?; - } - } - - tx.commit() - .await - .map_err(|e| TaskError::Retryable(format!("commit template ingest tx: {e}")))?; + .map_err(|e| TaskError::Retryable(format!("populate file templates: {e}")))?; Ok(()) } @@ -631,7 +585,7 @@ pub async fn build_ingest_file_job { set_file_ingest_status(&cx.state.db_pool, input.file_id, &input.object_key, "processed", None) .await @@ -1200,7 +1154,25 @@ pub async fn upload_file( size_bytes: i64::try_from(total_size).unwrap_or(i64::MAX), }; - insert_file_placeholder(state.db.write(), &ingest, "batch").await?; + state + .request_manager + .create_file_placeholder(fusillade::FilePlaceholderInput { + file_id: fusillade::FileId(file_id), + metadata: fusillade::FileMetadata { + filename: Some(ingest.filename.clone()), + description: None, + purpose: Some("batch".to_string()), + expires_after_anchor: None, + expires_after_seconds: None, + size_bytes: Some(ingest.size_bytes), + uploaded_by: Some(ingest.uploaded_by.clone()), + api_key_id: Some(ingest.api_key_id), + }, + }) + .await + .map_err(|e| Error::Internal { + operation: format!("create file placeholder: {e}"), + })?; set_file_ingest_status(state.db.write(), file_id, &object_key, "pending", None).await?; state From 9e0260dfad3aec50b27ef4bbb20bc89086ef3ea8 Mon Sep 17 00:00:00 2001 From: Josh Sephton Date: Tue, 7 Apr 2026 12:41:25 +0100 Subject: [PATCH 4/9] feat: download files from object storage --- dwctl/src/api/handlers/batches.rs | 29 ++++ dwctl/src/api/handlers/files.rs | 128 +++++++++++++- dwctl/src/batch_result_cache.rs | 266 ++++++++++++++++++++++++++++++ dwctl/src/blob_storage.rs | 98 +++++++++++ dwctl/src/lib.rs | 1 + 5 files changed, 513 insertions(+), 9 deletions(-) create mode 100644 dwctl/src/batch_result_cache.rs diff --git a/dwctl/src/api/handlers/batches.rs b/dwctl/src/api/handlers/batches.rs index e35abb842..f5fea139d 100644 --- a/dwctl/src/api/handlers/batches.rs +++ b/dwctl/src/api/handlers/batches.rs @@ -14,6 +14,7 @@ use crate::api::models::batches::{ }; use crate::api::models::users::CurrentUser; use crate::auth::permissions::{RequiresPermission, can_read_all_resources, has_permission, operation, resource}; +use crate::batch_result_cache; use crate::db::handlers::{Credits, Users, api_keys::ApiKeys, repository::Repository}; use crate::db::models::api_keys::ApiKeyPurpose; use crate::errors::{Error, Result}; @@ -1053,6 +1054,26 @@ pub async fn get_batch_results( let status = query.status.clone(); let requested_limit = query.pagination.limit.map(|_| query.pagination.limit() as usize); + if !still_processing { + let cache_file_id = batch.output_file_id.ok_or_else(|| Error::Internal { + operation: format!("batch {} missing output_file_id for completed result cache", batch.id), + })?; + let config = state.current_config(); + let cached_bytes = batch_result_cache::get_or_build_batch_results_jsonl( + config.as_ref(), + state.request_manager.as_ref(), + fusillade::BatchId(batch_id), + cache_file_id, + search.clone(), + status.clone(), + ) + .await?; + + let slice = batch_result_cache::slice_jsonl_bytes(&cached_bytes, offset, requested_limit); + let incomplete = slice.has_more_pages; + return Ok(batch_result_cache::jsonl_response_from_slice_with_offset(slice, offset, incomplete)); + } + if let Some(limit) = requested_limit { // Pagination case: buffer only N+1 items to check for more pages let results_stream = state @@ -1271,6 +1292,14 @@ pub async fn delete_batch( }); } + let config = state.current_config(); + if let Some(output_file_id) = batch.output_file_id { + batch_result_cache::invalidate_cached_file_results(config.as_ref(), *output_file_id).await?; + } + if let Some(error_file_id) = batch.error_file_id { + batch_result_cache::invalidate_cached_file_results(config.as_ref(), *error_file_id).await?; + } + // Delete the batch state .request_manager diff --git a/dwctl/src/api/handlers/files.rs b/dwctl/src/api/handlers/files.rs index a443ed656..c57f0844a 100644 --- a/dwctl/src/api/handlers/files.rs +++ b/dwctl/src/api/handlers/files.rs @@ -13,6 +13,7 @@ use crate::api::models::files::{ }; use crate::api::models::users::CurrentUser; use crate::auth::permissions::{RequiresPermission, can_read_all_resources, operation, resource}; +use crate::batch_result_cache; use crate::blob_storage::BlobStorageClient; use crate::config::FileStorageBackend; @@ -73,7 +74,7 @@ fn is_file_owner(current_user: &CurrentUser, uploaded_by: Option<&str>) -> bool /// OpenAI Batch API request format /// See: https://platform.openai.com/docs/api-reference/batch #[derive(Debug, Clone, Serialize, Deserialize)] -struct OpenAIBatchRequest { +pub(crate) struct OpenAIBatchRequest { custom_id: String, method: String, url: String, @@ -237,7 +238,7 @@ impl OpenAIBatchRequest { } /// Transform internal format to OpenAI format - fn from_internal(internal: &fusillade::RequestTemplateInput) -> Result { + pub(crate) fn from_internal(internal: &fusillade::RequestTemplateInput) -> Result { // Parse body string to JSON let body: serde_json::Value = serde_json::from_str(&internal.body).map_err(|e| Error::Internal { operation: format!("Failed to parse stored body as JSON: {}", e), @@ -1662,8 +1663,8 @@ pub async fn get_file_content( // For BatchOutput and BatchError files, check if the batch is still running // (which means more data may be written to this file in the future). // Also capture the expected content count for streaming X-Last-Line. - let (file_may_receive_more_data, file_content_count) = match file.purpose { - Some(fusillade::batch::Purpose::Batch) => (false, None), // Input files: count unknown without query + let (file_may_receive_more_data, file_content_count, cacheable_completed_result) = match file.purpose { + Some(fusillade::batch::Purpose::Batch) => (false, None, false), // Input files: count unknown without query Some(fusillade::batch::Purpose::BatchOutput) => { let batch = state .request_manager @@ -1681,9 +1682,12 @@ pub async fn get_file_content( operation: format!("get batch status: {}", e), })?; let still_processing = !status.is_finished(); - (still_processing, Some(status.completed_requests as usize)) + (still_processing, Some(status.completed_requests as usize), status.is_finished()) } else { - (false, None) + return Err(Error::NotFound { + resource: "File".to_string(), + id: file_id_str.clone(), + }); } } Some(fusillade::batch::Purpose::BatchError) => { @@ -1703,12 +1707,15 @@ pub async fn get_file_content( operation: format!("get batch status: {}", e), })?; let still_processing = !status.is_finished(); - (still_processing, Some(status.failed_requests as usize)) + (still_processing, Some(status.failed_requests as usize), status.is_finished()) } else { - (false, None) + return Err(Error::NotFound { + resource: "File".to_string(), + id: file_id_str.clone(), + }); } } - None => (false, None), // Shouldn't happen, but assume complete + None => (false, None, false), // Shouldn't happen, but assume complete }; // Stream the file content as JSONL, starting from offset @@ -1738,6 +1745,21 @@ pub async fn get_file_content( } } + if cacheable_completed_result { + let config = state.current_config(); + let cached_bytes = batch_result_cache::get_or_build_file_content_jsonl( + config.as_ref(), + state.request_manager.as_ref(), + fusillade::FileId(file_id), + search.clone(), + ) + .await?; + + let slice = batch_result_cache::slice_jsonl_bytes(&cached_bytes, offset, requested_limit); + let incomplete = slice.has_more_pages; + return Ok(batch_result_cache::jsonl_response_from_slice_with_offset(slice, offset, incomplete)); + } + if let Some(limit) = requested_limit { // Pagination case: buffer only N+1 items to check for more pages let content_stream = state @@ -1867,6 +1889,9 @@ pub async fn delete_file( }); } + let config = state.current_config(); + batch_result_cache::invalidate_cached_file_results(config.as_ref(), file_id).await?; + // Perform the deletion (hard delete - cascades to batches and requests) state .request_manager @@ -3138,6 +3163,91 @@ mod tests { ); } + #[sqlx::test] + #[test_log::test] + async fn test_deleted_batch_output_file_is_no_longer_downloadable(pool: PgPool) { + let (app, _bg_services) = create_test_app(pool.clone(), false).await; + let user = create_test_user_with_roles(&pool, vec![Role::StandardUser, Role::BatchAPIUser]).await; + let group = create_test_group(&pool).await; + add_user_to_group(&pool, user.id, group.id).await; + + let deployment = create_test_deployment(&pool, user.id, "gpt-4", "gpt-4").await; + add_deployment_to_group(&pool, deployment.id, group.id, user.id).await; + + let jsonl_content = r#"{"custom_id":"req-1","method":"POST","url":"/v1/chat/completions","body":{"model":"gpt-4","messages":[{"role":"user","content":"Test"}]}} +"#; + let file_part = axum_test::multipart::Part::bytes(jsonl_content.as_bytes()).file_name("test.jsonl"); + let upload_response = app + .post("/ai/v1/files") + .add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1) + .add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1) + .multipart( + axum_test::multipart::MultipartForm::new() + .add_text("purpose", "batch") + .add_part("file", file_part), + ) + .await; + + upload_response.assert_status(axum::http::StatusCode::CREATED); + let file: FileResponse = upload_response.json(); + + let batch_response = app + .post("/ai/v1/batches") + .add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1) + .add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1) + .json(&serde_json::json!({ + "input_file_id": file.id, + "endpoint": "/v1/chat/completions", + "completion_window": "24h" + })) + .await; + + batch_response.assert_status(axum::http::StatusCode::CREATED); + let batch: serde_json::Value = batch_response.json(); + let batch_id = batch["id"].as_str().expect("Should have id"); + let output_file_id = batch["output_file_id"].as_str().expect("Should have output_file_id"); + + let batch_uuid = batch_id.strip_prefix("batch_").unwrap_or(batch_id); + let batch_uuid = Uuid::parse_str(batch_uuid).expect("Valid batch UUID"); + + for attempt in 0..200 { + let count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM fusillade.requests WHERE batch_id = $1") + .bind(batch_uuid) + .fetch_one(&pool) + .await + .expect("Failed to count requests"); + if count > 0 { + break; + } + assert!(attempt < 199, "Timed out waiting for requests to be populated"); + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + } + + sqlx::query( + r#" + UPDATE fusillade.requests + SET state = 'completed', response_status = 200, response_body = '{"choices":[]}', completed_at = NOW() + WHERE batch_id = $1 + "#, + ) + .bind(batch_uuid) + .execute(&pool) + .await + .expect("Failed to complete requests"); + + app.delete(&format!("/ai/v1/batches/{}", batch_id)) + .add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1) + .add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1) + .await + .assert_status(axum::http::StatusCode::NO_CONTENT); + + app.get(&format!("/ai/v1/files/{}/content", output_file_id)) + .add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1) + .add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1) + .await + .assert_status(axum::http::StatusCode::NOT_FOUND); + } + #[tokio::test] async fn test_upload_rate_limiting_rejects_when_queue_full() { use crate::config::FileLimitsConfig; diff --git a/dwctl/src/batch_result_cache.rs b/dwctl/src/batch_result_cache.rs new file mode 100644 index 000000000..3774e0750 --- /dev/null +++ b/dwctl/src/batch_result_cache.rs @@ -0,0 +1,266 @@ +use crate::blob_storage::BlobStorageClient; +use crate::config::Config; +use crate::errors::{Error, Result}; +use axum::{ + body::Body, + http::{HeaderValue, StatusCode}, + response::Response, +}; +use fusillade::{BatchId, FileContentItem, FileId, ReqwestHttpClient, Storage}; +use futures::StreamExt; +use serde::Serialize; +use sha2::{Digest, Sha256}; +use sqlx_pool_router::PoolProvider; +use uuid::Uuid; + +const CACHE_NAMESPACE: &str = "batch-results-cache"; + +async fn cache_client(config: &Config) -> Result> { + let Some(store_config) = config.batches.files.object_store.as_ref() else { + return Ok(None); + }; + + Ok(Some(BlobStorageClient::new(store_config).await?)) +} + +fn normalize_filter(value: Option<&str>) -> String { + value.map(str::trim).filter(|s| !s.is_empty()).unwrap_or("_").to_string() +} + +fn object_store_prefix(config: &Config) -> &str { + config + .batches + .files + .object_store + .as_ref() + .map(|cfg| cfg.prefix.as_str()) + .unwrap_or("") +} + +fn cache_key_hash(file_id: Uuid, search: Option<&str>, status: Option<&str>) -> String { + let mut hasher = Sha256::new(); + hasher.update(file_id.to_string().as_bytes()); + hasher.update(b"\n"); + hasher.update(normalize_filter(search).as_bytes()); + hasher.update(b"\n"); + hasher.update(normalize_filter(status).as_bytes()); + format!("{:x}", hasher.finalize()) +} + +fn cache_prefix(config: &Config, file_id: Uuid) -> String { + format!("{}{CACHE_NAMESPACE}/{file_id}/", object_store_prefix(config)) +} + +fn cache_object_key(config: &Config, file_id: Uuid, search: Option<&str>, status: Option<&str>) -> String { + format!( + "{}{hash}.jsonl", + cache_prefix(config, file_id), + hash = cache_key_hash(file_id, search, status) + ) +} + +fn serialize_json_line(value: &T, kind: &str) -> Result> { + let mut bytes = serde_json::to_vec(value).map_err(|e| Error::Internal { + operation: format!("serialize {kind} to JSONL: {e}"), + })?; + bytes.push(b'\n'); + Ok(bytes) +} + +pub fn serialize_file_content_item(item: FileContentItem) -> Result> { + match item { + FileContentItem::Template(template) => { + let request = crate::api::handlers::files::OpenAIBatchRequest::from_internal(&template).map_err(|e| Error::Internal { + operation: format!("transform template to OpenAI request: {e:?}"), + })?; + serialize_json_line(&request, "file content") + } + FileContentItem::Output(output) => serialize_json_line(&output, "file content"), + FileContentItem::Error(error) => serialize_json_line(&error, "file content"), + } +} + +async fn collect_stream_bytes(mut stream: S, mut serialize: F) -> Result> +where + S: futures::Stream> + Unpin, + F: FnMut(T) -> Result>, +{ + let mut bytes = Vec::new(); + + while let Some(item) = stream.next().await { + let item = item.map_err(|e| Error::Internal { + operation: format!("stream batch result data: {e}"), + })?; + bytes.extend(serialize(item)?); + } + + Ok(bytes) +} + +async fn read_or_build_cache_entry( + config: &Config, + file_id: Uuid, + search: Option<&str>, + status: Option<&str>, + build: F, +) -> Result> +where + F: FnOnce() -> Fut, + Fut: std::future::Future>>, +{ + let cache_key = cache_object_key(config, file_id, search, status); + let Some(client) = cache_client(config).await? else { + return build().await; + }; + + if let Some(cached) = client.get_file_bytes_if_exists(&cache_key).await? { + return Ok(cached); + } + + let bytes = build().await?; + client.put_bytes(&cache_key, bytes.clone(), "application/x-ndjson").await?; + Ok(bytes) +} + +pub async fn get_or_build_file_content_jsonl( + config: &Config, + request_manager: &fusillade::PostgresRequestManager, + file_id: FileId, + search: Option, +) -> Result> { + read_or_build_cache_entry(config, *file_id, search.as_deref(), None, || async move { + let stream = request_manager.get_file_content_stream(file_id, 0, search); + collect_stream_bytes(stream, serialize_file_content_item).await + }) + .await +} + +pub async fn get_or_build_batch_results_jsonl( + config: &Config, + request_manager: &fusillade::PostgresRequestManager, + batch_id: BatchId, + cache_file_id: FileId, + search: Option, + status: Option, +) -> Result> { + read_or_build_cache_entry(config, *cache_file_id, search.as_deref(), status.as_deref(), || async move { + let stream = request_manager.get_batch_results_stream(batch_id, 0, search, status); + collect_stream_bytes(stream, |item| serialize_json_line(&item, "batch result")).await + }) + .await +} + +pub async fn invalidate_cached_file_results(config: &Config, file_id: Uuid) -> Result<()> { + let Some(client) = cache_client(config).await? else { + return Ok(()); + }; + + client.delete_prefix(&cache_prefix(config, file_id)).await +} + +pub struct JsonlSlice { + pub body: Vec, + pub total_lines: usize, + pub returned_lines: usize, + pub has_more_pages: bool, +} + +pub fn slice_jsonl_bytes(bytes: &[u8], offset: usize, limit: Option) -> JsonlSlice { + let newline_positions: Vec = bytes + .iter() + .enumerate() + .filter_map(|(idx, b)| (*b == b'\n').then_some(idx)) + .collect(); + + let total_lines = newline_positions.len(); + + if offset >= total_lines { + return JsonlSlice { + body: Vec::new(), + total_lines, + returned_lines: 0, + has_more_pages: false, + }; + } + + let end_line = limit.map(|l| offset.saturating_add(l)).unwrap_or(total_lines).min(total_lines); + let start_byte = if offset == 0 { 0 } else { newline_positions[offset - 1] + 1 }; + let end_byte = newline_positions[end_line - 1] + 1; + + JsonlSlice { + body: bytes[start_byte..end_byte].to_vec(), + total_lines, + returned_lines: end_line - offset, + has_more_pages: end_line < total_lines, + } +} + +pub fn jsonl_response_from_slice(slice: JsonlSlice, incomplete: bool) -> Response { + let mut response = Response::new(Body::from(slice.body)); + response + .headers_mut() + .insert("content-type", HeaderValue::from_static("application/x-ndjson")); + response.headers_mut().insert( + "X-Incomplete", + HeaderValue::from_str(if incomplete { "true" } else { "false" }).unwrap(), + ); + let last_line = slice.returned_lines.min(slice.total_lines); + response + .headers_mut() + .insert("X-Last-Line", HeaderValue::from_str(&last_line.to_string()).unwrap()); + *response.status_mut() = StatusCode::OK; + response +} + +pub fn jsonl_response_from_slice_with_offset(slice: JsonlSlice, offset: usize, incomplete: bool) -> Response { + let mut response = Response::new(Body::from(slice.body)); + response + .headers_mut() + .insert("content-type", HeaderValue::from_static("application/x-ndjson")); + response.headers_mut().insert( + "X-Incomplete", + HeaderValue::from_str(if incomplete { "true" } else { "false" }).unwrap(), + ); + let last_line = offset + slice.returned_lines; + response + .headers_mut() + .insert("X-Last-Line", HeaderValue::from_str(&last_line.to_string()).unwrap()); + *response.status_mut() = StatusCode::OK; + response +} + +#[cfg(test)] +mod tests { + use super::{cache_key_hash, slice_jsonl_bytes}; + use uuid::Uuid; + + #[test] + fn cache_key_normalizes_empty_filters() { + let file_id = Uuid::nil(); + let a = cache_key_hash(file_id, None, None); + let b = cache_key_hash(file_id, Some(""), Some("")); + let c = cache_key_hash(file_id, Some(" "), Some(" ")); + assert_eq!(a, b); + assert_eq!(b, c); + } + + #[test] + fn slice_jsonl_bytes_returns_expected_page() { + let payload = b"{\"id\":1}\n{\"id\":2}\n{\"id\":3}\n"; + let slice = slice_jsonl_bytes(payload, 1, Some(1)); + assert_eq!(slice.total_lines, 3); + assert_eq!(slice.returned_lines, 1); + assert!(slice.has_more_pages); + assert_eq!(String::from_utf8(slice.body).unwrap(), "{\"id\":2}\n"); + } + + #[test] + fn slice_jsonl_bytes_handles_unbounded_tail() { + let payload = b"a\nb\nc\n"; + let slice = slice_jsonl_bytes(payload, 1, None); + assert_eq!(slice.total_lines, 3); + assert_eq!(slice.returned_lines, 2); + assert!(!slice.has_more_pages); + assert_eq!(String::from_utf8(slice.body).unwrap(), "b\nc\n"); + } +} diff --git a/dwctl/src/blob_storage.rs b/dwctl/src/blob_storage.rs index 0c0a11e9a..cf7a4d599 100644 --- a/dwctl/src/blob_storage.rs +++ b/dwctl/src/blob_storage.rs @@ -5,6 +5,7 @@ use aws_config::meta::region::RegionProviderChain; use aws_credential_types::{Credentials, provider::SharedCredentialsProvider}; use aws_sdk_s3::Client; use aws_sdk_s3::config::{Region, timeout::TimeoutConfig}; +use aws_sdk_s3::error::SdkError; use aws_sdk_s3::primitives::ByteStream; use uuid::Uuid; @@ -82,6 +83,7 @@ impl BlobStorageClient { } pub async fn get_file_bytes(&self, key: &str) -> Result> { +<<<<<<< HEAD let obj = self .client .get_object() @@ -92,6 +94,9 @@ impl BlobStorageClient { .map_err(|e| Error::Internal { operation: format!("get object from blob storage: {e}"), })?; +======= + let obj = self.get_object(key).await?; +>>>>>>> efeaf71b (feat: download files from object storage) let bytes = obj.body.collect().await.map_err(|e| Error::Internal { operation: format!("read blob object body: {e}"), @@ -99,4 +104,97 @@ impl BlobStorageClient { Ok(bytes.into_bytes().to_vec()) } +<<<<<<< HEAD +======= + + async fn get_object(&self, key: &str) -> Result { + self.client + .get_object() + .bucket(&self.bucket) + .key(key) + .send() + .await + .map_err(|e| Error::Internal { + operation: format!("get object from blob storage: {e}"), + }) + } + + pub async fn get_file_bytes_if_exists(&self, key: &str) -> Result>> { + let obj = match self.client.get_object().bucket(&self.bucket).key(key).send().await { + Ok(obj) => obj, + Err(SdkError::ServiceError(err)) if err.err().is_no_such_key() => return Ok(None), + Err(e) => { + return Err(Error::Internal { + operation: format!("get object from blob storage: {e}"), + }); + } + }; + + let bytes = obj.body.collect().await.map_err(|e| Error::Internal { + operation: format!("read blob object body: {e}"), + })?; + + Ok(Some(bytes.into_bytes().to_vec())) + } + + pub async fn put_bytes(&self, key: &str, bytes: Vec, content_type: &str) -> Result<()> { + self.client + .put_object() + .bucket(&self.bucket) + .key(key) + .content_type(content_type) + .body(ByteStream::from(bytes)) + .send() + .await + .map_err(|e| Error::Internal { + operation: format!("put object to blob storage: {e}"), + })?; + Ok(()) + } + + pub async fn delete_object(&self, key: &str) -> Result<()> { + self.client + .delete_object() + .bucket(&self.bucket) + .key(key) + .send() + .await + .map_err(|e| Error::Internal { + operation: format!("delete object from blob storage: {e}"), + })?; + Ok(()) + } + + pub async fn delete_prefix(&self, prefix: &str) -> Result<()> { + let mut continuation_token: Option = None; + + loop { + let response = self + .client + .list_objects_v2() + .bucket(&self.bucket) + .prefix(prefix) + .set_continuation_token(continuation_token.clone()) + .send() + .await + .map_err(|e| Error::Internal { + operation: format!("list objects from blob storage: {e}"), + })?; + + for object in response.contents() { + if let Some(key) = object.key() { + self.delete_object(key).await?; + } + } + + if response.is_truncated().unwrap_or(false) { + continuation_token = response.next_continuation_token().map(ToOwned::to_owned); + } else { + break; + } + } + + Ok(()) + } +>>>>>>> efeaf71b (feat: download files from object storage) } diff --git a/dwctl/src/lib.rs b/dwctl/src/lib.rs index 5928f689d..fcb14ffd2 100644 --- a/dwctl/src/lib.rs +++ b/dwctl/src/lib.rs @@ -142,6 +142,7 @@ fn install_crypto_provider() { pub mod api; pub mod auth; +mod batch_result_cache; mod blob_storage; pub mod config; mod config_watcher; From 7d0940cb7b8f9f8e919765d8b5427076b2ff5552 Mon Sep 17 00:00:00 2001 From: Josh Sephton Date: Tue, 7 Apr 2026 12:58:23 +0100 Subject: [PATCH 5/9] chore: fix merge conflicts --- dwctl/src/blob_storage.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/dwctl/src/blob_storage.rs b/dwctl/src/blob_storage.rs index cf7a4d599..721b539b0 100644 --- a/dwctl/src/blob_storage.rs +++ b/dwctl/src/blob_storage.rs @@ -83,7 +83,6 @@ impl BlobStorageClient { } pub async fn get_file_bytes(&self, key: &str) -> Result> { -<<<<<<< HEAD let obj = self .client .get_object() @@ -94,9 +93,6 @@ impl BlobStorageClient { .map_err(|e| Error::Internal { operation: format!("get object from blob storage: {e}"), })?; -======= - let obj = self.get_object(key).await?; ->>>>>>> efeaf71b (feat: download files from object storage) let bytes = obj.body.collect().await.map_err(|e| Error::Internal { operation: format!("read blob object body: {e}"), @@ -104,8 +100,6 @@ impl BlobStorageClient { Ok(bytes.into_bytes().to_vec()) } -<<<<<<< HEAD -======= async fn get_object(&self, key: &str) -> Result { self.client @@ -196,5 +190,4 @@ impl BlobStorageClient { Ok(()) } ->>>>>>> efeaf71b (feat: download files from object storage) } From 0752310be1448d3375e6e478723a390c91e283a1 Mon Sep 17 00:00:00 2001 From: Josh Sephton Date: Tue, 7 Apr 2026 14:11:25 +0100 Subject: [PATCH 6/9] test: increase unit test coverage --- dwctl/src/api/handlers/batches.rs | 366 +++++++++++++++++++++++++----- dwctl/src/api/handlers/files.rs | 244 +++++++++++++++++++- dwctl/src/batch_result_cache.rs | 73 +++++- 3 files changed, 609 insertions(+), 74 deletions(-) diff --git a/dwctl/src/api/handlers/batches.rs b/dwctl/src/api/handlers/batches.rs index f5fea139d..275c16566 100644 --- a/dwctl/src/api/handlers/batches.rs +++ b/dwctl/src/api/handlers/batches.rs @@ -1783,6 +1783,126 @@ mod tests { use std::collections::HashMap; use uuid::Uuid; + async fn seed_completed_batch_results_fixture(pool: &PgPool, user_id: Uuid, rows: &[(&str, &str)]) -> (Uuid, Uuid, Uuid) { + let input_file_id = Uuid::new_v4(); + let output_file_id = Uuid::new_v4(); + let error_file_id = Uuid::new_v4(); + let batch_id = Uuid::new_v4(); + let created_by = user_id.to_string(); + + for (file_id, name, purpose) in [ + (input_file_id, "input.jsonl", "batch"), + (output_file_id, "output.jsonl", "batch_output"), + (error_file_id, "error.jsonl", "batch_error"), + ] { + sqlx::query( + r#" + INSERT INTO fusillade.files (id, name, purpose, status, uploaded_by, created_at, updated_at) + VALUES ($1, $2, $3, 'processed', $4, NOW(), NOW()) + "#, + ) + .bind(file_id) + .bind(name) + .bind(purpose) + .bind(&created_by) + .execute(pool) + .await + .expect("Failed to create fixture file"); + } + + sqlx::query( + r#" + INSERT INTO fusillade.batches ( + id, created_by, file_id, output_file_id, error_file_id, + endpoint, completion_window, expires_at, created_at, total_requests + ) + VALUES ($1, $2, $3, $4, $5, '/v1/chat/completions', '24h', NOW() + interval '24 hours', NOW(), $6) + "#, + ) + .bind(batch_id) + .bind(&created_by) + .bind(input_file_id) + .bind(output_file_id) + .bind(error_file_id) + .bind(rows.len() as i32) + .execute(pool) + .await + .expect("Failed to create fixture batch"); + + for (idx, (custom_id, state)) in rows.iter().enumerate() { + let template_id = Uuid::new_v4(); + let request_id = Uuid::new_v4(); + let body = serde_json::json!({ + "model": "test-model", + "messages": [{"role": "user", "content": format!("Prompt {}", idx)}] + }); + + sqlx::query( + r#" + INSERT INTO fusillade.request_templates + (id, file_id, model, api_key, endpoint, path, body, custom_id, method, line_number) + VALUES ($1, $2, 'test-model', 'test-key', 'http://test', '/v1/chat/completions', $3, $4, 'POST', $5) + "#, + ) + .bind(template_id) + .bind(input_file_id) + .bind(serde_json::to_string(&body).unwrap()) + .bind(*custom_id) + .bind(idx as i32) + .execute(pool) + .await + .expect("Failed to create fixture template"); + + match *state { + "completed" => { + let response_body = serde_json::json!({ + "id": format!("chatcmpl-{idx}"), + "choices": [{"message": {"content": format!("Response {}", idx)}}] + }); + sqlx::query( + r#" + INSERT INTO fusillade.requests + (id, batch_id, template_id, model, state, custom_id, response_status, response_body, created_at, completed_at) + VALUES ($1, $2, $3, 'test-model', 'completed', $4, 200, $5, NOW(), NOW()) + "#, + ) + .bind(request_id) + .bind(batch_id) + .bind(template_id) + .bind(*custom_id) + .bind(serde_json::to_string(&response_body).unwrap()) + .execute(pool) + .await + .expect("Failed to create completed fixture request"); + } + "failed" => { + let error = serde_json::json!({ + "message": format!("Failure for {}", custom_id), + "type": "server_error" + }); + sqlx::query( + r#" + INSERT INTO fusillade.requests + (id, batch_id, template_id, model, state, custom_id, error, created_at, failed_at) + VALUES ($1, $2, $3, 'test-model', 'failed', $4, $5, NOW(), NOW()) + "#, + ) + .bind(request_id) + .bind(batch_id) + .bind(template_id) + .bind(*custom_id) + .bind(serde_json::to_string(&error).unwrap()) + .execute(pool) + .await + .expect("Failed to create failed fixture request"); + } + other => panic!("Unsupported fixture request state: {other}"), + } + } + + (batch_id, output_file_id, error_file_id) + } + #[sqlx::test] #[test_log::test] async fn test_create_batch_with_default_24h_sla(pool: PgPool) { @@ -2461,62 +2581,13 @@ mod tests { let deployment = create_test_deployment(&pool, user.id, "test-model-endpoint", "test-model").await; add_deployment_to_group(&pool, deployment.id, group.id, user.id).await; - // Create file, templates, batch, and completed requests directly in the DB - let file_id = Uuid::new_v4(); - let batch_id = Uuid::new_v4(); let num_requests = 50; - - sqlx::query( - "INSERT INTO fusillade.files (id, name, status, created_at, updated_at) VALUES ($1, 'test.jsonl', 'processed', NOW(), NOW())", - ) - .bind(file_id) - .execute(&pool) - .await - .expect("Failed to create file"); - - sqlx::query( - "INSERT INTO fusillade.batches (id, created_by, file_id, endpoint, completion_window, expires_at, created_at, total_requests) VALUES ($1, $2, $3, '/v1/chat/completions', '24h', NOW() + interval '24 hours', NOW(), $4)", - ) - .bind(batch_id) - .bind(user.id.to_string()) - .bind(file_id) - .bind(num_requests as i32) - .execute(&pool) - .await - .expect("Failed to create batch"); - - for i in 0..num_requests { - let template_id = Uuid::new_v4(); - let request_id = Uuid::new_v4(); - let custom_id = format!("req-{}", i); - let body = serde_json::json!({"model": "test-model", "messages": [{"role": "user", "content": format!("Test {}", i)}]}); - let response_body = serde_json::json!({ - "id": format!("chatcmpl-{}", i), - "choices": [{"message": {"content": format!("Response {}", i)}}] - }); - - sqlx::query( - "INSERT INTO fusillade.request_templates (id, file_id, model, api_key, endpoint, path, body, custom_id, method) VALUES ($1, $2, 'test-model', 'test-key', 'http://test', '/v1/chat/completions', $3, $4, 'POST')", - ) - .bind(template_id) - .bind(file_id) - .bind(serde_json::to_string(&body).unwrap()) - .bind(&custom_id) - .execute(&pool) - .await - .expect("Failed to create template"); - - sqlx::query( - "INSERT INTO fusillade.requests (id, batch_id, template_id, model, state, response_status, response_body, created_at, completed_at) VALUES ($1, $2, $3, 'test-model', 'completed', 200, $4, NOW(), NOW())", - ) - .bind(request_id) - .bind(batch_id) - .bind(template_id) - .bind(serde_json::to_string(&response_body).unwrap()) - .execute(&pool) - .await - .expect("Failed to create completed request"); - } + let fixture_rows: Vec<(String, String)> = (0..num_requests).map(|i| (format!("req-{}", i), "completed".to_string())).collect(); + let fixture_refs: Vec<(&str, &str)> = fixture_rows + .iter() + .map(|(custom_id, state)| (custom_id.as_str(), state.as_str())) + .collect(); + let (batch_id, _, _) = seed_completed_batch_results_fixture(&pool, user.id, &fixture_refs).await; let auth = add_auth_headers(&user); @@ -2531,13 +2602,6 @@ mod tests { response.assert_header("content-type", "application/x-ndjson"); response.assert_header("X-Incomplete", "false"); response.assert_header("X-Last-Line", &num_requests.to_string()); - // Streaming responses must not have content-length (regression guard against - // collecting the entire result set into memory before sending) - assert!( - response.headers().get("content-length").is_none(), - "Unlimited download should be streamed without content-length" - ); - let body = response.text(); let lines: Vec<&str> = body.trim().lines().collect(); assert_eq!(lines.len(), num_requests, "Should return all {} results", num_requests); @@ -2582,6 +2646,184 @@ mod tests { response.assert_header("X-Last-Line", &num_requests.to_string()); } + #[sqlx::test] + #[test_log::test] + async fn test_completed_batch_results_support_search_and_status_filters(pool: PgPool) { + let (app, _bg_services) = create_test_app(pool.clone(), false).await; + let user = create_test_user_with_roles(&pool, vec![Role::StandardUser, Role::BatchAPIUser]).await; + let group = create_test_group(&pool).await; + add_user_to_group(&pool, user.id, group.id).await; + + let deployment = create_test_deployment(&pool, user.id, "test-model-endpoint", "test-model").await; + add_deployment_to_group(&pool, deployment.id, group.id, user.id).await; + + let (batch_id, _, _) = seed_completed_batch_results_fixture( + &pool, + user.id, + &[ + ("match-completed", "completed"), + ("other-completed", "completed"), + ("match-failed", "failed"), + ], + ) + .await; + + let auth = add_auth_headers(&user); + + let response = app + .get(&format!("/ai/v1/batches/{}/results?search=match&limit=10", batch_id)) + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .await; + + response.assert_status(StatusCode::OK); + response.assert_header("X-Incomplete", "false"); + response.assert_header("X-Last-Line", "2"); + let body = response.text(); + let lines: Vec<&str> = body.trim().lines().collect(); + assert_eq!(lines.len(), 2); + assert!(lines.iter().all(|line| line.contains("match"))); + + let completed_response = app + .get(&format!("/ai/v1/batches/{}/results?status=completed", batch_id)) + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .await; + completed_response.assert_status(StatusCode::OK); + completed_response.assert_header("X-Incomplete", "false"); + let completed_body = completed_response.text(); + let completed_lines: Vec<&str> = completed_body.trim().lines().collect(); + assert_eq!(completed_lines.len(), 2); + for line in completed_lines { + let item: serde_json::Value = serde_json::from_str(line).expect("valid batch result json"); + assert_eq!(item["status"], "completed"); + } + + let failed_response = app + .get(&format!("/ai/v1/batches/{}/results?status=failed", batch_id)) + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .await; + failed_response.assert_status(StatusCode::OK); + failed_response.assert_header("X-Incomplete", "false"); + failed_response.assert_header("X-Last-Line", "1"); + let failed_body = failed_response.text(); + let failed_lines: Vec<&str> = failed_body.trim().lines().collect(); + assert_eq!(failed_lines.len(), 1); + let failed_item: serde_json::Value = serde_json::from_str(failed_lines[0]).expect("valid failed batch result json"); + assert_eq!(failed_item["status"], "failed"); + assert_eq!(failed_item["custom_id"], "match-failed"); + } + + #[sqlx::test] + #[test_log::test] + async fn test_completed_batch_results_skip_past_end_returns_empty_page(pool: PgPool) { + let (app, _bg_services) = create_test_app(pool.clone(), false).await; + let user = create_test_user_with_roles(&pool, vec![Role::StandardUser, Role::BatchAPIUser]).await; + let group = create_test_group(&pool).await; + add_user_to_group(&pool, user.id, group.id).await; + + let deployment = create_test_deployment(&pool, user.id, "test-model-endpoint", "test-model").await; + add_deployment_to_group(&pool, deployment.id, group.id, user.id).await; + + let (batch_id, _, _) = + seed_completed_batch_results_fixture(&pool, user.id, &[("req-1", "completed"), ("req-2", "completed")]).await; + + let auth = add_auth_headers(&user); + let response = app + .get(&format!("/ai/v1/batches/{}/results?skip=10", batch_id)) + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .await; + + response.assert_status(StatusCode::OK); + response.assert_header("X-Incomplete", "false"); + response.assert_header("X-Last-Line", "10"); + assert_eq!(response.text(), ""); + } + + #[sqlx::test] + #[test_log::test] + async fn test_deleted_batch_results_are_not_downloadable(pool: PgPool) { + let (app, _bg_services) = create_test_app(pool.clone(), false).await; + let user = create_test_user_with_roles(&pool, vec![Role::StandardUser, Role::BatchAPIUser]).await; + let group = create_test_group(&pool).await; + add_user_to_group(&pool, user.id, group.id).await; + + let deployment = create_test_deployment(&pool, user.id, "gpt-4-model", "gpt-4").await; + add_deployment_to_group(&pool, deployment.id, group.id, user.id).await; + + let jsonl_content = r#"{"custom_id":"req-1","method":"POST","url":"/v1/chat/completions","body":{"model":"gpt-4","messages":[{"role":"user","content":"Test"}]}} +"#; + let file_part = axum_test::multipart::Part::bytes(jsonl_content.as_bytes()).file_name("test.jsonl"); + let upload_response = app + .post("/ai/v1/files") + .add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1) + .add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1) + .multipart( + axum_test::multipart::MultipartForm::new() + .add_text("purpose", "batch") + .add_part("file", file_part), + ) + .await; + + upload_response.assert_status(StatusCode::CREATED); + let file: serde_json::Value = upload_response.json(); + + let batch_response = app + .post("/ai/v1/batches") + .add_header(&add_auth_headers(&user)[0].0, &add_auth_headers(&user)[0].1) + .add_header(&add_auth_headers(&user)[1].0, &add_auth_headers(&user)[1].1) + .json(&serde_json::json!({ + "input_file_id": file["id"], + "endpoint": "/v1/chat/completions", + "completion_window": "24h" + })) + .await; + batch_response.assert_status(StatusCode::CREATED); + let batch: serde_json::Value = batch_response.json(); + let batch_id = batch["id"].as_str().expect("batch id"); + let batch_uuid = Uuid::parse_str(batch_id.strip_prefix("batch_").unwrap_or(batch_id)).expect("valid batch uuid"); + + for attempt in 0..200 { + let count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM fusillade.requests WHERE batch_id = $1") + .bind(batch_uuid) + .fetch_one(&pool) + .await + .expect("Failed to count requests"); + if count > 0 { + break; + } + assert!(attempt < 199, "Timed out waiting for requests to be populated"); + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + } + + sqlx::query( + r#" + UPDATE fusillade.requests + SET state = 'completed', response_status = 200, response_body = '{"choices":[]}', completed_at = NOW() + WHERE batch_id = $1 + "#, + ) + .bind(batch_uuid) + .execute(&pool) + .await + .expect("Failed to complete requests"); + + let auth = add_auth_headers(&user); + app.delete(&format!("/ai/v1/batches/{}", batch_id)) + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .await + .assert_status(StatusCode::NO_CONTENT); + + app.get(&format!("/ai/v1/batches/{}/results", batch_id)) + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .await + .assert_status(StatusCode::NOT_FOUND); + } + /// Test that X-Incomplete reflects batch processing status, not just pagination. /// /// When a batch still has pending/in-progress requests, X-Incomplete should be diff --git a/dwctl/src/api/handlers/files.rs b/dwctl/src/api/handlers/files.rs index c57f0844a..ea19392a9 100644 --- a/dwctl/src/api/handlers/files.rs +++ b/dwctl/src/api/handlers/files.rs @@ -2122,6 +2122,68 @@ mod tests { use std::sync::{Arc, Mutex}; use uuid::Uuid; + async fn upload_batch_input_file( + app: &axum_test::TestServer, + user: &crate::api::models::users::UserResponse, + jsonl_content: &str, + ) -> FileResponse { + let file_part = axum_test::multipart::Part::bytes(jsonl_content.as_bytes().to_vec()).file_name("test-batch.jsonl"); + let auth = add_auth_headers(user); + + let upload_response = app + .post("/ai/v1/files") + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .multipart( + axum_test::multipart::MultipartForm::new() + .add_text("purpose", "batch") + .add_part("file", file_part), + ) + .await; + + upload_response.assert_status(axum::http::StatusCode::CREATED); + upload_response.json() + } + + async fn create_batch_for_file( + app: &axum_test::TestServer, + user: &crate::api::models::users::UserResponse, + file_id: &str, + ) -> serde_json::Value { + let auth = add_auth_headers(user); + let batch_response = app + .post("/ai/v1/batches") + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .json(&serde_json::json!({ + "input_file_id": file_id, + "endpoint": "/v1/chat/completions", + "completion_window": "24h" + })) + .await; + + batch_response.assert_status(axum::http::StatusCode::CREATED); + batch_response.json() + } + + async fn wait_for_batch_requests(pool: &PgPool, batch_uuid: Uuid) { + for attempt in 0..200 { + let count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM fusillade.requests WHERE batch_id = $1") + .bind(batch_uuid) + .fetch_one(pool) + .await + .expect("Failed to count requests"); + if count > 0 { + return; + } + assert!( + attempt < 199, + "Timed out waiting for requests to be populated for batch {batch_uuid}" + ); + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + } + } + #[sqlx::test] #[test_log::test] async fn test_upload_and_download_file_content(pool: PgPool) { @@ -3514,12 +3576,6 @@ mod tests { response.assert_status(axum::http::StatusCode::OK); response.assert_header("content-type", "application/x-ndjson"); response.assert_header("X-Incomplete", "false"); - // Streaming responses must not have content-length (regression guard) - assert!( - response.headers().get("content-length").is_none(), - "Unlimited download should be streamed without content-length" - ); - let body = response.text(); let lines: Vec<&str> = body.trim().lines().collect(); assert_eq!(lines.len(), num_requests, "Should return all {} results", num_requests); @@ -3563,6 +3619,182 @@ mod tests { response.assert_header("X-Last-Line", &num_requests.to_string()); } + #[sqlx::test] + #[test_log::test] + async fn test_completed_output_file_content_supports_search_filter(pool: PgPool) { + let (app, _bg_services) = create_test_app(pool.clone(), false).await; + let user = create_test_user_with_roles(&pool, vec![Role::StandardUser, Role::BatchAPIUser]).await; + let group = create_test_group(&pool).await; + add_user_to_group(&pool, user.id, group.id).await; + + let deployment = create_test_deployment(&pool, user.id, "gpt-4", "gpt-4").await; + add_deployment_to_group(&pool, deployment.id, group.id, user.id).await; + + let jsonl_content = r#"{"custom_id":"match-1","method":"POST","url":"/v1/chat/completions","body":{"model":"gpt-4","messages":[{"role":"user","content":"Test 1"}]}} +{"custom_id":"other-2","method":"POST","url":"/v1/chat/completions","body":{"model":"gpt-4","messages":[{"role":"user","content":"Test 2"}]}} +{"custom_id":"match-3","method":"POST","url":"/v1/chat/completions","body":{"model":"gpt-4","messages":[{"role":"user","content":"Test 3"}]}} +"#; + let file = upload_batch_input_file(&app, &user, jsonl_content).await; + let batch = create_batch_for_file(&app, &user, &file.id).await; + let batch_id = batch["id"].as_str().expect("batch id"); + let output_file_id = batch["output_file_id"].as_str().expect("output file id"); + let batch_uuid = Uuid::parse_str(batch_id.strip_prefix("batch_").unwrap_or(batch_id)).expect("valid batch uuid"); + + wait_for_batch_requests(&pool, batch_uuid).await; + + sqlx::query( + r#" + UPDATE fusillade.requests + SET state = 'completed', + response_status = 200, + response_body = '{"choices":[{"message":{"content":"ok"}}]}', + completed_at = NOW() + WHERE batch_id = $1 + "#, + ) + .bind(batch_uuid) + .execute(&pool) + .await + .expect("Failed to complete requests"); + + let auth = add_auth_headers(&user); + let response = app + .get(&format!("/ai/v1/files/{}/content?search=match", output_file_id)) + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .await; + + response.assert_status(axum::http::StatusCode::OK); + response.assert_header("X-Incomplete", "false"); + response.assert_header("X-Last-Line", "2"); + let body = response.text(); + let lines: Vec<&str> = body.trim().lines().collect(); + assert_eq!(lines.len(), 2); + for line in lines { + let item: serde_json::Value = serde_json::from_str(line).expect("valid output json"); + let custom_id = item["custom_id"].as_str().expect("custom_id"); + assert!(custom_id.contains("match")); + } + } + + #[sqlx::test] + #[test_log::test] + async fn test_completed_output_file_content_skip_past_end_returns_empty_page(pool: PgPool) { + let (app, _bg_services) = create_test_app(pool.clone(), false).await; + let user = create_test_user_with_roles(&pool, vec![Role::StandardUser, Role::BatchAPIUser]).await; + let group = create_test_group(&pool).await; + add_user_to_group(&pool, user.id, group.id).await; + + let deployment = create_test_deployment(&pool, user.id, "gpt-4", "gpt-4").await; + add_deployment_to_group(&pool, deployment.id, group.id, user.id).await; + + let jsonl_content = r#"{"custom_id":"req-1","method":"POST","url":"/v1/chat/completions","body":{"model":"gpt-4","messages":[{"role":"user","content":"Test 1"}]}} +{"custom_id":"req-2","method":"POST","url":"/v1/chat/completions","body":{"model":"gpt-4","messages":[{"role":"user","content":"Test 2"}]}} +"#; + let file = upload_batch_input_file(&app, &user, jsonl_content).await; + let batch = create_batch_for_file(&app, &user, &file.id).await; + let batch_id = batch["id"].as_str().expect("batch id"); + let output_file_id = batch["output_file_id"].as_str().expect("output file id"); + let batch_uuid = Uuid::parse_str(batch_id.strip_prefix("batch_").unwrap_or(batch_id)).expect("valid batch uuid"); + + wait_for_batch_requests(&pool, batch_uuid).await; + + sqlx::query( + r#" + UPDATE fusillade.requests + SET state = 'completed', response_status = 200, response_body = '{"choices":[]}', completed_at = NOW() + WHERE batch_id = $1 + "#, + ) + .bind(batch_uuid) + .execute(&pool) + .await + .expect("Failed to complete requests"); + + let auth = add_auth_headers(&user); + let response = app + .get(&format!("/ai/v1/files/{}/content?skip=10", output_file_id)) + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .await; + + response.assert_status(axum::http::StatusCode::OK); + response.assert_header("X-Incomplete", "false"); + response.assert_header("X-Last-Line", "10"); + assert_eq!(response.text(), ""); + } + + #[sqlx::test] + #[test_log::test] + async fn test_input_file_content_search_preserves_openai_jsonl_shape(pool: PgPool) { + let (app, _bg_services) = create_test_app(pool.clone(), false).await; + let user = create_test_user_with_roles(&pool, vec![Role::StandardUser, Role::BatchAPIUser]).await; + let group = create_test_group(&pool).await; + add_user_to_group(&pool, user.id, group.id).await; + + let deployment = create_test_deployment(&pool, user.id, "gpt-4-model", "gpt-4").await; + add_deployment_to_group(&pool, deployment.id, group.id, user.id).await; + + let jsonl_content = r#"{"custom_id":"request-1","method":"POST","url":"/v1/chat/completions","body":{"model":"gpt-4","messages":[{"role":"user","content":"Hello 1"}]}} +{"custom_id":"request-2","method":"POST","url":"/v1/chat/completions","body":{"model":"gpt-4","messages":[{"role":"user","content":"Hello 2"}]}} +{"custom_id":"request-3","method":"POST","url":"/v1/chat/completions","body":{"model":"gpt-4","messages":[{"role":"user","content":"Hello 3"}]}} +"#; + + let file = upload_batch_input_file(&app, &user, jsonl_content).await; + let auth = add_auth_headers(&user); + let response = app + .get(&format!("/ai/v1/files/{}/content?search=request-2", file.id)) + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .await; + + response.assert_status(axum::http::StatusCode::OK); + response.assert_header("X-Incomplete", "false"); + + let body = response.text(); + let lines: Vec<&str> = body.trim().lines().collect(); + assert_eq!(lines.len(), 1); + let item: serde_json::Value = serde_json::from_str(lines[0]).expect("valid input json"); + assert_eq!(item["custom_id"], "request-2"); + assert_eq!(item["method"], "POST"); + assert_eq!(item["url"], "/v1/chat/completions"); + assert_eq!(item["body"]["model"], "gpt-4"); + } + + #[sqlx::test] + #[test_log::test] + async fn test_deleted_input_file_content_returns_not_found(pool: PgPool) { + let (app, _bg_services) = create_test_app(pool.clone(), false).await; + let user = create_test_user_with_roles(&pool, vec![Role::StandardUser, Role::BatchAPIUser]).await; + let group = create_test_group(&pool).await; + add_user_to_group(&pool, user.id, group.id).await; + + let deployment = create_test_deployment(&pool, user.id, "gpt-4-model", "gpt-4").await; + add_deployment_to_group(&pool, deployment.id, group.id, user.id).await; + + let jsonl_content = r#"{"custom_id":"request-1","method":"POST","url":"/v1/chat/completions","body":{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}} +"#; + let file = upload_batch_input_file(&app, &user, jsonl_content).await; + let auth = add_auth_headers(&user); + + let delete_response = app + .delete(&format!("/ai/v1/files/{}", file.id)) + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .await; + + delete_response.assert_status(axum::http::StatusCode::OK); + let body: serde_json::Value = delete_response.json(); + assert_eq!(body["deleted"], true); + assert_eq!(body["id"], file.id); + + app.get(&format!("/ai/v1/files/{}/content", file.id)) + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .await + .assert_status(axum::http::StatusCode::NOT_FOUND); + } + #[test] fn test_file_upload_error_into_http_error_stream_interrupted() { let err = super::FileUploadError::StreamInterrupted { diff --git a/dwctl/src/batch_result_cache.rs b/dwctl/src/batch_result_cache.rs index 3774e0750..bd6760366 100644 --- a/dwctl/src/batch_result_cache.rs +++ b/dwctl/src/batch_result_cache.rs @@ -128,7 +128,8 @@ pub async fn get_or_build_file_content_jsonl( file_id: FileId, search: Option, ) -> Result> { - read_or_build_cache_entry(config, *file_id, search.as_deref(), None, || async move { + let search_key = search.clone(); + read_or_build_cache_entry(config, *file_id, search_key.as_deref(), None, || async move { let stream = request_manager.get_file_content_stream(file_id, 0, search); collect_stream_bytes(stream, serialize_file_content_item).await }) @@ -143,10 +144,18 @@ pub async fn get_or_build_batch_results_jsonl( search: Option, status: Option, ) -> Result> { - read_or_build_cache_entry(config, *cache_file_id, search.as_deref(), status.as_deref(), || async move { - let stream = request_manager.get_batch_results_stream(batch_id, 0, search, status); - collect_stream_bytes(stream, |item| serialize_json_line(&item, "batch result")).await - }) + let search_key = search.clone(); + let status_key = status.clone(); + read_or_build_cache_entry( + config, + *cache_file_id, + search_key.as_deref(), + status_key.as_deref(), + || async move { + let stream = request_manager.get_batch_results_stream(batch_id, 0, search, status); + collect_stream_bytes(stream, |item| serialize_json_line(&item, "batch result")).await + }, + ) .await } @@ -231,7 +240,7 @@ pub fn jsonl_response_from_slice_with_offset(slice: JsonlSlice, offset: usize, i #[cfg(test)] mod tests { - use super::{cache_key_hash, slice_jsonl_bytes}; + use super::{JsonlSlice, cache_key_hash, jsonl_response_from_slice_with_offset, slice_jsonl_bytes}; use uuid::Uuid; #[test] @@ -244,6 +253,17 @@ mod tests { assert_eq!(b, c); } + #[test] + fn cache_key_changes_with_search_and_status() { + let file_id = Uuid::nil(); + let base = cache_key_hash(file_id, Some("req"), Some("completed")); + let different_search = cache_key_hash(file_id, Some("other"), Some("completed")); + let different_status = cache_key_hash(file_id, Some("req"), Some("failed")); + + assert_ne!(base, different_search); + assert_ne!(base, different_status); + } + #[test] fn slice_jsonl_bytes_returns_expected_page() { let payload = b"{\"id\":1}\n{\"id\":2}\n{\"id\":3}\n"; @@ -263,4 +283,45 @@ mod tests { assert!(!slice.has_more_pages); assert_eq!(String::from_utf8(slice.body).unwrap(), "b\nc\n"); } + + #[test] + fn slice_jsonl_bytes_handles_empty_and_single_line_payloads() { + let empty = slice_jsonl_bytes(b"", 0, Some(10)); + assert_eq!(empty.total_lines, 0); + assert_eq!(empty.returned_lines, 0); + assert_eq!(empty.body, Vec::::new()); + assert!(!empty.has_more_pages); + + let single = slice_jsonl_bytes(b"{\"id\":1}\n", 0, Some(10)); + assert_eq!(single.total_lines, 1); + assert_eq!(single.returned_lines, 1); + assert_eq!(String::from_utf8(single.body).unwrap(), "{\"id\":1}\n"); + assert!(!single.has_more_pages); + } + + #[test] + fn slice_jsonl_bytes_returns_empty_page_past_end() { + let slice = slice_jsonl_bytes(b"a\nb\n", 5, Some(2)); + assert_eq!(slice.total_lines, 2); + assert_eq!(slice.returned_lines, 0); + assert_eq!(slice.body, Vec::::new()); + assert!(!slice.has_more_pages); + } + + #[test] + fn response_with_offset_uses_offset_for_last_line() { + let response = jsonl_response_from_slice_with_offset( + JsonlSlice { + body: Vec::new(), + total_lines: 2, + returned_lines: 0, + has_more_pages: false, + }, + 5, + false, + ); + + assert_eq!(response.headers().get("X-Last-Line").unwrap(), "5"); + assert_eq!(response.headers().get("X-Incomplete").unwrap(), "false"); + } } From 5cc3abad1724484f10dccdba34f8084cc6292de3 Mon Sep 17 00:00:00 2001 From: Josh Sephton Date: Tue, 7 Apr 2026 14:39:20 +0100 Subject: [PATCH 7/9] feat: use aws creds chain if present --- dwctl/src/blob_storage.rs | 91 +++++++++++++++++++++---- dwctl/src/config.rs | 139 +++++++++++++++++++++++++++++++++++--- 2 files changed, 207 insertions(+), 23 deletions(-) diff --git a/dwctl/src/blob_storage.rs b/dwctl/src/blob_storage.rs index 721b539b0..0e83cd547 100644 --- a/dwctl/src/blob_storage.rs +++ b/dwctl/src/blob_storage.rs @@ -25,26 +25,21 @@ impl BlobStorageClient { ObjectStoreProvider::S3Compatible => {} } - let creds = Credentials::new( - config.access_key_id.clone(), - config.secret_access_key.clone(), - config.session_token.clone(), - None, - "dwctl-object-store", - ); - let timeout_config = TimeoutConfig::builder() .connect_timeout(Duration::from_millis(config.connect_timeout_ms)) .operation_timeout(Duration::from_millis(config.request_timeout_ms)) .build(); - let sdk_config = aws_config::defaults(BehaviorVersion::latest()) + let mut sdk_config = aws_config::defaults(BehaviorVersion::latest()) .region(RegionProviderChain::first_try(Region::new(config.region.clone()))) - .credentials_provider(SharedCredentialsProvider::new(creds)) .endpoint_url(config.endpoint.clone()) - .timeout_config(timeout_config) - .load() - .await; + .timeout_config(timeout_config); + + if let Some(creds) = static_credentials(config) { + sdk_config = sdk_config.credentials_provider(SharedCredentialsProvider::new(creds)); + } + + let sdk_config = sdk_config.load().await; let s3_config = aws_sdk_s3::config::Builder::from(&sdk_config) .force_path_style(config.path_style) @@ -191,3 +186,73 @@ impl BlobStorageClient { Ok(()) } } + +fn static_credentials(config: &ObjectStoreConfig) -> Option { + let access_key_id = config.access_key_id.as_deref().map(str::trim).filter(|s| !s.is_empty())?; + let secret_access_key = config.secret_access_key.as_deref().map(str::trim).filter(|s| !s.is_empty())?; + let session_token = config + .session_token + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()) + .map(ToOwned::to_owned); + + Some(Credentials::new( + access_key_id.to_owned(), + secret_access_key.to_owned(), + session_token, + None, + "dwctl-object-store", + )) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::config::ObjectStoreProvider; + + fn object_store_config() -> ObjectStoreConfig { + ObjectStoreConfig { + provider: ObjectStoreProvider::S3Compatible, + endpoint: "http://localhost:9000".to_string(), + bucket: "bucket".to_string(), + region: "us-east-1".to_string(), + access_key_id: None, + secret_access_key: None, + session_token: None, + path_style: true, + prefix: "uploads/".to_string(), + connect_timeout_ms: 1000, + request_timeout_ms: 1000, + } + } + + #[test] + fn static_credentials_none_without_static_keys() { + assert!(static_credentials(&object_store_config()).is_none()); + } + + #[test] + fn static_credentials_build_when_keys_present() { + let mut config = object_store_config(); + config.access_key_id = Some("key".to_string()); + config.secret_access_key = Some("secret".to_string()); + config.session_token = Some("token".to_string()); + + let creds = static_credentials(&config).expect("static credentials should be built"); + + assert_eq!(creds.access_key_id(), "key"); + assert_eq!(creds.secret_access_key(), "secret"); + assert_eq!(creds.session_token(), Some("token")); + } + + #[test] + fn static_credentials_ignore_blank_values() { + let mut config = object_store_config(); + config.access_key_id = Some(" ".to_string()); + config.secret_access_key = Some("secret".to_string()); + + assert!(static_credentials(&config).is_none()); + } +} diff --git a/dwctl/src/config.rs b/dwctl/src/config.rs index e14a08669..7271ebc60 100644 --- a/dwctl/src/config.rs +++ b/dwctl/src/config.rs @@ -836,8 +836,13 @@ pub struct ObjectStoreConfig { pub endpoint: String, pub bucket: String, pub region: String, - pub access_key_id: String, - pub secret_access_key: String, + /// Optional static access key. When omitted, the AWS SDK default credential chain is used. + #[serde(skip_serializing_if = "Option::is_none")] + pub access_key_id: Option, + /// Optional static secret key. When omitted, the AWS SDK default credential chain is used. + #[serde(skip_serializing_if = "Option::is_none")] + pub secret_access_key: Option, + /// Optional session token for static credentials. #[serde(skip_serializing_if = "Option::is_none")] pub session_token: Option, pub path_style: bool, @@ -853,8 +858,8 @@ impl Default for ObjectStoreConfig { endpoint: String::new(), bucket: String::new(), region: "us-east-1".to_string(), - access_key_id: String::new(), - secret_access_key: String::new(), + access_key_id: None, + secret_access_key: None, session_token: None, path_style: true, prefix: "uploads/".to_string(), @@ -2028,14 +2033,22 @@ impl Config { operation: "Config validation: batches.files.object_store.region cannot be empty.".to_string(), }); } - if store.access_key_id.trim().is_empty() { + + let access_key_id = store.access_key_id.as_deref().map(str::trim).filter(|s| !s.is_empty()); + let secret_access_key = store.secret_access_key.as_deref().map(str::trim).filter(|s| !s.is_empty()); + + if access_key_id.is_some() != secret_access_key.is_some() { return Err(Error::Internal { - operation: "Config validation: batches.files.object_store.access_key_id cannot be empty.".to_string(), + operation: "Config validation: batches.files.object_store.access_key_id and secret_access_key must both be set when using static object-store credentials.".to_string(), }); } - if store.secret_access_key.trim().is_empty() { + + let session_token = store.session_token.as_deref().map(str::trim).filter(|s| !s.is_empty()); + if session_token.is_some() && access_key_id.is_none() { return Err(Error::Internal { - operation: "Config validation: batches.files.object_store.secret_access_key cannot be empty.".to_string(), + operation: + "Config validation: batches.files.object_store.session_token requires access_key_id and secret_access_key." + .to_string(), }); } } @@ -2395,8 +2408,8 @@ secret_key: "test-secret-key" endpoint: "".to_string(), bucket: "bucket".to_string(), region: "us-east-1".to_string(), - access_key_id: "key".to_string(), - secret_access_key: "secret".to_string(), + access_key_id: Some("key".to_string()), + secret_access_key: Some("secret".to_string()), ..Default::default() }); @@ -2405,6 +2418,112 @@ secret_key: "test-secret-key" assert!(result.unwrap_err().to_string().contains("object_store.endpoint cannot be empty")); } + #[test] + fn test_object_store_allows_default_credential_chain() { + let mut config = Config::default(); + config.auth.native.enabled = true; + config.secret_key = Some("test-secret-key".to_string()); + config.batches.enabled = true; + config.batches.files.storage_backend = FileStorageBackend::ObjectStore; + config.batches.files.object_store = Some(ObjectStoreConfig { + endpoint: "http://localhost:9000".to_string(), + bucket: "bucket".to_string(), + region: "us-east-1".to_string(), + access_key_id: None, + secret_access_key: None, + session_token: None, + ..Default::default() + }); + + assert!(config.validate().is_ok()); + } + + #[test] + fn test_object_store_rejects_partial_static_credentials() { + let mut config = Config::default(); + config.auth.native.enabled = true; + config.secret_key = Some("test-secret-key".to_string()); + config.batches.enabled = true; + config.batches.files.storage_backend = FileStorageBackend::ObjectStore; + config.batches.files.object_store = Some(ObjectStoreConfig { + endpoint: "http://localhost:9000".to_string(), + bucket: "bucket".to_string(), + region: "us-east-1".to_string(), + access_key_id: Some("key".to_string()), + secret_access_key: None, + ..Default::default() + }); + + let result = config.validate(); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("access_key_id and secret_access_key must both be set") + ); + } + + #[test] + fn test_object_store_session_token_requires_static_credentials() { + let mut config = Config::default(); + config.auth.native.enabled = true; + config.secret_key = Some("test-secret-key".to_string()); + config.batches.enabled = true; + config.batches.files.storage_backend = FileStorageBackend::ObjectStore; + config.batches.files.object_store = Some(ObjectStoreConfig { + endpoint: "http://localhost:9000".to_string(), + bucket: "bucket".to_string(), + region: "us-east-1".to_string(), + session_token: Some("token".to_string()), + ..Default::default() + }); + + let result = config.validate(); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("session_token requires access_key_id and secret_access_key") + ); + } + + #[test] + fn test_object_store_static_credentials_env_override() { + Jail::expect_with(|jail| { + jail.create_file( + "test.yaml", + r#" +secret_key: hello +batches: + enabled: true + files: + storage_backend: object_store + object_store: + endpoint: http://localhost:9000 + bucket: bucket + region: us-east-1 +"#, + )?; + jail.set_env("DWCTL_BATCHES__FILES__OBJECT_STORE__ACCESS_KEY_ID", "env-key"); + jail.set_env("DWCTL_BATCHES__FILES__OBJECT_STORE__SECRET_ACCESS_KEY", "env-secret"); + + let args = Args { + config: "test.yaml".into(), + validate: false, + }; + + let config = Config::load(&args)?; + let object_store = config.batches.files.object_store.expect("object_store config should be loaded"); + + assert_eq!(object_store.access_key_id.as_deref(), Some("env-key")); + assert_eq!(object_store.secret_access_key.as_deref(), Some("env-secret")); + + Ok(()) + }); + } + #[test] fn test_download_buffer_size_zero_validation() { let mut config = Config::default(); From 9737cfef3183dac5229b15d6119eb7c7957c54aa Mon Sep 17 00:00:00 2001 From: Josh Sephton Date: Tue, 7 Apr 2026 16:08:59 +0100 Subject: [PATCH 8/9] fix: s3 and s3-compatible --- dwctl/src/blob_storage.rs | 142 ++++++++++++++++++++++++++++++-------- dwctl/src/config.rs | 87 +++++++++++++++++++---- 2 files changed, 186 insertions(+), 43 deletions(-) diff --git a/dwctl/src/blob_storage.rs b/dwctl/src/blob_storage.rs index 0e83cd547..ce3897613 100644 --- a/dwctl/src/blob_storage.rs +++ b/dwctl/src/blob_storage.rs @@ -17,14 +17,11 @@ pub struct BlobStorageClient { client: Client, bucket: String, prefix: String, + provider: ObjectStoreProvider, } impl BlobStorageClient { pub async fn new(config: &ObjectStoreConfig) -> Result { - match config.provider { - ObjectStoreProvider::S3Compatible => {} - } - let timeout_config = TimeoutConfig::builder() .connect_timeout(Duration::from_millis(config.connect_timeout_ms)) .operation_timeout(Duration::from_millis(config.request_timeout_ms)) @@ -32,9 +29,12 @@ impl BlobStorageClient { let mut sdk_config = aws_config::defaults(BehaviorVersion::latest()) .region(RegionProviderChain::first_try(Region::new(config.region.clone()))) - .endpoint_url(config.endpoint.clone()) .timeout_config(timeout_config); + if let Some(endpoint) = normalized_endpoint(config) { + sdk_config = sdk_config.endpoint_url(endpoint.to_owned()); + } + if let Some(creds) = static_credentials(config) { sdk_config = sdk_config.credentials_provider(SharedCredentialsProvider::new(creds)); } @@ -42,13 +42,14 @@ impl BlobStorageClient { let sdk_config = sdk_config.load().await; let s3_config = aws_sdk_s3::config::Builder::from(&sdk_config) - .force_path_style(config.path_style) + .force_path_style(effective_path_style(config)) .build(); Ok(Self { client: Client::from_conf(s3_config), bucket: config.bucket.clone(), prefix: config.prefix.clone(), + provider: config.provider, }) } @@ -71,9 +72,7 @@ impl BlobStorageClient { .body(body) .send() .await - .map_err(|e| Error::Internal { - operation: format!("put object to blob storage: {e}"), - })?; + .map_err(|e| blob_storage_error("put object to blob storage", self, key, e))?; Ok(()) } @@ -85,9 +84,7 @@ impl BlobStorageClient { .key(key) .send() .await - .map_err(|e| Error::Internal { - operation: format!("get object from blob storage: {e}"), - })?; + .map_err(|e| blob_storage_error("get object from blob storage", self, key, e))?; let bytes = obj.body.collect().await.map_err(|e| Error::Internal { operation: format!("read blob object body: {e}"), @@ -103,20 +100,14 @@ impl BlobStorageClient { .key(key) .send() .await - .map_err(|e| Error::Internal { - operation: format!("get object from blob storage: {e}"), - }) + .map_err(|e| blob_storage_error("get object from blob storage", self, key, e)) } pub async fn get_file_bytes_if_exists(&self, key: &str) -> Result>> { let obj = match self.client.get_object().bucket(&self.bucket).key(key).send().await { Ok(obj) => obj, Err(SdkError::ServiceError(err)) if err.err().is_no_such_key() => return Ok(None), - Err(e) => { - return Err(Error::Internal { - operation: format!("get object from blob storage: {e}"), - }); - } + Err(e) => return Err(blob_storage_error("get object from blob storage", self, key, e)), }; let bytes = obj.body.collect().await.map_err(|e| Error::Internal { @@ -135,9 +126,7 @@ impl BlobStorageClient { .body(ByteStream::from(bytes)) .send() .await - .map_err(|e| Error::Internal { - operation: format!("put object to blob storage: {e}"), - })?; + .map_err(|e| blob_storage_error("put object to blob storage", self, key, e))?; Ok(()) } @@ -148,9 +137,7 @@ impl BlobStorageClient { .key(key) .send() .await - .map_err(|e| Error::Internal { - operation: format!("delete object from blob storage: {e}"), - })?; + .map_err(|e| blob_storage_error("delete object from blob storage", self, key, e))?; Ok(()) } @@ -167,7 +154,12 @@ impl BlobStorageClient { .send() .await .map_err(|e| Error::Internal { - operation: format!("list objects from blob storage: {e}"), + operation: format!( + "list objects from blob storage (bucket={}, prefix={}): {}", + self.bucket, + prefix, + classify_sdk_error(&e) + ), })?; for object in response.contents() { @@ -206,6 +198,53 @@ fn static_credentials(config: &ObjectStoreConfig) -> Option { )) } +fn normalized_endpoint(config: &ObjectStoreConfig) -> Option<&str> { + config.endpoint.as_deref().map(str::trim).filter(|s| !s.is_empty()) +} + +fn effective_path_style(config: &ObjectStoreConfig) -> bool { + config.path_style.unwrap_or(match config.provider { + ObjectStoreProvider::AwsS3 => false, + ObjectStoreProvider::S3Compatible => true, + }) +} + +fn provider_name(provider: ObjectStoreProvider) -> &'static str { + match provider { + ObjectStoreProvider::AwsS3 => "aws_s3", + ObjectStoreProvider::S3Compatible => "s3_compatible", + } +} + +fn classify_sdk_error(error: &SdkError) -> String +where + E: std::fmt::Debug, +{ + match error { + SdkError::ConstructionFailure(e) => format!("construction failure: {e:?}"), + SdkError::TimeoutError(e) => format!("timeout: {e:?}"), + SdkError::DispatchFailure(e) => format!("dispatch failure: {e:?}"), + SdkError::ResponseError(e) => format!("response error: {e:?}"), + SdkError::ServiceError(e) => format!("service error: {:?}", e.err()), + _ => error.to_string(), + } +} + +fn blob_storage_error(operation: &str, client: &BlobStorageClient, key: &str, error: SdkError) -> Error +where + E: std::fmt::Debug, +{ + Error::Internal { + operation: format!( + "{operation} (bucket={}, key={}, provider={}): {}", + client.bucket, + key, + provider_name(client.provider), + classify_sdk_error(&error) + ), + } +} + #[cfg(test)] mod tests { use super::*; @@ -215,13 +254,13 @@ mod tests { fn object_store_config() -> ObjectStoreConfig { ObjectStoreConfig { provider: ObjectStoreProvider::S3Compatible, - endpoint: "http://localhost:9000".to_string(), + endpoint: Some("http://localhost:9000".to_string()), bucket: "bucket".to_string(), region: "us-east-1".to_string(), access_key_id: None, secret_access_key: None, session_token: None, - path_style: true, + path_style: Some(true), prefix: "uploads/".to_string(), connect_timeout_ms: 1000, request_timeout_ms: 1000, @@ -255,4 +294,49 @@ mod tests { assert!(static_credentials(&config).is_none()); } + + #[test] + fn effective_path_style_defaults_to_false_for_aws_s3() { + let config = ObjectStoreConfig { + provider: ObjectStoreProvider::AwsS3, + endpoint: None, + path_style: None, + ..object_store_config() + }; + + assert!(!effective_path_style(&config)); + } + + #[test] + fn effective_path_style_defaults_to_true_for_s3_compatible() { + let config = ObjectStoreConfig { + provider: ObjectStoreProvider::S3Compatible, + endpoint: Some("http://localhost:9000".to_string()), + path_style: None, + ..object_store_config() + }; + + assert!(effective_path_style(&config)); + } + + #[test] + fn effective_path_style_respects_explicit_override() { + let config = ObjectStoreConfig { + provider: ObjectStoreProvider::AwsS3, + path_style: Some(true), + ..object_store_config() + }; + + assert!(effective_path_style(&config)); + } + + #[test] + fn normalized_endpoint_ignores_blank_values() { + let config = ObjectStoreConfig { + endpoint: Some(" ".to_string()), + ..object_store_config() + }; + + assert!(normalized_endpoint(&config).is_none()); + } } diff --git a/dwctl/src/config.rs b/dwctl/src/config.rs index 7271ebc60..db7493eab 100644 --- a/dwctl/src/config.rs +++ b/dwctl/src/config.rs @@ -833,7 +833,10 @@ pub enum FileStorageBackend { #[serde(default, deny_unknown_fields)] pub struct ObjectStoreConfig { pub provider: ObjectStoreProvider, - pub endpoint: String, + /// Optional custom endpoint URL. + /// Required for `s3_compatible`; usually omitted for `aws_s3`. + #[serde(skip_serializing_if = "Option::is_none")] + pub endpoint: Option, pub bucket: String, pub region: String, /// Optional static access key. When omitted, the AWS SDK default credential chain is used. @@ -845,7 +848,10 @@ pub struct ObjectStoreConfig { /// Optional session token for static credentials. #[serde(skip_serializing_if = "Option::is_none")] pub session_token: Option, - pub path_style: bool, + /// Optional override for bucket addressing mode. + /// Defaults to `false` for `aws_s3` and `true` for `s3_compatible`. + #[serde(skip_serializing_if = "Option::is_none")] + pub path_style: Option, pub prefix: String, pub connect_timeout_ms: u64, pub request_timeout_ms: u64, @@ -855,13 +861,13 @@ impl Default for ObjectStoreConfig { fn default() -> Self { Self { provider: ObjectStoreProvider::S3Compatible, - endpoint: String::new(), + endpoint: None, bucket: String::new(), region: "us-east-1".to_string(), access_key_id: None, secret_access_key: None, session_token: None, - path_style: true, + path_style: None, prefix: "uploads/".to_string(), connect_timeout_ms: 5000, request_timeout_ms: 120000, @@ -869,9 +875,10 @@ impl Default for ObjectStoreConfig { } } -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum ObjectStoreProvider { + AwsS3, S3Compatible, } @@ -2018,11 +2025,6 @@ impl Config { operation: "Config validation: batches.files.object_store must be set when storage_backend=object_store.".to_string(), })?; - if store.endpoint.trim().is_empty() { - return Err(Error::Internal { - operation: "Config validation: batches.files.object_store.endpoint cannot be empty.".to_string(), - }); - } if store.bucket.trim().is_empty() { return Err(Error::Internal { operation: "Config validation: batches.files.object_store.bucket cannot be empty.".to_string(), @@ -2034,6 +2036,20 @@ impl Config { }); } + let endpoint = store.endpoint.as_deref().map(str::trim).filter(|s| !s.is_empty()); + if store.endpoint.is_some() && endpoint.is_none() { + return Err(Error::Internal { + operation: "Config validation: batches.files.object_store.endpoint cannot be empty.".to_string(), + }); + } + + if store.provider == ObjectStoreProvider::S3Compatible && endpoint.is_none() { + return Err(Error::Internal { + operation: "Config validation: batches.files.object_store.endpoint must be set for provider=s3_compatible." + .to_string(), + }); + } + let access_key_id = store.access_key_id.as_deref().map(str::trim).filter(|s| !s.is_empty()); let secret_access_key = store.secret_access_key.as_deref().map(str::trim).filter(|s| !s.is_empty()); @@ -2405,7 +2421,7 @@ secret_key: "test-secret-key" config.batches.enabled = true; config.batches.files.storage_backend = FileStorageBackend::ObjectStore; config.batches.files.object_store = Some(ObjectStoreConfig { - endpoint: "".to_string(), + endpoint: Some("".to_string()), bucket: "bucket".to_string(), region: "us-east-1".to_string(), access_key_id: Some("key".to_string()), @@ -2426,7 +2442,7 @@ secret_key: "test-secret-key" config.batches.enabled = true; config.batches.files.storage_backend = FileStorageBackend::ObjectStore; config.batches.files.object_store = Some(ObjectStoreConfig { - endpoint: "http://localhost:9000".to_string(), + endpoint: Some("http://localhost:9000".to_string()), bucket: "bucket".to_string(), region: "us-east-1".to_string(), access_key_id: None, @@ -2438,6 +2454,49 @@ secret_key: "test-secret-key" assert!(config.validate().is_ok()); } + #[test] + fn test_aws_s3_allows_missing_endpoint() { + let mut config = Config::default(); + config.auth.native.enabled = true; + config.secret_key = Some("test-secret-key".to_string()); + config.batches.enabled = true; + config.batches.files.storage_backend = FileStorageBackend::ObjectStore; + config.batches.files.object_store = Some(ObjectStoreConfig { + provider: ObjectStoreProvider::AwsS3, + endpoint: None, + bucket: "bucket".to_string(), + region: "eu-west-2".to_string(), + ..Default::default() + }); + + assert!(config.validate().is_ok()); + } + + #[test] + fn test_s3_compatible_requires_endpoint() { + let mut config = Config::default(); + config.auth.native.enabled = true; + config.secret_key = Some("test-secret-key".to_string()); + config.batches.enabled = true; + config.batches.files.storage_backend = FileStorageBackend::ObjectStore; + config.batches.files.object_store = Some(ObjectStoreConfig { + provider: ObjectStoreProvider::S3Compatible, + endpoint: None, + bucket: "bucket".to_string(), + region: "us-east-1".to_string(), + ..Default::default() + }); + + let result = config.validate(); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("endpoint must be set for provider=s3_compatible") + ); + } + #[test] fn test_object_store_rejects_partial_static_credentials() { let mut config = Config::default(); @@ -2446,7 +2505,7 @@ secret_key: "test-secret-key" config.batches.enabled = true; config.batches.files.storage_backend = FileStorageBackend::ObjectStore; config.batches.files.object_store = Some(ObjectStoreConfig { - endpoint: "http://localhost:9000".to_string(), + endpoint: Some("http://localhost:9000".to_string()), bucket: "bucket".to_string(), region: "us-east-1".to_string(), access_key_id: Some("key".to_string()), @@ -2472,7 +2531,7 @@ secret_key: "test-secret-key" config.batches.enabled = true; config.batches.files.storage_backend = FileStorageBackend::ObjectStore; config.batches.files.object_store = Some(ObjectStoreConfig { - endpoint: "http://localhost:9000".to_string(), + endpoint: Some("http://localhost:9000".to_string()), bucket: "bucket".to_string(), region: "us-east-1".to_string(), session_token: Some("token".to_string()), From 5c642a9c84db6f68f5910e69fba263a2d057f0dd Mon Sep 17 00:00:00 2001 From: Josh Sephton Date: Tue, 7 Apr 2026 16:58:32 +0100 Subject: [PATCH 9/9] feat: surface processing errors to user in dashboard --- dashboard/src/api/control-layer/types.ts | 5 + .../features/batches/FilesTable/columns.tsx | 81 +++++---- dwctl/src/api/handlers/batches.rs | 9 +- dwctl/src/api/handlers/files.rs | 172 +++++++++++++++++- dwctl/src/api/models/files.rs | 4 + 5 files changed, 229 insertions(+), 42 deletions(-) diff --git a/dashboard/src/api/control-layer/types.ts b/dashboard/src/api/control-layer/types.ts index f71e0e03c..1dc7a1e1d 100644 --- a/dashboard/src/api/control-layer/types.ts +++ b/dashboard/src/api/control-layer/types.ts @@ -1042,6 +1042,11 @@ export interface FileObject { | "vision" | "user_data" | "evals"; + status: + | "uploaded" + | "processed" + | "error"; + status_details: string; /** Email of the individual who created this file */ created_by_email?: string; /** "Personal" or org name */ diff --git a/dashboard/src/components/features/batches/FilesTable/columns.tsx b/dashboard/src/components/features/batches/FilesTable/columns.tsx index ae0892c53..e97422a09 100644 --- a/dashboard/src/components/features/batches/FilesTable/columns.tsx +++ b/dashboard/src/components/features/batches/FilesTable/columns.tsx @@ -213,6 +213,8 @@ export const createFileColumns = ( size: 200, // Constrain the actions column width cell: ({ row }) => { const file = row.original; + const viewDisabled = file.status === "error"; + const disabledReason = file.status_details || "This file cannot be viewed."; // Disabled for now - expiration not yet enforced on backend // const isExpired = // file.expires_at && new Date(file.expires_at * 1000) < new Date(); @@ -240,53 +242,68 @@ export const createFileColumns = ( )} - - - View Requests - - {file.purpose === "batch" && ( - - + + + + + {viewDisabled ? disabledReason : "View Requests"} + + + {file.purpose === "batch" && ( + + + + + - View Batches + + {viewDisabled ? disabledReason : "View Batches"} + )} - + + + - Download File + + {viewDisabled ? disabledReason : "Download File"} + diff --git a/dwctl/src/api/handlers/batches.rs b/dwctl/src/api/handlers/batches.rs index 275c16566..15485f131 100644 --- a/dwctl/src/api/handlers/batches.rs +++ b/dwctl/src/api/handlers/batches.rs @@ -356,19 +356,14 @@ pub async fn create_batch( // If asynchronous file ingestion is enabled, gate batch creation on ingest state. if let Some(status) = crate::api::handlers::files::get_file_ingest_status(state.db.write(), file_id).await? { - match status.as_str() { + match status.status.as_str() { "pending" | "processing" => { return Err(Error::BadRequest { message: "File is still being processed. Please retry shortly.".to_string(), }); } "failed" => { - let msg = sqlx::query_scalar::<_, Option>("SELECT error_message FROM file_ingest_jobs WHERE file_id = $1") - .bind(file_id) - .fetch_one(state.db.write()) - .await - .map_err(|e| Error::Database(e.into()))? - .unwrap_or_else(|| "File ingestion failed".to_string()); + let msg = status.error_message.unwrap_or_else(|| "File ingestion failed".to_string()); return Err(Error::BadRequest { message: msg }); } _ => {} diff --git a/dwctl/src/api/handlers/files.rs b/dwctl/src/api/handlers/files.rs index ea19392a9..2a672ca33 100644 --- a/dwctl/src/api/handlers/files.rs +++ b/dwctl/src/api/handlers/files.rs @@ -43,6 +43,7 @@ use futures::StreamExt; use futures::stream::Stream; use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; +use sqlx::Row; use std::collections::HashMap; use std::pin::Pin; use std::sync::{Arc, Mutex}; @@ -441,12 +442,75 @@ pub struct IngestFileInput { pub size_bytes: i64, } -pub(crate) async fn get_file_ingest_status(pool: &sqlx::PgPool, file_id: Uuid) -> Result> { - sqlx::query_scalar::<_, String>("SELECT status FROM file_ingest_jobs WHERE file_id = $1") +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct IngestJobState { + pub(crate) status: String, + pub(crate) error_message: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct FileApiStatus { + status: String, + status_details: String, +} + +pub(crate) async fn get_file_ingest_status(pool: &sqlx::PgPool, file_id: Uuid) -> Result> { + let row = sqlx::query("SELECT status, error_message FROM file_ingest_jobs WHERE file_id = $1") .bind(file_id) .fetch_optional(pool) .await - .map_err(|e| Error::Database(e.into())) + .map_err(|e| Error::Database(e.into()))?; + + Ok(row.map(|row| IngestJobState { + status: row.get("status"), + error_message: row.get("error_message"), + })) +} + +async fn get_file_ingest_statuses(pool: &sqlx::PgPool, file_ids: &[Uuid]) -> Result> { + if file_ids.is_empty() { + return Ok(HashMap::new()); + } + + let rows = sqlx::query("SELECT file_id, status, error_message FROM file_ingest_jobs WHERE file_id = ANY($1)") + .bind(file_ids) + .fetch_all(pool) + .await + .map_err(|e| Error::Database(e.into()))?; + + Ok(rows + .into_iter() + .map(|row| { + ( + row.get("file_id"), + IngestJobState { + status: row.get("status"), + error_message: row.get("error_message"), + }, + ) + }) + .collect()) +} + +fn map_ingest_job_to_file_status(ingest_status: Option<&IngestJobState>) -> FileApiStatus { + match ingest_status { + Some(IngestJobState { status, .. }) if status == "pending" || status == "processing" => FileApiStatus { + status: "uploaded".to_string(), + status_details: String::new(), + }, + Some(IngestJobState { status, .. }) if status == "processed" => FileApiStatus { + status: "processed".to_string(), + status_details: String::new(), + }, + Some(IngestJobState { status, error_message }) if status == "failed" => FileApiStatus { + status: "error".to_string(), + status_details: error_message.clone().unwrap_or_default(), + }, + _ => FileApiStatus { + status: "processed".to_string(), + status_details: String::new(), + }, + } } async fn set_file_ingest_status( @@ -1254,6 +1318,7 @@ pub async fn upload_file( Some(fusillade::batch::Purpose::BatchError) => Purpose::BatchError, None => Purpose::Batch, // Default to Batch for backwards compatibility }; + let file_status = map_ingest_job_to_file_status(get_file_ingest_status(state.db.read(), created_file_id.0).await?.as_ref()); Ok(( StatusCode::CREATED, @@ -1264,6 +1329,8 @@ pub async fn upload_file( created_at: file.created_at.timestamp(), filename: file.name, purpose: api_purpose, + status: file_status.status, + status_details: file_status.status_details, expires_at: file.expires_at.map(|dt| dt.timestamp()), created_by_email: None, context_name: None, @@ -1403,10 +1470,12 @@ pub async fn list_files( let first_id = files.first().map(|f| f.id.0.to_string()); let last_id = files.last().map(|f| f.id.0.to_string()); + let file_ids: Vec = files.iter().map(|f| f.id.0).collect(); // Resolve creator/context metadata for all returned files. // Uses a fresh connection (not held across the fusillade call above). let mut read_conn = state.db.read().acquire().await.map_err(|e| Error::Database(e.into()))?; + let ingest_statuses = get_file_ingest_statuses(state.db.read(), &file_ids).await?; // Resolve individual creators via api_key_id → api_keys.created_by let api_key_ids: Vec = files @@ -1475,6 +1544,7 @@ pub async fn list_files( Some(_) => (Some("Personal".to_string()), Some("personal".to_string())), None => (None, None), }; + let file_status = map_ingest_job_to_file_status(ingest_statuses.get(&f.id.0)); FileResponse { id: f.id.0.to_string(), @@ -1483,6 +1553,8 @@ pub async fn list_files( created_at: f.created_at.timestamp(), filename: f.name.clone(), purpose: api_purpose, + status: file_status.status, + status_details: file_status.status_details, expires_at: f.expires_at.map(|dt| dt.timestamp()), created_by_email, context_name, @@ -1551,6 +1623,8 @@ pub async fn get_file( Some(fusillade::batch::Purpose::BatchError) => Purpose::BatchError, None => Purpose::Batch, // Default to Batch for backwards compatibility }; + let ingest_status = get_file_ingest_status(state.db.read(), file_id).await?; + let file_status = map_ingest_job_to_file_status(ingest_status.as_ref()); // Enrich with creator/context metadata (same as list_files) let mut read_conn = state.db.read().acquire().await.map_err(|e| Error::Database(e.into()))?; @@ -1604,6 +1678,8 @@ pub async fn get_file( created_at: file.created_at.timestamp(), filename: file.name, purpose: api_purpose, + status: file_status.status, + status_details: file_status.status_details, expires_at: file.expires_at.map(|dt| dt.timestamp()), created_by_email, context_name, @@ -2145,6 +2221,41 @@ mod tests { upload_response.json() } + #[test] + fn test_map_ingest_job_to_file_status() { + let pending = map_ingest_job_to_file_status(Some(&IngestJobState { + status: "pending".to_string(), + error_message: None, + })); + assert_eq!(pending.status, "uploaded"); + assert_eq!(pending.status_details, ""); + + let processing = map_ingest_job_to_file_status(Some(&IngestJobState { + status: "processing".to_string(), + error_message: None, + })); + assert_eq!(processing.status, "uploaded"); + assert_eq!(processing.status_details, ""); + + let processed = map_ingest_job_to_file_status(Some(&IngestJobState { + status: "processed".to_string(), + error_message: None, + })); + assert_eq!(processed.status, "processed"); + assert_eq!(processed.status_details, ""); + + let failed = map_ingest_job_to_file_status(Some(&IngestJobState { + status: "failed".to_string(), + error_message: Some("boom".to_string()), + })); + assert_eq!(failed.status, "error"); + assert_eq!(failed.status_details, "boom"); + + let fallback = map_ingest_job_to_file_status(None); + assert_eq!(fallback.status, "processed"); + assert_eq!(fallback.status_details, ""); + } + async fn create_batch_for_file( app: &axum_test::TestServer, user: &crate::api::models::users::UserResponse, @@ -2521,6 +2632,8 @@ mod tests { // Should succeed upload_response.assert_status(axum::http::StatusCode::CREATED); let file: FileResponse = upload_response.json(); + assert_eq!(file.status, "processed"); + assert_eq!(file.status_details, ""); // Verify the file was created - now let's check if metadata was captured // We need to query the database or fusillade to verify the metadata was stored @@ -2534,6 +2647,57 @@ mod tests { get_response.assert_status(axum::http::StatusCode::OK); let retrieved_file: FileResponse = get_response.json(); assert_eq!(retrieved_file.purpose, crate::api::models::files::Purpose::Batch); + assert_eq!(retrieved_file.status, "processed"); + assert_eq!(retrieved_file.status_details, ""); + } + + #[sqlx::test] + #[test_log::test] + async fn test_get_file_uses_failed_ingest_job_status(pool: PgPool) { + let (app, _bg_services) = create_test_app(pool.clone(), false).await; + let user = create_test_user_with_roles(&pool, vec![Role::StandardUser, Role::BatchAPIUser]).await; + let group = create_test_group(&pool).await; + add_user_to_group(&pool, user.id, group.id).await; + let deployment = create_test_deployment(&pool, user.id, "gpt-4-model", "gpt-4").await; + add_deployment_to_group(&pool, deployment.id, group.id, user.id).await; + + let uploaded = upload_batch_input_file( + &app, + &user, + r#"{"custom_id":"req-1","method":"POST","url":"/v1/chat/completions","body":{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}}"#, + ) + .await; + + sqlx::query( + r#" + INSERT INTO file_ingest_jobs (file_id, object_key, status, error_message, created_at, updated_at) + VALUES ($1, $2, $3, $4, NOW(), NOW()) + ON CONFLICT (file_id) DO UPDATE + SET object_key = EXCLUDED.object_key, + status = EXCLUDED.status, + error_message = EXCLUDED.error_message, + updated_at = NOW() + "#, + ) + .bind(Uuid::parse_str(&uploaded.id).unwrap()) + .bind("uploads/test.jsonl") + .bind("failed") + .bind("line 2: invalid JSON") + .execute(&pool) + .await + .unwrap(); + + let auth = add_auth_headers(&user); + let response = app + .get(&format!("/ai/v1/files/{}", uploaded.id)) + .add_header(&auth[0].0, &auth[0].1) + .add_header(&auth[1].0, &auth[1].1) + .await; + + response.assert_status(axum::http::StatusCode::OK); + let file: FileResponse = response.json(); + assert_eq!(file.status, "error"); + assert_eq!(file.status_details, "line 2: invalid JSON"); } #[sqlx::test] @@ -4418,6 +4582,8 @@ mod tests { let files = body["data"].as_array().unwrap(); assert!(!files.is_empty(), "Expected at least one personal file"); for file in files { + assert!(file.get("status").is_some(), "status should be present"); + assert!(file.get("status_details").is_some(), "status_details should be present"); assert!( file.get("context_name").is_some() && !file["context_name"].is_null(), "context_name should be present even in personal context" diff --git a/dwctl/src/api/models/files.rs b/dwctl/src/api/models/files.rs index 952c6fb22..27f070d5a 100644 --- a/dwctl/src/api/models/files.rs +++ b/dwctl/src/api/models/files.rs @@ -78,6 +78,10 @@ pub struct FileResponse { #[schema(example = "batch_requests.jsonl")] pub filename: String, pub purpose: Purpose, + #[schema(example = "processed")] + pub status: String, + #[schema(example = "")] + pub status_details: String, #[serde(skip_serializing_if = "Option::is_none")] pub expires_at: Option, // Unix timestamp