Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/preview2-shim/lib/io/calls.js
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export const POLL_POLL_LIST = ++call_id << CALL_SHIFT;

// Futures
export const FUTURE_DISPOSE = ++call_id << CALL_SHIFT;
export const FUTURE_GET_VALUE_AND_DISPOSE = ++call_id << CALL_SHIFT;
export const FUTURE_TAKE_VALUE = ++call_id << CALL_SHIFT;
export const FUTURE_SUBSCRIBE = ++call_id << CALL_SHIFT;

// Http
Expand Down Expand Up @@ -102,7 +102,7 @@ export const SOCKET_UDP_SET_SEND_BUFFER_SIZE = ++call_id << CALL_SHIFT;
export const SOCKET_UDP_SET_UNICAST_HOP_LIMIT = ++call_id << CALL_SHIFT;
// Name lookup
export const SOCKET_RESOLVE_ADDRESS_CREATE_REQUEST = ++call_id << CALL_SHIFT;
export const SOCKET_RESOLVE_ADDRESS_GET_AND_DISPOSE_REQUEST = ++call_id << CALL_SHIFT;
export const SOCKET_RESOLVE_ADDRESS_TAKE_REQUEST = ++call_id << CALL_SHIFT;
export const SOCKET_RESOLVE_ADDRESS_SUBSCRIBE_REQUEST = ++call_id << CALL_SHIFT;
export const SOCKET_RESOLVE_ADDRESS_DISPOSE_REQUEST = ++call_id << CALL_SHIFT;

Expand Down
259 changes: 136 additions & 123 deletions packages/preview2-shim/lib/io/worker-socket-tcp.js
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import {
createFuture,
createPoll,
createReadableStream,
createReadableStreamPollState,
createWritableStream,
futureDispose,
futureTakeValue,
pollStateReady,
pollStateWait,
verifyPollsDroppedForDrop,
} from "./worker-thread.js";
// See: https://github.com/nodejs/node/blob/main/src/tcp_wrap.cc
const { TCP, constants: TCPConstants } = process.binding("tcp_wrap");
import {
deserializeIpAddress,
Expand All @@ -33,14 +34,14 @@ const globalBoundAddresses = new Set();
const isWindows = platform() === "win32";

let stateCnt = 0;
const SOCKET_STATE_INIT = ++stateCnt;
const SOCKET_STATE_BIND = ++stateCnt;
const SOCKET_STATE_BOUND = ++stateCnt;
const SOCKET_STATE_LISTEN = ++stateCnt;
const SOCKET_STATE_LISTENER = ++stateCnt;
const SOCKET_STATE_CONNECT = ++stateCnt;
const SOCKET_STATE_CONNECTION = ++stateCnt;
const SOCKET_STATE_ERROR = ++stateCnt;
export const SOCKET_STATE_INIT = ++stateCnt;
export const SOCKET_STATE_BIND = ++stateCnt;
export const SOCKET_STATE_BOUND = ++stateCnt;
export const SOCKET_STATE_LISTEN = ++stateCnt;
export const SOCKET_STATE_LISTENER = ++stateCnt;
export const SOCKET_STATE_CONNECT = ++stateCnt;
export const SOCKET_STATE_CONNECTION = ++stateCnt;
export const SOCKET_STATE_CLOSED = ++stateCnt;

/**
* @typedef {import("../../types/interfaces/wasi-sockets-network.js").IpSocketAddress} IpSocketAddress
Expand All @@ -54,7 +55,7 @@ const SOCKET_STATE_ERROR = ++stateCnt;
*
* @typedef {{
* state: number,
* bindOrConnectAddress: IpSocketAddress | null,
* future: number | null,
* serializedLocalAddress: string | null,
* listenBacklogSize: number,
* handle: TCP,
Expand All @@ -77,58 +78,73 @@ export function createTcpSocket() {
const handle = new TCP(TCPConstants.SOCKET);
tcpSockets.set(++tcpSocketCnt, {
state: SOCKET_STATE_INIT,
bindOrConnectAddress: null,
future: null,
serializedLocalAddress: null,
listenBacklogSize: 128,
handle,
pendingAccepts: [],
pollState: { ready: false, listener: null, polls: [] },
pollState: { ready: true, listener: null, polls: [], parentStream: null },
});
return tcpSocketCnt;
}

export function socketTcpSubscribe(id) {
const socket = tcpSockets.get(id);
return createPoll(socket.pollState);
return createPoll(tcpSockets.get(id).pollState);
}

export function socketTcpBindStart(id, localAddress) {
export function socketTcpFinish(id, fromState, toState) {
const socket = tcpSockets.get(id);
if (socket.state !== SOCKET_STATE_INIT) throw "invalid-state";
socket.state = SOCKET_STATE_BIND;
socket.bindOrConnectAddress = localAddress;
pollStateWait(socket.pollState);
if (socket.state !== fromState) throw "not-in-progress";
if (!socket.pollState.ready) throw "would-block";
const { tag, val } = futureTakeValue(socket.future).val;
futureDispose(socket.future, false);
socket.future = null;
if (tag === "err") {
socket.state = SOCKET_STATE_CLOSED;
throw val;
} else {
socket.state = toState;
// for the listener, we must immediately transition back to unresolved
if (toState === SOCKET_STATE_LISTENER)
socket.pollState.ready = false;
return val;
}
}

export function socketTcpBindFinish(id) {
export function socketTcpBindStart(id, localAddress, family) {
const socket = tcpSockets.get(id);
if (socket.state !== SOCKET_STATE_BIND) throw "not-in-progress";
if (socket.state !== SOCKET_STATE_INIT) throw "invalid-state";
if (family !== localAddress.tag || !isUnicastIpAddress(localAddress))
throw "invalid-argument";
if (isIPv4MappedAddress(localAddress)) throw "invalid-argument";
socket.state = SOCKET_STATE_BIND;
const { handle } = socket;
const address = serializeIpAddress(socket.bindOrConnectAddress);
if (isIPv4MappedAddress(socket.bindOrConnectAddress))
throw 'invalid-argument';
const port = socket.bindOrConnectAddress.val.port;
if (globalBoundAddresses.has(`${address}:${port}`)) throw "address-in-use";
const code =
socket.bindOrConnectAddress.tag === "ipv6"
? handle.bind6(address, port, TCPConstants.UV_TCP_IPV6ONLY)
: handle.bind(address, port);
if (code !== 0) {
socket.state = SOCKET_STATE_ERROR;
throw convertSocketErrorCode(-code);
}
const localAddress = socketTcpGetLocalAddress(id);
const serializedLocalAddress = `${serializeIpAddress(localAddress)}:${
localAddress.val.port
}`;
globalBoundAddresses.add(
(socket.serializedLocalAddress = serializedLocalAddress)
socket.future = createFuture(
(async () => {
const address = serializeIpAddress(localAddress);
const port = localAddress.val.port;
if (globalBoundAddresses.has(`${address}:${port}`))
throw "address-in-use";
const code =
localAddress.tag === "ipv6"
? handle.bind6(address, port, TCPConstants.UV_TCP_IPV6ONLY)
: handle.bind(address, port);
if (code !== 0) throw convertSocketErrorCode(-code);
{
const localAddress = socketTcpGetLocalAddress(id);
const serializedLocalAddress = `${serializeIpAddress(localAddress)}:${
localAddress.val.port
}`;
globalBoundAddresses.add(
(socket.serializedLocalAddress = serializedLocalAddress)
);
}
})(),
socket.pollState
);
socket.state = SOCKET_STATE_BOUND;
pollStateReady(socket.pollState, false);
}

export function socketTcpConnectStart(id, { remoteAddress, family }) {
export function socketTcpConnectStart(id, remoteAddress, family) {
const socket = tcpSockets.get(id);
if (socket.state !== SOCKET_STATE_INIT && socket.state !== SOCKET_STATE_BOUND)
throw "invalid-state";
Expand All @@ -142,64 +158,96 @@ export function socketTcpConnectStart(id, { remoteAddress, family }) {
) {
throw "invalid-argument";
}
if (isIPv4MappedAddress(remoteAddress)) throw "invalid-argument";
socket.state = SOCKET_STATE_CONNECT;
socket.bindOrConnectAddress = remoteAddress;
pollStateWait(socket.pollState);
socket.future = createFuture(
new Promise((resolve, reject) => {
const tcpSocket = new Socket({
handle: socket.handle,
pauseOnCreate: true,
allowHalfOpen: true,
});
function handleErr(err) {
tcpSocket.off("connect", handleConnect);
reject(err);
}
function handleConnect() {
tcpSocket.off("error", handleErr);
if (!tcpSocket.serializedLocalAddress) {
const localAddress = socketTcpGetLocalAddress(id);
const serializedLocalAddress = `${serializeIpAddress(localAddress)}:${
localAddress.val.port
}`;
globalBoundAddresses.add(
(tcpSocket.serializedLocalAddress = serializedLocalAddress)
);
}
resolve([
createReadableStream(tcpSocket),
createWritableStream(tcpSocket),
]);
}
tcpSocket.once("connect", handleConnect);
tcpSocket.once("error", handleErr);
tcpSocket.connect({
port: remoteAddress.val.port,
host: serializeIpAddress(remoteAddress),
lookup: () => {
throw "invalid-argument";
},
});
}),
socket.pollState
);
}

export function socketTcpConnectFinish(id) {
export function socketTcpListenStart(id) {
const socket = tcpSockets.get(id);
if (socket.state !== SOCKET_STATE_CONNECT) throw "not-in-progress";
const tcpSocket = new Socket({ handle: socket.handle, pauseOnCreate: true, allowHalfOpen: true });
const remoteAddress = socket.bindOrConnectAddress;
if (isIPv4MappedAddress(remoteAddress))
throw "invalid-argument";
return new Promise((resolve, reject) => {
function handleErr(err) {
tcpSocket.off("connect", handleConnect);
socket.state = SOCKET_STATE_ERROR;
pollStateReady(socket.pollState, false);
reject(err);
}
function handleConnect() {
tcpSocket.off("error", handleErr);
if (!tcpSocket.serializedLocalAddress) {
const localAddress = socketTcpGetLocalAddress(id);
const serializedLocalAddress = `${serializeIpAddress(localAddress)}:${
localAddress.val.port
}`;
globalBoundAddresses.add(
(tcpSocket.serializedLocalAddress = serializedLocalAddress)
);
if (socket.state !== SOCKET_STATE_BOUND) throw "invalid-state";
const { handle } = socket;
socket.state = SOCKET_STATE_LISTEN;
socket.future = createFuture(
new Promise((resolve, reject) => {
const server = new Server({ pauseOnConnect: true, allowHalfOpen: true });
function handleErr(err) {
server.off("listening", handleListen);
reject(err);
}
socket.state = SOCKET_STATE_CONNECTION;
pollStateReady(socket.pollState, false);
resolve([
createReadableStream(tcpSocket),
createWritableStream(tcpSocket),
]);
}
tcpSocket.once("connect", handleConnect);
tcpSocket.once("error", handleErr);
tcpSocket.connect({
port: remoteAddress.val.port,
host: serializeIpAddress(remoteAddress),
lookup: () => {
throw "invalid-argument";
},
});
});
function handleListen() {
server.off("error", handleErr);
server.on("connection", (tcpSocket) => {
pollStateReady(socket.pollState);
const pollState = createReadableStreamPollState(tcpSocket);
socket.pendingAccepts.push({ tcpSocket, err: null, pollState });
});
server.on("error", (err) => {
pollStateReady(socket.pollState);
socket.pendingAccepts.push({ tcpSocket: null, err, pollState: null });
});
resolve();
}
server.once("listening", handleListen);
server.once("error", handleErr);
server.listen(handle, socket.listenBacklogSize);
}),
socket.pollState
);
}

export function socketTcpAccept(id) {
const socket = tcpSockets.get(id);
if (socket.state !== SOCKET_STATE_LISTENER) throw "invalid-state";
if (socket.pendingAccepts.length === 0) throw "would-block";
const accept = socket.pendingAccepts.shift();
if (accept.err) throw convertSocketError(accept.err);
if (accept.err) {
socket.state = SOCKET_STATE_CLOSED;
throw convertSocketError(accept.err);
}
if (socket.pendingAccepts.length === 0)
socket.pollState.ready = false;
tcpSockets.set(++tcpSocketCnt, {
state: SOCKET_STATE_CONNECTION,
bindOrConnectAddress: null,
future: null,
serializedLocalAddress: null,
listenBacklogSize: 128,
handle: accept.tcpSocket._handle,
Expand All @@ -213,41 +261,6 @@ export function socketTcpAccept(id) {
];
}

export function socketTcpListenStart(id) {
const socket = tcpSockets.get(id);
if (socket.state !== SOCKET_STATE_BOUND) throw "invalid-state";
socket.state = SOCKET_STATE_LISTEN;
}

export function socketTcpListenFinish(id, backlogSize) {
const socket = tcpSockets.get(id);
if (socket.state !== SOCKET_STATE_LISTEN) throw "not-in-progress";
const { handle } = socket;
const server = new Server({ pauseOnConnect: true, allowHalfOpen: true });
return new Promise((resolve, reject) => {
function handleErr(err) {
server.off("listening", handleListen);
socket.state = SOCKET_STATE_ERROR;
reject(err);
}
function handleListen() {
server.off("error", handleErr);
server.on("connection", (tcpSocket) => {
const pollState = createReadableStreamPollState(tcpSocket);
socket.pendingAccepts.push({ tcpSocket, err: null, pollState });
});
server.on("error", (err) => {
socket.pendingAccepts.push({ tcpSocket: null, err, pollState: null });
});
socket.state = SOCKET_STATE_LISTENER;
resolve();
}
server.once("listening", handleListen);
server.once("error", handleErr);
server.listen(handle, backlogSize);
});
}

export function socketTcpIsListening(id) {
return tcpSockets.get(id).state === SOCKET_STATE_LISTENER;
}
Expand Down
Loading