diff --git a/Microsoft.Toolkit.Uwp/IncrementalLoadingCollection/IncrementalLoadingCollection.cs b/Microsoft.Toolkit.Uwp/IncrementalLoadingCollection/IncrementalLoadingCollection.cs index 44871dde607..4058b0bf142 100644 --- a/Microsoft.Toolkit.Uwp/IncrementalLoadingCollection/IncrementalLoadingCollection.cs +++ b/Microsoft.Toolkit.Uwp/IncrementalLoadingCollection/IncrementalLoadingCollection.cs @@ -30,6 +30,8 @@ public class IncrementalLoadingCollection : ObservableCollection ISupportIncrementalLoading where TSource : Collections.IIncrementalSource { + private readonly SemaphoreSlim _mutex = new SemaphoreSlim(1); + /// /// Gets or sets an that is called when a retrieval operation begins. /// @@ -226,7 +228,23 @@ public Task RefreshAsync() /// protected virtual async Task> LoadDataAsync(CancellationToken cancellationToken) { - var result = await Source.GetPagedItemsAsync(CurrentPageIndex++, ItemsPerPage, cancellationToken); + var result = await Source.GetPagedItemsAsync(CurrentPageIndex, ItemsPerPage, cancellationToken) + .ContinueWith( + t => + { + if(t.IsFaulted) + { + throw t.Exception; + } + + if (t.IsCompletedSuccessfully) + { + CurrentPageIndex += 1; + } + + return t.Result; + }, cancellationToken); + return result; } @@ -235,6 +253,9 @@ private async Task LoadMoreItemsAsync(uint count, Cancellat uint resultCount = 0; _cancellationToken = cancellationToken; + // TODO (2021.05.05): Make use common AsyncMutex class. + // AsyncMutex is located at Microsoft.Toolkit.Uwp.UI.Media/Extensions/System.Threading.Tasks/AsyncMutex.cs at the time of this note. + await _mutex.WaitAsync(); try { if (!_cancellationToken.IsCancellationRequested) @@ -278,6 +299,8 @@ private async Task LoadMoreItemsAsync(uint count, Cancellat _refreshOnLoad = false; await RefreshAsync(); } + + _mutex.Release(); } return new LoadMoreItemsResult { Count = resultCount }; diff --git a/UnitTests/UnitTests.UWP/UI/Collection/DataSource.cs b/UnitTests/UnitTests.UWP/UI/Collection/DataSource.cs new file mode 100644 index 00000000000..b6c76ed83d4 --- /dev/null +++ b/UnitTests/UnitTests.UWP/UI/Collection/DataSource.cs @@ -0,0 +1,59 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Toolkit.Collections; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace UnitTests.UI +{ + public class DataSource : IIncrementalSource + { + private readonly IEnumerable items; + private readonly Queue pageRequestOperations; + + public delegate IEnumerable PageOperation(IEnumerable page); + + public DataSource(IEnumerable items, IEnumerable pageOps) + : this(items, new Queue(pageOps)) + { + } + + public DataSource(IEnumerable items, params PageOperation[] pageOps) + : this(items, new Queue(pageOps)) + { + } + + public DataSource(IEnumerable items, Queue pageOps = default) + { + this.items = items ?? throw new ArgumentNullException(nameof(items)); + this.pageRequestOperations = pageOps ?? new Queue(); + } + + public static PageOperation MakeDelayOp(int delay) + => new (page => + { + Thread.Sleep(delay); + return page; + }); + + public static IEnumerable ThrowException(IEnumerable page) => throw new Exception(); + + public static IEnumerable PassThrough(IEnumerable page) => page; + + public async Task> GetPagedItemsAsync(int pageIndex, int pageSize, CancellationToken cancellationToken = default) + { + // Gets items from the collection according to pageIndex and pageSize parameters. + var result = (from p in items + select p).Skip(pageIndex * pageSize).Take(pageSize); + + return this.pageRequestOperations.TryDequeue(out var op) + ? await Task.Factory.StartNew(new Func>(o => op(o as IEnumerable)), state: result) + : result; + } + } +} diff --git a/UnitTests/UnitTests.UWP/UI/Collection/Test_IncrementalLoadingCollection.cs b/UnitTests/UnitTests.UWP/UI/Collection/Test_IncrementalLoadingCollection.cs new file mode 100644 index 00000000000..45a35d17b85 --- /dev/null +++ b/UnitTests/UnitTests.UWP/UI/Collection/Test_IncrementalLoadingCollection.cs @@ -0,0 +1,145 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Toolkit.Uwp; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace UnitTests.UI +{ + [TestClass] + public class Test_IncrementalLoadingCollection + { + private const int PageSize = 20; + private const int Pages = 5; + + private static readonly DataSource.PageOperation[] FailPassSequence + = new DataSource.PageOperation[] + { + DataSource.ThrowException, DataSource.PassThrough, + DataSource.ThrowException, DataSource.PassThrough, + DataSource.ThrowException, DataSource.PassThrough, + DataSource.ThrowException, DataSource.PassThrough, + DataSource.ThrowException, DataSource.PassThrough, + }; + + private static readonly int[] AllData + = Enumerable.Range(0, Pages * PageSize).ToArray(); + + [DataRow] + [DataRow(2500, 1000, 1000, 1000, 1000)] + [TestMethod] + public async Task Requests(params int[] pageDelays) + { + var source = new DataSource(AllData, pageDelays.Select(DataSource.MakeDelayOp)); + var collection = new IncrementalLoadingCollection, int>(source, PageSize); + + for (int pageNum = 1; pageNum <= Pages; pageNum++) + { + await collection.LoadMoreItemsAsync(0); + CollectionAssert.AreEqual(Enumerable.Range(0, PageSize * pageNum).ToArray(), collection); + } + } + + [DataRow] + [DataRow(2500, 1000, 1000, 1000, 1000)] + [TestMethod] + public async Task RequestsAsync(params int[] pageDelays) + { + var source = new DataSource(AllData, pageDelays.Select(DataSource.MakeDelayOp)); + var collection = new IncrementalLoadingCollection, int>(source, PageSize); + + var requests = new List(); + + for (int pageNum = 1; pageNum <= Pages; pageNum++) + { + requests.Add(collection.LoadMoreItemsAsync(0).AsTask() + .ContinueWith(t => Assert.IsTrue(t.IsCompletedSuccessfully))); + } + + await Task.WhenAll(requests); + + CollectionAssert.AreEqual(AllData, collection); + } + + [TestMethod] + public async Task FirstRequestFails() + { + var source = new DataSource(AllData, DataSource.ThrowException); + var collection = new IncrementalLoadingCollection, int>(source, PageSize); + + await Assert.ThrowsExceptionAsync(collection.LoadMoreItemsAsync(0).AsTask); + + Assert.IsTrue(!collection.Any()); + + var requests = new List(); + + for (int pageNum = 1; pageNum <= Pages; pageNum++) + { + requests.Add(collection.LoadMoreItemsAsync(0).AsTask() + .ContinueWith(t => Assert.IsTrue(t.IsCompletedSuccessfully))); + } + + await Task.WhenAll(requests); + + CollectionAssert.AreEqual(AllData, collection); + } + + [TestMethod] + public async Task EveryOtherRequestFails() + { + var source = new DataSource(AllData, FailPassSequence); + var collection = new IncrementalLoadingCollection, int>(source, PageSize); + + var willFail = true; + for (int submitedRequests = 0; submitedRequests < Pages * 2; submitedRequests++) + { + if (willFail) + { + await Assert.ThrowsExceptionAsync(collection.LoadMoreItemsAsync(0).AsTask); + } + else + { + await collection.LoadMoreItemsAsync(0); + } + + willFail = !willFail; + } + + CollectionAssert.AreEqual(AllData, collection); + } + + [TestMethod] + public async Task EveryOtherRequestFailsAsync() + { + var source = new DataSource(AllData, FailPassSequence); + var collection = new IncrementalLoadingCollection, int>(source, PageSize); + + var requests = new List(); + + var willFail = true; + for (int submitedRequests = 0; submitedRequests < Pages * 2; submitedRequests++) + { + if (willFail) + { + requests.Add(Assert.ThrowsExceptionAsync(collection.LoadMoreItemsAsync(0).AsTask)); + } + else + { + requests.Add(collection.LoadMoreItemsAsync(0).AsTask().ContinueWith(t => Assert.IsTrue(t.IsCompletedSuccessfully))); + } + + willFail = !willFail; + } + + await Task.WhenAll(requests); + + CollectionAssert.AreEqual(AllData, collection); + } + } +} \ No newline at end of file diff --git a/UnitTests/UnitTests.UWP/UnitTests.UWP.csproj b/UnitTests/UnitTests.UWP/UnitTests.UWP.csproj index b9631e7cff8..59f18e8b106 100644 --- a/UnitTests/UnitTests.UWP/UnitTests.UWP.csproj +++ b/UnitTests/UnitTests.UWP/UnitTests.UWP.csproj @@ -197,6 +197,8 @@ + + @@ -548,4 +550,4 @@ --> - + \ No newline at end of file