Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions src/Net/Http.fs
Original file line number Diff line number Diff line change
Expand Up @@ -1313,11 +1313,11 @@ module internal HttpHelpers =
let mutable streams = streams |> Seq.cache

let rec readFromStream buffer offset count =
if Seq.isEmpty streams
then 0
else
let stream = Seq.head streams
let read = stream.Read(buffer, offset, min count (int stream.Length))
match streams |> Seq.tryHead with
| None -> 0
| Some stream ->
let qty = if stream.CanSeek then min count (int stream.Length) else count
let read = stream.Read(buffer, offset, qty)
if read < count
then
stream.Dispose()
Expand All @@ -1327,9 +1327,9 @@ module internal HttpHelpers =
else read

override x.CanRead = true
override x.CanSeek = false
override x.CanSeek = match length with | None -> false | Some _ -> true
override x.CanWrite = false
override x.Length with get () = length
override x.Length with get () = length |> Option.defaultWith (fun () -> failwith "One or more of the encompassed streams are not seekable and the length cannot be determine")
override x.Position with get () = v and set(_) = failwith "no position setting"
override x.Flush() = ()
override x.CanTimeout = false
Expand All @@ -1351,7 +1351,16 @@ module internal HttpHelpers =
let writeMultipart (boundary: string) (parts: seq<MultipartItem>) (e : Encoding) =
let newlineStream () = new MemoryStream(e.GetBytes "\r\n") :> Stream
let prefixedBoundary = sprintf "--%s" boundary
let segments = parts |> Seq.map (fun (MultipartItem(formField, fileName, fileStream)) ->
let trySumLength streams = //allows seq to be blocking & non seekable
let mutable seekable = true
let mutable length = 0L
let takeIfSeekable (str: Stream) =
seekable <- str.CanSeek
if str.CanSeek then length <- length + str.Length
str.CanSeek
streams |> Seq.takeWhile takeIfSeekable |> List.ofSeq |> ignore
if seekable then Some length else None
let segments = parts |> Seq.map (fun (MultipartItem(formField, fileName, contentStream)) ->
let fileExt = Path.GetExtension fileName
let contentType = defaultArg (MimeTypes.tryFind fileExt) "application/octet-stream"
let printHeader (header, value) = sprintf "%s: %s" header value
Expand All @@ -1367,9 +1376,9 @@ module internal HttpHelpers =
[ headerStream
newlineStream()
newlineStream()
fileStream
contentStream
newlineStream()]
let partLength = partSubstreams |> Seq.sumBy (fun s -> s.Length)
let partLength = partSubstreams |> trySumLength
new CombinedStream(partLength, partSubstreams) :> Stream
)

Expand All @@ -1380,7 +1389,7 @@ module internal HttpHelpers =
new MemoryStream(bytes) :> Stream

let wholePayload = Seq.append segments [newlineStream(); endBoundaryStream; ]
let wholePayloadLength = wholePayload |> Seq.sumBy (fun s -> s.Length)
let wholePayloadLength = wholePayload |> trySumLength
new CombinedStream(wholePayloadLength, wholePayload) :> Stream

let asyncCopy (source: Stream) (dest: Stream) =
Expand All @@ -1393,7 +1402,8 @@ module internal HttpHelpers =

let writeBody (req:HttpWebRequest) (data: Stream) =
async {
req.ContentLength <- data.Length
if data.CanSeek then
req.ContentLength <- data.Length
use! output = req.GetRequestStreamAsync () |> Async.AwaitTask
do! asyncCopy data output
output.Flush()
Expand Down
59 changes: 59 additions & 0 deletions tests/FSharp.Data.Tests/Http.fs
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,62 @@ let ``correct multipart content format`` () =
let singleMultipartFormat file = sprintf "--%s\r\nContent-Disposition: form-data; name=\"%i\"; filename=\"%i\"\r\nContent-Type: application/octet-stream\r\n\r\n%s\r\n" boundary file file content
let finalFormat = [sprintf "\r\n--%s--" boundary] |> Seq.append (seq {for i in [0..numFiles] -> singleMultipartFormat i }) |> String.concat ""
str |> should equal finalFormat

[<Test>]
let ``CombinedStream has length with Some length`` () =
use combinedStream = new HttpHelpers.CombinedStream(Some 10L, [])
combinedStream.Length |> should equal 10L

[<Test>]
let ``CombinedStream can seek with Some length`` () =
use combinedStream = new HttpHelpers.CombinedStream(Some 10L, [])
combinedStream.CanSeek |> should equal true

[<Test>]
let ``CombinedStream length throws with None length`` () =
use combinedStream = new HttpHelpers.CombinedStream(None, [])
(fun () -> combinedStream.Length |> ignore) |> should throw typeof<Exception>

[<Test>]
let ``CombinedStream cannot seek with None length`` () =
use combinedStream = new HttpHelpers.CombinedStream(None, [])
combinedStream.CanSeek |> should equal false

type nonSeekableStream (b: byte[]) =
inherit IO.MemoryStream(b)
override _.Length with get():Int64 = failwith "Im not seekable"
override _.CanSeek with get() = false

[<Test>]
let ``Non-seekable streams create non-seekable CombinedStream`` () =
use nonSeekms = new nonSeekableStream(Array.zeroCreate 10)
let multiparts = [MultipartItem("","", nonSeekms)]
let combinedStream = HttpHelpers.writeMultipart "-" multiparts Encoding.UTF8
(fun () -> combinedStream.Length |> ignore) |> should throw typeof<Exception>
combinedStream.CanSeek |> should equal false

[<Test>]
let ``Seekable streams create Seekable CombinedStream`` () =
let byteLen = 10L
let result = byteLen + 110L //110 is headers
use ms = new IO.MemoryStream(Array.zeroCreate (int byteLen))
let multiparts = [MultipartItem("","", ms)]
let combinedStream = HttpHelpers.writeMultipart "-" multiparts Encoding.UTF8
combinedStream.Length |> should equal result
combinedStream.CanSeek |> should equal true

[<Test>]
let ``HttpWebRequest length is set with seekable streams`` () =
use ms = new IO.MemoryStream(Array.zeroCreate 10)
let wr = Net.HttpWebRequest.Create("http://x") :?> Net.HttpWebRequest
wr.Method <- "POST"
HttpHelpers.writeBody wr ms |> Async.RunSynchronously
wr.ContentLength |> should equal 10

[<Test>]
let ``HttpWebRequest length is not set with non-seekable streams`` () =
use nonSeekms = new nonSeekableStream(Array.zeroCreate 10)
let wr = Net.HttpWebRequest.Create("http://x") :?> Net.HttpWebRequest
wr.Method <- "POST"
HttpHelpers.writeBody wr nonSeekms |> Async.RunSynchronously
wr.ContentLength |> should equal 0