diff --git a/src/libs/Pusher/index.native.ts b/src/libs/Pusher/index.native.ts index 72cd06774c383..c9194b8c5bf36 100644 --- a/src/libs/Pusher/index.native.ts +++ b/src/libs/Pusher/index.native.ts @@ -10,7 +10,7 @@ import {authenticatePusher} from '@userActions/Session'; import CONST from '@src/CONST'; import ONYXKEYS from '@src/ONYXKEYS'; import TYPE from './EventType'; -import type {Args, ChunkedDataEvents, EventCallbackError, EventData, PusherEventName, SocketEventCallback, SocketEventName, States} from './types'; +import type {Args, ChunkedDataEvents, EventCallbackError, EventData, PusherEventName, PusherSubscription, SocketEventCallback, SocketEventName, States} from './types'; import type PusherModule from './types'; let shouldForceOffline = false; @@ -34,7 +34,9 @@ let initPromise = new Promise((resolve) => { resolveInitPromise = resolve; }); -const eventsBoundToChannels = new Map) => void>>(); +type BoundCallback = (eventData: EventData) => void; + +const eventsBoundToChannels = new Map>>(); let channels: Record> = {}; /** @@ -126,11 +128,16 @@ function parseEventData(eventData: EventData< } /** - * Binds an event callback to a channel + eventName + * Binds an event callback to a channel + eventName. + * Returns the wrapped callback so it can be individually unbound later. */ -function bindEventToChannel(channel: string, eventName?: EventName, eventCallback: (data: EventData) => void = () => {}) { +function bindEventToChannel( + channel: string, + eventName?: EventName, + eventCallback: (data: EventData) => void = () => {}, +): BoundCallback | undefined { if (!eventName) { - return; + return undefined; } const chunkedDataEvents: Record = {}; @@ -192,24 +199,40 @@ function bindEventToChannel(channel: string, if (!eventsBoundToChannels.has(channel)) { eventsBoundToChannels.set(channel, new Map()); } + const eventMap = eventsBoundToChannels.get(channel); + if (!eventMap?.has(eventName)) { + eventMap?.set(eventName, new Set()); + } + const boundCb = callback as BoundCallback; + eventMap?.get(eventName)?.add(boundCb); - eventsBoundToChannels.get(channel)?.set(eventName, callback as (eventData: EventData) => void); + return boundCb; } /** - * Subscribe to a channel and an event + * Subscribe to a channel and an event. + * Returns a PusherSubscription — a Promise (for backward-compatible .catch()/.then()) + * with an .unsubscribe() method that removes only this specific callback. */ function subscribe( channelName: string, eventName?: EventName, eventCallback: (data: EventData) => void = () => {}, onResubscribe = () => {}, -): Promise { - return initPromise.then( +): PusherSubscription { + let wrappedCb: BoundCallback | undefined; + let disposed = false; + + const promise = initPromise.then( () => - new Promise((resolve, reject) => { + new Promise((resolve, reject) => { // eslint-disable-next-line @typescript-eslint/no-deprecated InteractionManager.runAfterInteractions(() => { + if (disposed) { + resolve(); + return; + } + // We cannot call subscribe() before init(). Prevent any attempt to do this on dev. if (!socket) { const error = new Error('[Pusher] instance not found. Pusher.subscribe() most likely has been called before Pusher.init()'); @@ -237,12 +260,27 @@ function subscribe( socket.subscribe({ channelName, onEvent: (event) => { - const callback = eventsBoundToChannels.get(event.channelName)?.get(event.eventName); - callback?.(event.data as EventData); + const callbacks = eventsBoundToChannels.get(event.channelName)?.get(event.eventName); + if (callbacks) { + for (const cb of callbacks) { + cb(event.data as EventData); + } + } }, onSubscriptionSucceeded: () => { channels[channelName] = CONST.PUSHER.CHANNEL_STATUS.SUBSCRIBED; - bindEventToChannel(channelName, eventName, eventCallback); + if (!disposed) { + wrappedCb = bindEventToChannel(channelName, eventName, eventCallback); + } else { + // Handle was disposed mid-handshake — clean up the channel + // if no other subscribers have bound callbacks to it + const eventMap = eventsBoundToChannels.get(channelName); + if (!eventMap || eventMap.size === 0) { + eventsBoundToChannels.delete(channelName); + delete channels[channelName]; + socket?.unsubscribe({channelName}); + } + } resolve(); // When subscribing for the first time we register a success callback that can be // called multiple times when the subscription succeeds again in the future @@ -260,16 +298,48 @@ function subscribe( }, }); } else { - bindEventToChannel(channelName, eventName, eventCallback); + if (!disposed) { + wrappedCb = bindEventToChannel(channelName, eventName, eventCallback); + } resolve(); } }); }), ); + + return Object.assign(promise, { + unsubscribe: () => { + disposed = true; + if (!wrappedCb || !eventName) { + return; + } + + // 1. Remove this specific callback from tracking + const eventMap = eventsBoundToChannels.get(channelName); + const callbacks = eventMap?.get(eventName); + callbacks?.delete(wrappedCb); + + // 2. If last callback for this event, remove the event + if (callbacks?.size === 0) { + eventMap?.delete(eventName); + } + + // 3. If last event on this channel, unsubscribe entirely + if (eventMap?.size === 0) { + eventsBoundToChannels.delete(channelName); + delete channels[channelName]; + socket?.unsubscribe({channelName}); + } + + wrappedCb = undefined; + }, + }); } /** - * Unsubscribe from a channel and optionally a specific event + * Unsubscribe from a channel and optionally a specific event. + * This removes ALL callbacks for the given event (or all events on the channel). + * For per-callback removal, use the .unsubscribe() method on the PusherSubscription handle. */ function unsubscribe(channelName: string, eventName: PusherEventName = '') { const channel = getChannel(channelName); diff --git a/src/libs/Pusher/index.ts b/src/libs/Pusher/index.ts index 827a3885a136f..173a74e777240 100644 --- a/src/libs/Pusher/index.ts +++ b/src/libs/Pusher/index.ts @@ -13,6 +13,7 @@ import type { EventCallbackError, EventData, PusherEventName, + PusherSubscription, PusherSubscriptionErrorData, PusherWithAuthParams, SocketEventCallback, @@ -45,7 +46,10 @@ let initPromise = new Promise((resolve) => { resolveInitPromise = resolve; }); -const eventsBoundToChannels = new Map>(); +// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Pusher callbacks have varying signatures due to chunking wrapper +type BoundCallback = (eventData: any) => void; + +const eventsBoundToChannels = new Map>>(); /** * Trigger each of the socket event callbacks with the event information @@ -118,11 +122,16 @@ function getChannel(channelName: string): Channel | undefined { } /** - * Binds an event callback to a channel + eventName + * Binds an event callback to a channel + eventName. + * Returns the wrapped callback so it can be individually unbound later. */ -function bindEventToChannel(channel: Channel | undefined, eventName?: EventName, eventCallback: (data: EventData) => void = () => {}) { +function bindEventToChannel( + channel: Channel | undefined, + eventName?: EventName, + eventCallback: (data: EventData) => void = () => {}, +): BoundCallback | undefined { if (!eventName || !channel) { - return; + return undefined; } const chunkedDataEvents: Record = {}; @@ -184,14 +193,23 @@ function bindEventToChannel(channel: Channel }; channel.bind(eventName, callback); + if (!eventsBoundToChannels.has(channel)) { - eventsBoundToChannels.set(channel, new Set()); + eventsBoundToChannels.set(channel, new Map()); + } + const eventMap = eventsBoundToChannels.get(channel); + if (!eventMap?.has(eventName)) { + eventMap?.set(eventName, new Set()); } - eventsBoundToChannels.get(channel)?.add(eventName); + eventMap?.get(eventName)?.add(callback); + + return callback; } /** - * Subscribe to a channel and an event + * Subscribe to a channel and an event. + * Returns a PusherSubscription — a Promise (for backward-compatible .catch()/.then()) + * with an .unsubscribe() method that removes only this specific callback. * @param [onResubscribe] Callback to be called when reconnection happen */ function subscribe( @@ -199,12 +217,21 @@ function subscribe( eventName?: EventName, eventCallback: (data: EventData) => void = () => {}, onResubscribe = () => {}, -): Promise { - return initPromise.then( +): PusherSubscription { + let wrappedCb: BoundCallback | undefined; + let resolvedChannel: Channel | undefined; + let disposed = false; + + const promise = initPromise.then( () => - new Promise((resolve, reject) => { + new Promise((resolve, reject) => { // eslint-disable-next-line @typescript-eslint/no-deprecated InteractionManager.runAfterInteractions(() => { + if (disposed) { + resolve(); + return; + } + // We cannot call subscribe() before init(). Prevent any attempt to do this on dev. if (!socket) { const error = new Error('[Pusher] instance not found. Pusher.subscribe() most likely has been called before Pusher.init()'); @@ -234,7 +261,18 @@ function subscribe( channel.bind('pusher:subscription_succeeded', () => { // Check so that we do not bind another event with each reconnect attempt if (!isBound) { - bindEventToChannel(channel, eventName, eventCallback); + if (!disposed) { + wrappedCb = bindEventToChannel(channel, eventName, eventCallback); + resolvedChannel = channel ?? undefined; + } else if (channel) { + // Handle was disposed mid-handshake — clean up the channel + // if no other subscribers have bound callbacks to it + const eventMap = eventsBoundToChannels.get(channel); + if (!eventMap || eventMap.size === 0) { + eventsBoundToChannels.delete(channel); + socket?.unsubscribe(channelName); + } + } resolve(); isBound = true; return; @@ -258,16 +296,52 @@ function subscribe( reject(error); }); } else { - bindEventToChannel(channel, eventName, eventCallback); + if (!disposed) { + wrappedCb = bindEventToChannel(channel, eventName, eventCallback); + resolvedChannel = channel; + } resolve(); } }); }), ); + + return Object.assign(promise, { + unsubscribe: () => { + disposed = true; + if (!wrappedCb || !resolvedChannel || !eventName) { + return; + } + + // 1. Unbind this specific callback from pusher-js + resolvedChannel.unbind(eventName, wrappedCb); + + // 2. Remove from tracking + const eventMap = eventsBoundToChannels.get(resolvedChannel); + const callbacks = eventMap?.get(eventName); + callbacks?.delete(wrappedCb); + + // 3. If last callback for this event, remove the event + if (callbacks?.size === 0) { + eventMap?.delete(eventName); + } + + // 4. If last event on this channel, unsubscribe entirely + if (eventMap?.size === 0) { + eventsBoundToChannels.delete(resolvedChannel); + socket?.unsubscribe(channelName); + } + + wrappedCb = undefined; + resolvedChannel = undefined; + }, + }); } /** - * Unsubscribe from a channel and optionally a specific event + * Unsubscribe from a channel and optionally a specific event. + * This removes ALL callbacks for the given event (or all events on the channel). + * For per-callback removal, use the .unsubscribe() method on the PusherSubscription handle. */ function unsubscribe(channelName: string, eventName: PusherEventName = '') { const channel = getChannel(channelName); @@ -294,6 +368,7 @@ function unsubscribe(channelName: string, eventName: PusherEventName = '') { Log.info('[Pusher] Unsubscribing from channel', false, {channelName}); channel.unbind(); + eventsBoundToChannels.delete(channel); socket?.unsubscribe(channelName); } } @@ -369,6 +444,7 @@ function disconnect() { socket.disconnect(); socket = null; pusherSocketID = ''; + eventsBoundToChannels.clear(); initPromise = new Promise((resolve) => { resolveInitPromise = resolve; }); diff --git a/src/libs/Pusher/types.ts b/src/libs/Pusher/types.ts index c136e8e27b049..ad9fa1d4f154e 100644 --- a/src/libs/Pusher/types.ts +++ b/src/libs/Pusher/types.ts @@ -68,6 +68,10 @@ type PusherEventName = LiteralUnion, string>; type PusherSubscriptionErrorData = {type?: string; error?: string; status?: string}; +type PusherSubscription = Promise & { + unsubscribe: () => void; +}; + type PusherModule = { init: (args: Args) => Promise; subscribe: ( @@ -75,7 +79,7 @@ type PusherModule = { eventName?: EventName, eventCallback?: (data: EventData) => void, onResubscribe?: () => void, - ) => Promise; + ) => PusherSubscription; unsubscribe: (channelName: string, eventName?: PusherEventName) => void; getChannel: (channelName: string) => Channel | PusherChannel | undefined; isSubscribed: (channelName: string) => boolean; @@ -105,5 +109,6 @@ export type { SocketEventCallback, PusherWithAuthParams, PusherEventName, + PusherSubscription, PusherSubscriptionErrorData, }; diff --git a/tests/unit/PusherSubscribeTest.ts b/tests/unit/PusherSubscribeTest.ts index a63114ce67caa..5409658637bc8 100644 --- a/tests/unit/PusherSubscribeTest.ts +++ b/tests/unit/PusherSubscribeTest.ts @@ -2,6 +2,8 @@ import Log from '@libs/Log'; import Pusher from '@libs/Pusher'; import CONFIG from '@src/CONFIG'; import PusherConnectionManager from '@src/libs/PusherConnectionManager'; +// eslint-disable-next-line import/no-relative-packages -- Import mock class directly for proper typing +import {Pusher as MockedPusher} from '../../__mocks__/@pusher/pusher-websocket-react-native/index'; /** * Tests for Pusher.subscribe() graceful handling when socket is disconnected @@ -119,3 +121,217 @@ describe('Pusher.subscribe', () => { await expect(subscribePromise).resolves.toBeUndefined(); }); }); + +describe('Per-callback subscription handles', () => { + const CHANNEL = 'private-user-callback'; + const EVENT = 'testEvent'; + + beforeEach(async () => { + jest.spyOn(Pusher, 'isSubscribed').mockReturnValue(false); + jest.spyOn(Pusher, 'isAlreadySubscribing').mockReturnValue(false); + await initPusher(); + }); + + afterEach(() => { + Pusher.disconnect(); + jest.restoreAllMocks(); + }); + + function triggerEvent(channelName: string, eventName: string, data: Record = {value: 1}) { + // Fire events through the mock socket's trigger, which invokes the Pusher module's + // onEvent dispatcher that iterates over eventsBoundToChannels. + MockedPusher.getInstance().trigger({channelName, eventName, data}); + } + + it('should return a PusherSubscription with an unsubscribe method', async () => { + const handle = Pusher.subscribe(CHANNEL, EVENT, () => {}); + await jest.runAllTimersAsync(); + await handle; + + expect(typeof handle.unsubscribe).toBe('function'); + }); + + it('should deliver events to both subscribers on the same channel+event', async () => { + const callbackA = jest.fn(); + const callbackB = jest.fn(); + + const handleA = Pusher.subscribe(CHANNEL, EVENT, callbackA); + await jest.runAllTimersAsync(); + await handleA; + + // Second subscribe to same channel — goes through the "already subscribed" branch + jest.spyOn(Pusher, 'isSubscribed').mockReturnValue(true); + const handleB = Pusher.subscribe(CHANNEL, EVENT, callbackB); + await jest.runAllTimersAsync(); + await handleB; + + triggerEvent(CHANNEL, EVENT, {msg: 'hello'}); + + expect(callbackA).toHaveBeenCalledTimes(1); + expect(callbackB).toHaveBeenCalledTimes(1); + expect(callbackA).toHaveBeenCalledWith({msg: 'hello'}); + + handleA.unsubscribe(); + handleB.unsubscribe(); + }); + + it('should stop delivering events to an unsubscribed callback while others continue', async () => { + const callbackA = jest.fn(); + const callbackB = jest.fn(); + + const handleA = Pusher.subscribe(CHANNEL, EVENT, callbackA); + await jest.runAllTimersAsync(); + await handleA; + + jest.spyOn(Pusher, 'isSubscribed').mockReturnValue(true); + const handleB = Pusher.subscribe(CHANNEL, EVENT, callbackB); + await jest.runAllTimersAsync(); + await handleB; + + // Unsubscribe A only + handleA.unsubscribe(); + + triggerEvent(CHANNEL, EVENT, {msg: 'after-removal'}); + + expect(callbackA).not.toHaveBeenCalled(); + expect(callbackB).toHaveBeenCalledTimes(1); + expect(callbackB).toHaveBeenCalledWith({msg: 'after-removal'}); + + handleB.unsubscribe(); + }); + + it('should keep the channel subscribed until the last callback unsubscribes', async () => { + const handleA = Pusher.subscribe(CHANNEL, EVENT, jest.fn()); + await jest.runAllTimersAsync(); + await handleA; + + jest.spyOn(Pusher, 'isSubscribed').mockReturnValue(true); + const handleB = Pusher.subscribe(CHANNEL, EVENT, jest.fn()); + await jest.runAllTimersAsync(); + await handleB; + + // After unsubscribing A, mock socket should still have the channel + handleA.unsubscribe(); + expect(MockedPusher.getInstance().getChannel(CHANNEL)).toBeTruthy(); + + // After unsubscribing B (last callback), channel should be cleaned up + handleB.unsubscribe(); + expect(MockedPusher.getInstance().getChannel(CHANNEL)).toBeFalsy(); + }); + + it('should handle unsubscribe before subscription completes without errors', async () => { + const callback = jest.fn(); + + // Subscribe but do NOT flush timers yet — subscription is pending + const handle = Pusher.subscribe(CHANNEL, EVENT, callback); + + // Unsubscribe immediately (sets disposed = true) + handle.unsubscribe(); + + // Now flush — the InteractionManager callback should see disposed=true and skip binding + await jest.runAllTimersAsync(); + await expect(handle).resolves.toBeUndefined(); + + // Event should not reach the callback since it was never bound + triggerEvent(CHANNEL, EVENT); + expect(callback).not.toHaveBeenCalled(); + }); + + it('should handle multiple events on the same channel with independent cleanup', async () => { + const callbackX = jest.fn(); + const callbackY = jest.fn(); + + const handleX = Pusher.subscribe(CHANNEL, 'eventX', callbackX); + await jest.runAllTimersAsync(); + await handleX; + + jest.spyOn(Pusher, 'isSubscribed').mockReturnValue(true); + const handleY = Pusher.subscribe(CHANNEL, 'eventY', callbackY); + await jest.runAllTimersAsync(); + await handleY; + + // Both events should work + triggerEvent(CHANNEL, 'eventX', {type: 'x'}); + triggerEvent(CHANNEL, 'eventY', {type: 'y'}); + expect(callbackX).toHaveBeenCalledWith({type: 'x'}); + expect(callbackY).toHaveBeenCalledWith({type: 'y'}); + + // Unsubscribe eventX — eventY should still work + handleX.unsubscribe(); + callbackX.mockClear(); + callbackY.mockClear(); + + triggerEvent(CHANNEL, 'eventX', {type: 'x2'}); + triggerEvent(CHANNEL, 'eventY', {type: 'y2'}); + expect(callbackX).not.toHaveBeenCalled(); + expect(callbackY).toHaveBeenCalledWith({type: 'y2'}); + + // Unsubscribe eventY — channel should be fully cleaned up + handleY.unsubscribe(); + expect(MockedPusher.getInstance().getChannel(CHANNEL)).toBeFalsy(); + }); + + it('should clear all callbacks on disconnect so they do not fire after re-init', async () => { + const oldCallback = jest.fn(); + + const handle = Pusher.subscribe(CHANNEL, EVENT, oldCallback); + await jest.runAllTimersAsync(); + await handle; + + // Disconnect clears all callback tracking + Pusher.disconnect(); + jest.restoreAllMocks(); + + // Re-init and subscribe a new callback + jest.spyOn(Pusher, 'isSubscribed').mockReturnValue(false); + jest.spyOn(Pusher, 'isAlreadySubscribing').mockReturnValue(false); + await initPusher(); + + const newCallback = jest.fn(); + const newHandle = Pusher.subscribe(CHANNEL, EVENT, newCallback); + await jest.runAllTimersAsync(); + await newHandle; + + // Fire event — only new callback should receive it + triggerEvent(CHANNEL, EVENT, {session: 'new'}); + expect(oldCallback).not.toHaveBeenCalled(); + expect(newCallback).toHaveBeenCalledWith({session: 'new'}); + + newHandle.unsubscribe(); + }); + + it('should clean up channel when disposed mid-handshake before onSubscriptionSucceeded', async () => { + // Capture the onSubscriptionSucceeded callback so we can fire it manually + const mockSocket = MockedPusher.getInstance(); + let capturedOnSuccess: (() => void) | undefined; + + jest.spyOn(mockSocket, 'subscribe').mockImplementation(({channelName: cn, onEvent, onSubscriptionSucceeded}) => { + // Store the channel like the real mock, but DON'T call onSubscriptionSucceeded yet + mockSocket.channels.set(cn, {onEvent, onSubscriptionSucceeded}); + capturedOnSuccess = onSubscriptionSucceeded; + return Promise.resolve(); + }); + + const callback = jest.fn(); + const handle = Pusher.subscribe(CHANNEL, EVENT, callback); + + // Flush InteractionManager — socket.subscribe() fires, but onSubscriptionSucceeded is deferred + await jest.runAllTimersAsync(); + expect(capturedOnSuccess).toBeDefined(); + + // Dispose the handle mid-handshake (wrappedCb is still undefined) + handle.unsubscribe(); + + // Now fire onSubscriptionSucceeded — the disposed handle should trigger channel cleanup + capturedOnSuccess?.(); + await jest.runAllTimersAsync(); + await handle; + + // Channel should be cleaned up since no callbacks are bound + expect(mockSocket.channels.has(CHANNEL)).toBe(false); + + // Event should not reach the callback + triggerEvent(CHANNEL, EVENT); + expect(callback).not.toHaveBeenCalled(); + }); +});