From f644892b7d67aaf7897af1c2bdae069a33f44afe Mon Sep 17 00:00:00 2001 From: ordinaryorange Date: Tue, 23 Nov 2021 23:00:36 +1030 Subject: [PATCH 1/2] Implementation and tests - non seekable multipart form files --- src/Net/Http.fs | 34 ++++++++++++------- tests/FSharp.Data.Tests/Http.fs | 59 +++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 12 deletions(-) diff --git a/src/Net/Http.fs b/src/Net/Http.fs index a1f70bf28..a9b814946 100644 --- a/src/Net/Http.fs +++ b/src/Net/Http.fs @@ -1313,11 +1313,11 @@ module internal HttpHelpers = let mutable streams = streams |> Seq.cache let rec readFromStream buffer offset count = - if Seq.isEmpty streams - then 0 - else - let stream = Seq.head streams - let read = stream.Read(buffer, offset, min count (int stream.Length)) + match streams |> Seq.tryHead with + | None -> 0 + | Some stream -> + let qty = if stream.CanSeek then min count (int stream.Length) else count + let read = stream.Read(buffer, offset, qty) if read < count then stream.Dispose() @@ -1327,9 +1327,9 @@ module internal HttpHelpers = else read override x.CanRead = true - override x.CanSeek = false + override x.CanSeek = match length with | None -> false | Some _ -> true override x.CanWrite = false - override x.Length with get () = length + override x.Length with get () = length |> Option.defaultWith (fun () -> NotSupportedException() |> raise) override x.Position with get () = v and set(_) = failwith "no position setting" override x.Flush() = () override x.CanTimeout = false @@ -1351,7 +1351,16 @@ module internal HttpHelpers = let writeMultipart (boundary: string) (parts: seq) (e : Encoding) = let newlineStream () = new MemoryStream(e.GetBytes "\r\n") :> Stream let prefixedBoundary = sprintf "--%s" boundary - let segments = parts |> Seq.map (fun (MultipartItem(formField, fileName, fileStream)) -> + let trySumLength streams = //allows seq to be blocking & non seekable + let mutable seekable = true + let mutable length = 0L + let takeIfSeekable (str: Stream) = + seekable <- str.CanSeek + if str.CanSeek then length <- length + str.Length + str.CanSeek + streams |> Seq.takeWhile takeIfSeekable |> List.ofSeq |> ignore + if seekable then Some length else None + let segments = parts |> Seq.map (fun (MultipartItem(formField, fileName, contentStream)) -> let fileExt = Path.GetExtension fileName let contentType = defaultArg (MimeTypes.tryFind fileExt) "application/octet-stream" let printHeader (header, value) = sprintf "%s: %s" header value @@ -1367,9 +1376,9 @@ module internal HttpHelpers = [ headerStream newlineStream() newlineStream() - fileStream + contentStream newlineStream()] - let partLength = partSubstreams |> Seq.sumBy (fun s -> s.Length) + let partLength = partSubstreams |> trySumLength new CombinedStream(partLength, partSubstreams) :> Stream ) @@ -1380,7 +1389,7 @@ module internal HttpHelpers = new MemoryStream(bytes) :> Stream let wholePayload = Seq.append segments [newlineStream(); endBoundaryStream; ] - let wholePayloadLength = wholePayload |> Seq.sumBy (fun s -> s.Length) + let wholePayloadLength = wholePayload |> trySumLength new CombinedStream(wholePayloadLength, wholePayload) :> Stream let asyncCopy (source: Stream) (dest: Stream) = @@ -1393,7 +1402,8 @@ module internal HttpHelpers = let writeBody (req:HttpWebRequest) (data: Stream) = async { - req.ContentLength <- data.Length + if data.CanSeek then + req.ContentLength <- data.Length use! output = req.GetRequestStreamAsync () |> Async.AwaitTask do! asyncCopy data output output.Flush() diff --git a/tests/FSharp.Data.Tests/Http.fs b/tests/FSharp.Data.Tests/Http.fs index 3d79ee8c8..54339fe4c 100644 --- a/tests/FSharp.Data.Tests/Http.fs +++ b/tests/FSharp.Data.Tests/Http.fs @@ -207,3 +207,62 @@ let ``correct multipart content format`` () = let singleMultipartFormat file = sprintf "--%s\r\nContent-Disposition: form-data; name=\"%i\"; filename=\"%i\"\r\nContent-Type: application/octet-stream\r\n\r\n%s\r\n" boundary file file content let finalFormat = [sprintf "\r\n--%s--" boundary] |> Seq.append (seq {for i in [0..numFiles] -> singleMultipartFormat i }) |> String.concat "" str |> should equal finalFormat + +[] +let ``CombinedStream has length with Some length`` () = + use combinedStream = new HttpHelpers.CombinedStream(Some 10L, []) + combinedStream.Length |> should equal 10L + +[] +let ``CombinedStream can seek with Some length`` () = + use combinedStream = new HttpHelpers.CombinedStream(Some 10L, []) + combinedStream.CanSeek |> should equal true + +[] +let ``CombinedStream length throws with None length`` () = + use combinedStream = new HttpHelpers.CombinedStream(None, []) + (fun () -> combinedStream.Length |> ignore) |> should throw typeof + +[] +let ``CombinedStream cannot seek with None length`` () = + use combinedStream = new HttpHelpers.CombinedStream(None, []) + combinedStream.CanSeek |> should equal false + +type nonSeekableStream (b: byte[]) = + inherit IO.MemoryStream(b) + override _.Length with get():Int64 = failwith "Im not seekable" + override _.CanSeek with get() = false + +[] +let ``Non-seekable streams create non-seekable CombinedStream`` () = + use nonSeekms = new nonSeekableStream(Array.zeroCreate 10) + let multiparts = [MultipartItem("","", nonSeekms)] + let combinedStream = HttpHelpers.writeMultipart "-" multiparts Encoding.UTF8 + (fun () -> combinedStream.Length |> ignore) |> should throw typeof + combinedStream.CanSeek |> should equal false + +[] +let ``Seekable streams create Seekable CombinedStream`` () = + let byteLen = 10L + let result = byteLen + 110L //110 is headers + use ms = new IO.MemoryStream(Array.zeroCreate (int byteLen)) + let multiparts = [MultipartItem("","", ms)] + let combinedStream = HttpHelpers.writeMultipart "-" multiparts Encoding.UTF8 + combinedStream.Length |> should equal result + combinedStream.CanSeek |> should equal true + +[] +let ``HttpWebRequest length is set with seekable streams`` () = + use ms = new IO.MemoryStream(Array.zeroCreate 10) + let wr = Net.HttpWebRequest.Create("http://x") :?> Net.HttpWebRequest + wr.Method <- "POST" + HttpHelpers.writeBody wr ms |> Async.RunSynchronously + wr.ContentLength |> should equal 10 + +[] +let ``HttpWebRequest length is not set with non-seekable streams`` () = + use nonSeekms = new nonSeekableStream(Array.zeroCreate 10) + let wr = Net.HttpWebRequest.Create("http://x") :?> Net.HttpWebRequest + wr.Method <- "POST" + HttpHelpers.writeBody wr nonSeekms |> Async.RunSynchronously + wr.ContentLength |> should equal 0 From b66562d128f41d4ef212b4220c0838ec9c5ea52a Mon Sep 17 00:00:00 2001 From: ordinaryorange Date: Tue, 23 Nov 2021 23:00:36 +1030 Subject: [PATCH 2/2] changed exceImplementation and tests - non seekable multipart form files --- src/Net/Http.fs | 34 ++++++++++++------- tests/FSharp.Data.Tests/Http.fs | 59 +++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 12 deletions(-) diff --git a/src/Net/Http.fs b/src/Net/Http.fs index a1f70bf28..0c5aeaa5f 100644 --- a/src/Net/Http.fs +++ b/src/Net/Http.fs @@ -1313,11 +1313,11 @@ module internal HttpHelpers = let mutable streams = streams |> Seq.cache let rec readFromStream buffer offset count = - if Seq.isEmpty streams - then 0 - else - let stream = Seq.head streams - let read = stream.Read(buffer, offset, min count (int stream.Length)) + match streams |> Seq.tryHead with + | None -> 0 + | Some stream -> + let qty = if stream.CanSeek then min count (int stream.Length) else count + let read = stream.Read(buffer, offset, qty) if read < count then stream.Dispose() @@ -1327,9 +1327,9 @@ module internal HttpHelpers = else read override x.CanRead = true - override x.CanSeek = false + override x.CanSeek = match length with | None -> false | Some _ -> true override x.CanWrite = false - override x.Length with get () = length + override x.Length with get () = length |> Option.defaultWith (fun () -> failwith "One or more of the encompassed streams are not seekable and the length cannot be determine") override x.Position with get () = v and set(_) = failwith "no position setting" override x.Flush() = () override x.CanTimeout = false @@ -1351,7 +1351,16 @@ module internal HttpHelpers = let writeMultipart (boundary: string) (parts: seq) (e : Encoding) = let newlineStream () = new MemoryStream(e.GetBytes "\r\n") :> Stream let prefixedBoundary = sprintf "--%s" boundary - let segments = parts |> Seq.map (fun (MultipartItem(formField, fileName, fileStream)) -> + let trySumLength streams = //allows seq to be blocking & non seekable + let mutable seekable = true + let mutable length = 0L + let takeIfSeekable (str: Stream) = + seekable <- str.CanSeek + if str.CanSeek then length <- length + str.Length + str.CanSeek + streams |> Seq.takeWhile takeIfSeekable |> List.ofSeq |> ignore + if seekable then Some length else None + let segments = parts |> Seq.map (fun (MultipartItem(formField, fileName, contentStream)) -> let fileExt = Path.GetExtension fileName let contentType = defaultArg (MimeTypes.tryFind fileExt) "application/octet-stream" let printHeader (header, value) = sprintf "%s: %s" header value @@ -1367,9 +1376,9 @@ module internal HttpHelpers = [ headerStream newlineStream() newlineStream() - fileStream + contentStream newlineStream()] - let partLength = partSubstreams |> Seq.sumBy (fun s -> s.Length) + let partLength = partSubstreams |> trySumLength new CombinedStream(partLength, partSubstreams) :> Stream ) @@ -1380,7 +1389,7 @@ module internal HttpHelpers = new MemoryStream(bytes) :> Stream let wholePayload = Seq.append segments [newlineStream(); endBoundaryStream; ] - let wholePayloadLength = wholePayload |> Seq.sumBy (fun s -> s.Length) + let wholePayloadLength = wholePayload |> trySumLength new CombinedStream(wholePayloadLength, wholePayload) :> Stream let asyncCopy (source: Stream) (dest: Stream) = @@ -1393,7 +1402,8 @@ module internal HttpHelpers = let writeBody (req:HttpWebRequest) (data: Stream) = async { - req.ContentLength <- data.Length + if data.CanSeek then + req.ContentLength <- data.Length use! output = req.GetRequestStreamAsync () |> Async.AwaitTask do! asyncCopy data output output.Flush() diff --git a/tests/FSharp.Data.Tests/Http.fs b/tests/FSharp.Data.Tests/Http.fs index 3d79ee8c8..58682730b 100644 --- a/tests/FSharp.Data.Tests/Http.fs +++ b/tests/FSharp.Data.Tests/Http.fs @@ -207,3 +207,62 @@ let ``correct multipart content format`` () = let singleMultipartFormat file = sprintf "--%s\r\nContent-Disposition: form-data; name=\"%i\"; filename=\"%i\"\r\nContent-Type: application/octet-stream\r\n\r\n%s\r\n" boundary file file content let finalFormat = [sprintf "\r\n--%s--" boundary] |> Seq.append (seq {for i in [0..numFiles] -> singleMultipartFormat i }) |> String.concat "" str |> should equal finalFormat + +[] +let ``CombinedStream has length with Some length`` () = + use combinedStream = new HttpHelpers.CombinedStream(Some 10L, []) + combinedStream.Length |> should equal 10L + +[] +let ``CombinedStream can seek with Some length`` () = + use combinedStream = new HttpHelpers.CombinedStream(Some 10L, []) + combinedStream.CanSeek |> should equal true + +[] +let ``CombinedStream length throws with None length`` () = + use combinedStream = new HttpHelpers.CombinedStream(None, []) + (fun () -> combinedStream.Length |> ignore) |> should throw typeof + +[] +let ``CombinedStream cannot seek with None length`` () = + use combinedStream = new HttpHelpers.CombinedStream(None, []) + combinedStream.CanSeek |> should equal false + +type nonSeekableStream (b: byte[]) = + inherit IO.MemoryStream(b) + override _.Length with get():Int64 = failwith "Im not seekable" + override _.CanSeek with get() = false + +[] +let ``Non-seekable streams create non-seekable CombinedStream`` () = + use nonSeekms = new nonSeekableStream(Array.zeroCreate 10) + let multiparts = [MultipartItem("","", nonSeekms)] + let combinedStream = HttpHelpers.writeMultipart "-" multiparts Encoding.UTF8 + (fun () -> combinedStream.Length |> ignore) |> should throw typeof + combinedStream.CanSeek |> should equal false + +[] +let ``Seekable streams create Seekable CombinedStream`` () = + let byteLen = 10L + let result = byteLen + 110L //110 is headers + use ms = new IO.MemoryStream(Array.zeroCreate (int byteLen)) + let multiparts = [MultipartItem("","", ms)] + let combinedStream = HttpHelpers.writeMultipart "-" multiparts Encoding.UTF8 + combinedStream.Length |> should equal result + combinedStream.CanSeek |> should equal true + +[] +let ``HttpWebRequest length is set with seekable streams`` () = + use ms = new IO.MemoryStream(Array.zeroCreate 10) + let wr = Net.HttpWebRequest.Create("http://x") :?> Net.HttpWebRequest + wr.Method <- "POST" + HttpHelpers.writeBody wr ms |> Async.RunSynchronously + wr.ContentLength |> should equal 10 + +[] +let ``HttpWebRequest length is not set with non-seekable streams`` () = + use nonSeekms = new nonSeekableStream(Array.zeroCreate 10) + let wr = Net.HttpWebRequest.Create("http://x") :?> Net.HttpWebRequest + wr.Method <- "POST" + HttpHelpers.writeBody wr nonSeekms |> Async.RunSynchronously + wr.ContentLength |> should equal 0