Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 70 additions & 72 deletions crates/csharp/src/AsyncSupport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,13 @@ public enum CallbackCode : uint
//#define TEST_CALLBACK_CODE_WAIT(set) (2 | (set << 4))
}

public class WaitableSet(int handle) : IDisposable
// The context that we will create in unmanaged memory and pass to context_set.
// TODO: C has world specific types for these pointers, perhaps C# would benefit from those also.
[StructLayout(LayoutKind.Sequential)]
public struct ContextTask
{
public int Handle { get; } = handle;

void Dispose(bool _disposing)
{
AsyncSupport.WaitableSetDrop(handle);
}

public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}

~WaitableSet()
{
Dispose(false);
}
public int WaitableSetHandle;
public int FutureHandle;
}

public static class AsyncSupport
Expand All @@ -51,9 +39,6 @@ internal static class PollWasmInterop
internal static extern void wasmImportPoll(nint p0, int p1, nint p2);
}

// TODO: How do we allow multiple waitable sets?
internal static WaitableSet WaitableSet;

private static class Interop
{
[global::System.Runtime.InteropServices.DllImport("$root", EntryPoint = "[waitable-set-new]"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute]
Expand All @@ -69,7 +54,7 @@ private static class Interop
internal static unsafe extern uint WaitableSetPoll(int waitable, uint* waitableHandlePtr);

[global::System.Runtime.InteropServices.DllImport("$root", EntryPoint = "[waitable-set-drop]"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute]
internal static unsafe extern void WaitableSetDrop(int waitable);
internal static extern void WaitableSetDrop(int waitable);

[global::System.Runtime.InteropServices.DllImport("$root", EntryPoint = "[context-set-0]"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute]
internal static unsafe extern void ContextSet(ContextTask* waitable);
Expand All @@ -78,13 +63,14 @@ private static class Interop
internal static unsafe extern ContextTask* ContextGet();
}

public static WaitableSet WaitableSetNew()
public static int WaitableSetNew()
{
var waitableSet = Interop.WaitableSetNew();
Console.WriteLine($"WaitableSet created with number {waitableSet}");
return new WaitableSet(waitableSet);
return waitableSet;
}

// unsafe because we are using pointers.
public static unsafe void WaitableSetPoll(int waitableHandle)
{
var error = Interop.WaitableSetPoll(waitableHandle, null);
Expand All @@ -94,16 +80,16 @@ public static unsafe void WaitableSetPoll(int waitableHandle)
}
}

internal static void Join(SubtaskStatus subtask, WaitableSet set, WaitableInfoState waitableInfoState)
internal static void Join(SubtaskStatus subtask, int waitableSetHandle, WaitableInfoState waitableInfoState)
{
AddTaskToWaitables(set.Handle, subtask.Handle, waitableInfoState);
Interop.WaitableJoin(subtask.Handle, set.Handle);
AddTaskToWaitables(waitableSetHandle, subtask.Handle, waitableInfoState);
Interop.WaitableJoin(subtask.Handle, waitableSetHandle);
}

internal static void Join(int readerWriterHandle, WaitableSet set, WaitableInfoState waitableInfoState)
internal static void Join(int readerWriterHandle, int waitableHandle, WaitableInfoState waitableInfoState)
{
AddTaskToWaitables(set.Handle, readerWriterHandle, waitableInfoState);
Interop.WaitableJoin(readerWriterHandle, set.Handle);
AddTaskToWaitables(waitableHandle, readerWriterHandle, waitableInfoState);
Interop.WaitableJoin(readerWriterHandle, waitableHandle);
}

// TODO: Revisit this to see if we can remove it.
Expand All @@ -120,10 +106,11 @@ private static void AddTaskToWaitables(int waitableSetHandle, int waitableHandle
waitableSetOfTasks[waitableHandle] = waitableInfoState;
}

public unsafe static EventWaitable WaitableSetWait(WaitableSet set)
// unsafe because we use a fixed size buffer.
public static unsafe EventWaitable WaitableSetWait(int waitableSetHandle)
{
uint* buffer = stackalloc uint[2];
var eventCode = (EventCode)Interop.WaitableSetWait(set.Handle, buffer);
var eventCode = (EventCode)Interop.WaitableSetWait(waitableSetHandle, buffer);
return new EventWaitable(eventCode, buffer[0], buffer[1]);
}

Expand All @@ -132,34 +119,25 @@ public static void WaitableSetDrop(int handle)
Interop.WaitableSetDrop(handle);
}

// The context that we will create in unmanaged memory and pass to context_set.
// TODO: C has world specific types for these pointers, perhaps C# would benefit from those also.
[StructLayout(LayoutKind.Sequential)]
public struct ContextTask
{
public int Set;
public int FutureHandle;
}

// unsafe because we are using pointers.
public static unsafe void ContextSet(ContextTask* contextTask)
{
Interop.ContextSet(contextTask);
}

// unsafe because we are using pointers.
public static unsafe ContextTask* ContextGet()
{
ContextTask* contextTaskPtr = Interop.ContextGet();
if(contextTaskPtr == null)
{
throw new Exception("null context returned.");
}
return contextTaskPtr;
return Interop.ContextGet();
}

// unsafe because we are using pointers.
public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr, Action taskReturn)
{
Console.WriteLine($"Callback Event code {e.EventCode} Code {e.Code} Waitable {e.Waitable} Waitable Status {e.WaitableStatus.State}, Count {e.WaitableCount}");
var waitables = pendingTasks[WaitableSet.Handle];
ContextTask* contextTaskPtr = ContextGet();

var waitables = pendingTasks[contextTaskPtr->WaitableSetHandle];
var waitableInfoState = waitables[e.Waitable];

if (e.IsDropped)
Expand Down Expand Up @@ -195,32 +173,36 @@ public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr, Act

if (waitables.Count == 0)
{
Console.WriteLine($"No more waitables for waitable {e.Waitable} in set {WaitableSet.Handle}");
Console.WriteLine($"No more waitables for waitable {e.Waitable} in set {contextTaskPtr->WaitableSetHandle}");
taskReturn();
ContextSet(null);
Marshal.FreeHGlobal((IntPtr)contextTaskPtr);
return (uint)CallbackCode.Exit;
}

Console.WriteLine("More waitables in the set.");
return (uint)CallbackCode.Wait | (uint)(WaitableSet.Handle << 4);
return (uint)CallbackCode.Wait | (uint)(contextTaskPtr->WaitableSetHandle << 4);
}

throw new NotImplementedException($"WaitableStatus not implemented {e.WaitableStatus.State} in set {WaitableSet.Handle}");
throw new NotImplementedException($"WaitableStatus not implemented {e.WaitableStatus.State} in set {contextTaskPtr->WaitableSetHandle}");
}

public static Task TaskFromStatus(uint status)
// This method is unsafe because we are using unmanaged memory to store the context.
internal static unsafe Task TaskFromStatus(uint status)
{
var subtaskStatus = new SubtaskStatus(status);
status = status & 0xF;

if (subtaskStatus.IsSubtaskStarting || subtaskStatus.IsSubtaskStarted)
{
if (WaitableSet == null) {
WaitableSet = WaitableSetNew();
Console.WriteLine($"TaskFromStatus creating WaitableSet {WaitableSet.Handle}");
ContextTask* contextTaskPtr = ContextGet();
if (contextTaskPtr == null) {
contextTaskPtr = AllocateAndSetNewContext();
Console.WriteLine($"TaskFromStatus creating WaitableSet {contextTaskPtr->WaitableSetHandle}");
}

TaskCompletionSource tcs = new TaskCompletionSource();
AsyncSupport.Join(subtaskStatus, WaitableSet, new WaitableInfoState(tcs));
Join(subtaskStatus, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs));
return tcs.Task;
}
else if (subtaskStatus.IsSubtaskReturned)
Expand All @@ -233,7 +215,8 @@ public static Task TaskFromStatus(uint status)
}
}

public static Task<T> TaskFromStatus<T>(uint status, Func<T> liftFunc)
// unsafe because we are using pointers.
public static unsafe Task<T> TaskFromStatus<T>(uint status, Func<T> liftFunc)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What makes this function unsafe to call? Could it be documented?

{
var subtaskStatus = new SubtaskStatus(status);
status = status & 0xF;
Expand All @@ -242,9 +225,12 @@ public static Task<T> TaskFromStatus<T>(uint status, Func<T> liftFunc)
var tcs = new TaskCompletionSource<T>();
if (subtaskStatus.IsSubtaskStarting || subtaskStatus.IsSubtaskStarted)
{
if (WaitableSet == null) {
ContextTask* contextTaskPtr = ContextGet();
if (contextTaskPtr == null) {
contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf<ContextTask>());
Console.WriteLine("TaskFromStatus<T> creating WaitableSet");
WaitableSet = AsyncSupport.WaitableSetNew();
contextTaskPtr->WaitableSetHandle = WaitableSetNew();
ContextSet(contextTaskPtr);
}

return tcs.Task;
Expand All @@ -259,6 +245,15 @@ public static Task<T> TaskFromStatus<T>(uint status, Func<T> liftFunc)
throw new Exception($"unexpected subtask status: {status}");
}
}

// unsafe because we are working with native memory.
internal static unsafe ContextTask* AllocateAndSetNewContext()
{
var contextTaskPtr = (ContextTask *)Marshal.AllocHGlobal(Marshal.SizeOf<ContextTask>());
contextTaskPtr->WaitableSetHandle = AsyncSupport.WaitableSetNew();
AsyncSupport.ContextSet(contextTaskPtr);
return contextTaskPtr;
}
}


Expand Down Expand Up @@ -371,6 +366,7 @@ internal int TakeHandle()

internal abstract uint VTableRead(IntPtr bufferPtr, int length);

// unsafe as we are working with pointers.
internal unsafe Task<int> ReadInternal(Func<GCHandle?> liftBuffer, int length)
{
if (Handle == 0)
Expand All @@ -389,14 +385,15 @@ internal unsafe Task<int> ReadInternal(Func<GCHandle?> liftBuffer, int length)
{
Console.WriteLine("Read Blocked");
var tcs = new TaskCompletionSource<int>();
if(AsyncSupport.WaitableSet == null)
ContextTask* contextTaskPtr = AsyncSupport.ContextGet();
if(contextTaskPtr == null)
{
Console.WriteLine("FutureReader Read Blocked creating WaitableSet");
AsyncSupport.WaitableSet = AsyncSupport.WaitableSetNew();
contextTaskPtr = AsyncSupport.AllocateAndSetNewContext();
}
Console.WriteLine("blocked read before join");

AsyncSupport.Join(Handle, AsyncSupport.WaitableSet, new WaitableInfoState(tcs, this));
AsyncSupport.Join(Handle, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs, this));
Console.WriteLine("blocked read after join");
return tcs.Task;
}
Expand Down Expand Up @@ -470,7 +467,7 @@ public class FutureReader<T>(int handle, FutureVTable vTable) : ReaderBase(handl
{
public FutureVTable VTable { get; private set; } = vTable;

private GCHandle LiftBuffer<T>(T buffer)
private GCHandle LiftBuffer(T buffer)
{
if(typeof(T) == typeof(byte))
{
Expand All @@ -483,7 +480,7 @@ private GCHandle LiftBuffer<T>(T buffer)
}
}

public unsafe Task Read<T>(T buffer)
public Task Read(T buffer)
{
return ReadInternal(() => LiftBuffer(buffer), 1);
}
Expand All @@ -508,7 +505,7 @@ public StreamReader(int handle, StreamVTable vTable) : base(handle)

public StreamVTable VTable { get; private set; }

public unsafe Task Read(int length)
public Task Read(int length)
{
return ReadInternal(() => null, length);
}
Expand All @@ -528,7 +525,7 @@ public class StreamReader<T>(int handle, StreamVTable vTable) : ReaderBase(hand
{
public StreamVTable VTable { get; private set; } = vTable;

private GCHandle LiftBuffer<T>(T[] buffer)
private GCHandle LiftBuffer(T[] buffer)
{
if(typeof(T) == typeof(byte))
{
Expand All @@ -541,7 +538,7 @@ private GCHandle LiftBuffer<T>(T[] buffer)
}
}

public unsafe Task<int> Read<T>(T[] buffer)
public Task<int> Read(T[] buffer)
{
return ReadInternal(() => LiftBuffer(buffer), buffer.Length);
}
Expand Down Expand Up @@ -582,6 +579,7 @@ internal int TakeHandle()

internal abstract uint VTableWrite(IntPtr bufferPtr, int length);

// unsafe as we are working with pointers.
internal unsafe Task<int> WriteInternal(Func<GCHandle?> lowerPayload, int length)
{
if (Handle == 0)
Expand All @@ -600,12 +598,13 @@ internal unsafe Task<int> WriteInternal(Func<GCHandle?> lowerPayload, int length
{
Console.WriteLine("blocked write");
var tcs = new TaskCompletionSource<int>();
if(AsyncSupport.WaitableSet == null)
ContextTask* contextTaskPtr = AsyncSupport.ContextGet();
if(contextTaskPtr == null)
{
AsyncSupport.WaitableSet = AsyncSupport.WaitableSetNew();
contextTaskPtr = AsyncSupport.AllocateAndSetNewContext();
}
Console.WriteLine("blocked write before join");
AsyncSupport.Join(Handle, AsyncSupport.WaitableSet, new WaitableInfoState(tcs, this));
AsyncSupport.Join(Handle, contextTaskPtr->WaitableSetHandle, new WaitableInfoState(tcs, this));
Console.WriteLine("blocked write after join");
return tcs.Task;
}
Expand Down Expand Up @@ -679,7 +678,6 @@ public class FutureWriter<T>(int handle, FutureVTable vTable) : WriterBase(handl
// TODO: Generate per type for this instrinsic.
public Task Write()
{
// TODO: Lower T
return WriteInternal(() => null, 1);
}

Expand Down Expand Up @@ -719,7 +717,7 @@ public class StreamWriter<T>(int handle, StreamVTable vTable) : WriterBase(handl
private GCHandle bufferHandle;
public StreamVTable VTable { get; private set; } = vTable;

private GCHandle LowerPayload<T>(T[] payload)
private GCHandle LowerPayload(T[] payload)
{
if (VTable.Lower == null)
{
Expand Down
3 changes: 2 additions & 1 deletion crates/csharp/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,8 @@ impl Bindgen for FunctionBindgen<'_, '_> {
}});

// TODO: Defer dropping borrowed resources until a result is returned.
return (uint)CallbackCode.Wait | (uint)(AsyncSupport.WaitableSet.Handle << 4);
ContextTask* contextTaskPtr = AsyncSupport.ContextGet();
return (uint)CallbackCode.Wait | (uint)(contextTaskPtr->WaitableSetHandle << 4);
"#);
}

Expand Down
2 changes: 1 addition & 1 deletion crates/csharp/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ var {async_status_var} = {raw_name}({wasm_params});
uwriteln!(
self.csharp_interop_src,
r#"
return (uint)AsyncSupport.Callback(e, (AsyncSupport.ContextTask *)IntPtr.Zero, () => {camel_name}TaskReturn());
return (uint)AsyncSupport.Callback(e, (ContextTask *)IntPtr.Zero, () => {camel_name}TaskReturn());
}}
"#
);
Expand Down
Loading