diff --git a/eslint-suppressions.json b/eslint-suppressions.json index 09e332d490f..130f11eec5f 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -1231,14 +1231,6 @@ "count": 1 } }, - "packages/multichain-account-service/src/MultichainAccountService.test.ts": { - "@typescript-eslint/explicit-function-return-type": { - "count": 1 - }, - "@typescript-eslint/naming-convention": { - "count": 4 - } - }, "packages/multichain-account-service/src/MultichainAccountService.ts": { "id-length": { "count": 1 @@ -1296,14 +1288,6 @@ "count": 1 } }, - "packages/multichain-account-service/src/providers/SnapAccountProvider.ts": { - "@typescript-eslint/explicit-function-return-type": { - "count": 1 - }, - "@typescript-eslint/naming-convention": { - "count": 1 - } - }, "packages/multichain-account-service/src/providers/SolAccountProvider.ts": { "@typescript-eslint/naming-convention": { "count": 2 @@ -1312,11 +1296,6 @@ "count": 1 } }, - "packages/multichain-account-service/src/providers/TrxAccountProvider.test.ts": { - "no-negated-condition": { - "count": 1 - } - }, "packages/multichain-account-service/src/providers/TrxAccountProvider.ts": { "@typescript-eslint/naming-convention": { "count": 2 diff --git a/packages/multichain-account-service/CHANGELOG.md b/packages/multichain-account-service/CHANGELOG.md index efc68d47f78..fc022335baa 100644 --- a/packages/multichain-account-service/CHANGELOG.md +++ b/packages/multichain-account-service/CHANGELOG.md @@ -7,8 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Wait for Snap platform to be ready before any wallet/group operations ([#7266](https://github.com/MetaMask/core/pull/7266)) +- Add `SnapAccountProvider.withSnap` protected helper ([#7266](https://github.com/MetaMask/core/pull/7266)) + - This is used to protect any Snap operation behind a guard that checks if the Snap platform is ready. +- Add `MultichainAccountService:ensureCanUseSnapPlatform` method and action. + - This will resolve once the Snap platform is ready for the first time and will throw afterward if Snap platform has been disabled dynamically. + - This action is mostly used internally by any Snap-based account providers. + ### Changed +- **BREAKING:** The `SnapAccountProvider.client` property is now private ([#7266](https://github.com/MetaMask/core/pull/7266)) + - You now need to use `SnapAccountProvider.withSnap` to access to it. - Bump `@metamask/snaps-controllers` from `^14.0.1` to `^17.2.0` ([#7550](https://github.com/MetaMask/core/pull/7550)) - Bump `@metamask/snaps-sdk` from `^9.0.0` to `^10.3.0` ([#7550](https://github.com/MetaMask/core/pull/7550)) - Bump `@metamask/snaps-utils` from `^11.0.0` to `^11.7.0` ([#7550](https://github.com/MetaMask/core/pull/7550)) diff --git a/packages/multichain-account-service/package.json b/packages/multichain-account-service/package.json index 36f71736a30..3bac73b455b 100644 --- a/packages/multichain-account-service/package.json +++ b/packages/multichain-account-service/package.json @@ -64,7 +64,8 @@ "@metamask/snaps-utils": "^11.7.0", "@metamask/superstruct": "^3.1.0", "@metamask/utils": "^11.9.0", - "async-mutex": "^0.5.0" + "async-mutex": "^0.5.0", + "lodash": "^4.17.21" }, "devDependencies": { "@metamask/account-api": "^0.12.0", diff --git a/packages/multichain-account-service/src/MultichainAccountService.test.ts b/packages/multichain-account-service/src/MultichainAccountService.test.ts index 307a2a5922f..6ec71a17763 100644 --- a/packages/multichain-account-service/src/MultichainAccountService.test.ts +++ b/packages/multichain-account-service/src/MultichainAccountService.test.ts @@ -17,6 +17,7 @@ import { SOL_ACCOUNT_PROVIDER_NAME, SolAccountProvider, } from './providers/SolAccountProvider'; +import { SnapPlatformWatcher } from './snaps/SnapPlatformWatcher'; import { MOCK_HARDWARE_ACCOUNT_1, MOCK_HD_ACCOUNT_1, @@ -54,25 +55,40 @@ jest.mock('./providers/SolAccountProvider', () => { }); type Mocks = { + // eslint-disable-next-line @typescript-eslint/naming-convention KeyringController: { keyrings: KeyringObject[]; getState: jest.Mock; getKeyringsByType: jest.Mock; addNewKeyring: jest.Mock; }; + // eslint-disable-next-line @typescript-eslint/naming-convention AccountsController: { listMultichainAccounts: jest.Mock; }; + // eslint-disable-next-line @typescript-eslint/naming-convention + SnapController: { + getState: jest.Mock; + }; + // eslint-disable-next-line @typescript-eslint/naming-convention EvmAccountProvider: MockAccountProvider; + // eslint-disable-next-line @typescript-eslint/naming-convention SolAccountProvider: MockAccountProvider; }; +type Spies = { + // eslint-disable-next-line @typescript-eslint/naming-convention + SnapPlatformWatcher: { + ensureCanUseSnapPlatform: jest.SpyInstance; + }; +}; + function mockAccountProvider( providerClass: new (messenger: MultichainAccountServiceMessenger) => Provider, mocks: MockAccountProvider, accounts: KeyringAccount[], type: KeyringAccount['type'], -) { +): void { jest.mocked(providerClass).mockImplementation((...args) => { mocks.constructor(...args); return mocks as unknown as Provider; @@ -100,6 +116,7 @@ async function setup({ rootMessenger: RootMessenger; messenger: MultichainAccountServiceMessenger; mocks: Mocks; + spies: Spies; }> { const mocks: Mocks = { KeyringController: { @@ -111,13 +128,34 @@ async function setup({ AccountsController: { listMultichainAccounts: jest.fn(), }, + SnapController: { + getState: jest.fn(), + }, EvmAccountProvider: makeMockAccountProvider(), SolAccountProvider: makeMockAccountProvider(), }; + const spies: Spies = { + SnapPlatformWatcher: { + ensureCanUseSnapPlatform: jest.spyOn( + SnapPlatformWatcher.prototype, + 'ensureCanUseSnapPlatform', + ), + }, + }; + // Required for the `assert` on `MultichainAccountWallet.createMultichainAccountGroup`. Object.setPrototypeOf(mocks.EvmAccountProvider, EvmAccountProvider.prototype); + mocks.SnapController.getState.mockImplementation(() => ({ + isReady: true, + })); + + rootMessenger.registerActionHandler( + 'SnapController:getState', + mocks.SnapController.getState, + ); + mocks.KeyringController.getState.mockImplementation(() => ({ isUnlocked: true, keyrings: mocks.KeyringController.keyrings, @@ -181,6 +219,7 @@ async function setup({ rootMessenger, messenger, mocks, + spies, }; } @@ -1004,6 +1043,21 @@ describe('MultichainAccountService', () => { await messenger.call('MultichainAccountService:resyncAccounts'); expect(resyncAccountsSpy).toHaveBeenCalled(); }); + + it('checks for Snap platform readiness with MultichainAccountService:ensureCanUseSnapPlatform', async () => { + const { messenger, service } = await setup({ + accounts: [], + }); + + await service.ensureCanUseSnapPlatform(); + + const ensureCanUseSnapPlatformSpy = jest.spyOn( + service, + 'ensureCanUseSnapPlatform', + ); + await messenger.call('MultichainAccountService:ensureCanUseSnapPlatform'); + expect(ensureCanUseSnapPlatformSpy).toHaveBeenCalled(); + }); }); describe('resyncAccounts', () => { @@ -1247,4 +1301,18 @@ describe('MultichainAccountService', () => { expect(mocks.KeyringController.addNewKeyring).not.toHaveBeenCalled(); }); }); + + describe('ensureCanUseSnapPlatform', () => { + it('delegates Snap platform readiness check to SnapPlatformWatcher (method)', async () => { + const { service, spies } = await setup({ + accounts: [], + }); + + await service.ensureCanUseSnapPlatform(); + + expect( + spies.SnapPlatformWatcher.ensureCanUseSnapPlatform, + ).toHaveBeenCalledTimes(1); + }); + }); }); diff --git a/packages/multichain-account-service/src/MultichainAccountService.ts b/packages/multichain-account-service/src/MultichainAccountService.ts index 6b46735ea8f..55f04428212 100644 --- a/packages/multichain-account-service/src/MultichainAccountService.ts +++ b/packages/multichain-account-service/src/MultichainAccountService.ts @@ -31,6 +31,7 @@ import { SOL_ACCOUNT_PROVIDER_NAME, SolAccountProviderConfig, } from './providers/SolAccountProvider'; +import { SnapPlatformWatcher } from './snaps/SnapPlatformWatcher'; import type { MultichainAccountServiceConfig, MultichainAccountServiceMessenger, @@ -64,6 +65,8 @@ type AccountContext> = { export class MultichainAccountService { readonly #messenger: MultichainAccountServiceMessenger; + readonly #watcher: SnapPlatformWatcher; + readonly #providers: Bip44AccountProvider[]; readonly #wallets: Map< @@ -124,6 +127,8 @@ export class MultichainAccountService { ...providers, ]; + this.#watcher = new SnapPlatformWatcher(messenger); + this.#messenger.registerActionHandler( 'MultichainAccountService:getMultichainAccountGroup', (...args) => this.getMultichainAccountGroup(...args), @@ -168,6 +173,10 @@ export class MultichainAccountService { 'MultichainAccountService:resyncAccounts', (...args) => this.resyncAccounts(...args), ); + this.#messenger.registerActionHandler( + 'MultichainAccountService:ensureCanUseSnapPlatform', + (...args) => this.ensureCanUseSnapPlatform(...args), + ); this.#messenger.subscribe('AccountsController:accountAdded', (account) => this.#handleOnAccountAdded(account), @@ -261,6 +270,10 @@ export class MultichainAccountService { log('Providers got re-synced!'); } + ensureCanUseSnapPlatform(): Promise { + return this.#watcher.ensureCanUseSnapPlatform(); + } + #handleOnAccountAdded(account: KeyringAccount): void { // We completely omit non-BIP-44 accounts! if (!isBip44Account(account)) { diff --git a/packages/multichain-account-service/src/providers/BtcAccountProvider.test.ts b/packages/multichain-account-service/src/providers/BtcAccountProvider.test.ts index 6425be5a907..4c84cf451bb 100644 --- a/packages/multichain-account-service/src/providers/BtcAccountProvider.test.ts +++ b/packages/multichain-account-service/src/providers/BtcAccountProvider.test.ts @@ -6,6 +6,7 @@ import type { EthKeyring, InternalAccount, } from '@metamask/keyring-internal-api'; +import { SnapControllerState } from '@metamask/snaps-controllers'; import { AccountProviderWrapper } from './AccountProviderWrapper'; import { @@ -98,6 +99,11 @@ class MockBtcKeyring { return account; }); } +class MockBtcAccountProvider extends BtcAccountProvider { + override async ensureCanUseSnapPlatform(): Promise { + // Override to avoid waiting during tests. + } +} /** * Sets up a BtcAccountProvider for testing. @@ -129,6 +135,11 @@ function setup({ } { const keyring = new MockBtcKeyring(accounts); + messenger.registerActionHandler( + 'SnapController:getState', + () => ({ isReady: true }) as SnapControllerState, + ); + messenger.registerActionHandler( 'AccountsController:listMultichainAccounts', () => accounts, @@ -158,7 +169,7 @@ function setup({ const multichainMessenger = getMultichainAccountServiceMessenger(messenger); const provider = new AccountProviderWrapper( multichainMessenger, - new BtcAccountProvider(multichainMessenger, config), + new MockBtcAccountProvider(multichainMessenger, config), ); return { @@ -373,7 +384,7 @@ describe('BtcAccountProvider', () => { const multichainMessenger = getMultichainAccountServiceMessenger(messenger); - const btcProvider = new BtcAccountProvider( + const btcProvider = new MockBtcAccountProvider( multichainMessenger, undefined, mockTrace, @@ -425,7 +436,7 @@ describe('BtcAccountProvider', () => { const multichainMessenger = getMultichainAccountServiceMessenger(messenger); - const btcProvider = new BtcAccountProvider( + const btcProvider = new MockBtcAccountProvider( multichainMessenger, undefined, mockTrace, @@ -458,7 +469,7 @@ describe('BtcAccountProvider', () => { const multichainMessenger = getMultichainAccountServiceMessenger(messenger); - const btcProvider = new BtcAccountProvider( + const btcProvider = new MockBtcAccountProvider( multichainMessenger, undefined, mockTrace, diff --git a/packages/multichain-account-service/src/providers/BtcAccountProvider.ts b/packages/multichain-account-service/src/providers/BtcAccountProvider.ts index 7e9f059f205..fd9e78cef2e 100644 --- a/packages/multichain-account-service/src/providers/BtcAccountProvider.ts +++ b/packages/multichain-account-service/src/providers/BtcAccountProvider.ts @@ -7,7 +7,10 @@ import type { InternalAccount } from '@metamask/keyring-internal-api'; import type { SnapId } from '@metamask/snaps-sdk'; import { SnapAccountProvider } from './SnapAccountProvider'; -import type { SnapAccountProviderConfig } from './SnapAccountProvider'; +import type { + RestrictedSnapKeyring, + SnapAccountProviderConfig, +} from './SnapAccountProvider'; import { withRetry, withTimeout } from './utils'; import { traceFallback } from '../analytics'; import { TraceName } from '../constants/traces'; @@ -54,18 +57,18 @@ export class BtcAccountProvider extends SnapAccountProvider { ); } - async createAccounts({ + async #createAccounts({ + keyring, entropySource, groupIndex: index, }: { + keyring: RestrictedSnapKeyring; entropySource: EntropySourceId; groupIndex: number; }): Promise[]> { return this.withMaxConcurrency(async () => { - const createAccount = await this.getRestrictedSnapAccountCreator(); - const account = await withTimeout( - createAccount({ + keyring.createAccount({ entropySource, index, addressType: BtcAccountType.P2wpkh, @@ -79,6 +82,18 @@ export class BtcAccountProvider extends SnapAccountProvider { }); } + async createAccounts({ + entropySource, + groupIndex, + }: { + entropySource: EntropySourceId; + groupIndex: number; + }): Promise[]> { + return this.withSnap(async ({ keyring }) => { + return this.#createAccounts({ keyring, entropySource, groupIndex }); + }); + } + async discoverAccounts({ entropySource, groupIndex, @@ -86,48 +101,51 @@ export class BtcAccountProvider extends SnapAccountProvider { entropySource: EntropySourceId; groupIndex: number; }): Promise[]> { - return await super.trace( - { - name: TraceName.SnapDiscoverAccounts, - data: { - provider: this.getName(), + return this.withSnap(async ({ client, keyring }) => { + return await super.trace( + { + name: TraceName.SnapDiscoverAccounts, + data: { + provider: this.getName(), + }, }, - }, - async () => { - if (!this.config.discovery.enabled) { - return []; - } - - const discoveredAccounts = await withRetry( - () => - withTimeout( - this.client.discoverAccounts( - [BtcScope.Mainnet], - entropySource, - groupIndex, + async () => { + if (!this.config.discovery.enabled) { + return []; + } + + const discoveredAccounts = await withRetry( + () => + withTimeout( + client.discoverAccounts( + [BtcScope.Mainnet], + entropySource, + groupIndex, + ), + this.config.discovery.timeoutMs, ), - this.config.discovery.timeoutMs, - ), - { - maxAttempts: this.config.discovery.maxAttempts, - backOffMs: this.config.discovery.backOffMs, - }, - ); - - if ( - !Array.isArray(discoveredAccounts) || - discoveredAccounts.length === 0 - ) { - return []; - } - - const createdAccounts = await this.createAccounts({ - entropySource, - groupIndex, - }); - - return createdAccounts; - }, - ); + { + maxAttempts: this.config.discovery.maxAttempts, + backOffMs: this.config.discovery.backOffMs, + }, + ); + + if ( + !Array.isArray(discoveredAccounts) || + discoveredAccounts.length === 0 + ) { + return []; + } + + const createdAccounts = await this.#createAccounts({ + keyring, + entropySource, + groupIndex, + }); + + return createdAccounts; + }, + ); + }); } } diff --git a/packages/multichain-account-service/src/providers/SnapAccountProvider.test.ts b/packages/multichain-account-service/src/providers/SnapAccountProvider.test.ts index 7cf7b5e9d09..20be597d904 100644 --- a/packages/multichain-account-service/src/providers/SnapAccountProvider.test.ts +++ b/packages/multichain-account-service/src/providers/SnapAccountProvider.test.ts @@ -1,11 +1,14 @@ import { isBip44Account } from '@metamask/account-api'; import type { Bip44Account } from '@metamask/account-api'; import type { TraceCallback, TraceRequest } from '@metamask/controller-utils'; +import { KeyringRpcMethod } from '@metamask/keyring-api'; +import type { GetAccountRequest } from '@metamask/keyring-api'; import type { EntropySourceId, KeyringAccount } from '@metamask/keyring-api'; import type { InternalAccount } from '@metamask/keyring-internal-api'; -import type { SnapId } from '@metamask/snaps-sdk'; +import type { JsonRpcRequest, SnapId } from '@metamask/snaps-sdk'; import { BtcAccountProvider } from './BtcAccountProvider'; +import type { SnapAccountProviderConfig } from './SnapAccountProvider'; import { isSnapAccountProvider, SnapAccountProvider, @@ -36,28 +39,64 @@ const THROTTLED_OPERATION_DELAY_MS = 10; const TEST_SNAP_ID = 'npm:@metamask/test-snap' as SnapId; const TEST_ENTROPY_SOURCE = 'test-entropy-source' as EntropySourceId; -// Helper to create a test provider that exposes protected trace method -class TestSnapAccountProvider extends SnapAccountProvider { +class MockSnapAccountProvider extends SnapAccountProvider { + readonly tracker: { + startLog: number[]; + endLog: number[]; + activeCount: number; + maxActiveCount: number; + }; + + constructor( + snapId: SnapId, + messenger: MultichainAccountServiceMessenger, + config: SnapAccountProviderConfig, + /* istanbul ignore next */ + trace: TraceCallback = traceFallback, + ) { + super(snapId, messenger, config, trace); + + // Tracker to monitor concurrent executions. + this.tracker = { + startLog: [], + endLog: [], + activeCount: 0, + maxActiveCount: 0, + }; + } + getName(): string { return 'Test Provider'; } - isAccountCompatible(_account: Bip44Account): boolean { + isAccountCompatible(): boolean { return true; } - async discoverAccounts(_options: { - entropySource: EntropySourceId; - groupIndex: number; - }): Promise[]> { + async discoverAccounts(): Promise[]> { return []; } - async createAccounts(_options: { + async createAccounts(options: { entropySource: EntropySourceId; groupIndex: number; }): Promise[]> { - return []; + const { tracker } = this; + + return this.withMaxConcurrency(async () => { + tracker.startLog.push(options.groupIndex); + tracker.activeCount += 1; + tracker.maxActiveCount = Math.max( + tracker.maxActiveCount, + tracker.activeCount, + ); + await new Promise((resolve) => + setTimeout(resolve, THROTTLED_OPERATION_DELAY_MS), + ); + tracker.activeCount -= 1; + tracker.endLog.push(options.groupIndex); + return []; + }); } // Expose protected trace method as public for testing @@ -73,52 +112,69 @@ class TestSnapAccountProvider extends SnapAccountProvider { const setup = ({ maxConcurrency, messenger = getRootMessenger(), -}: { maxConcurrency?: number; messenger?: RootMessenger } = {}) => { - const tracker: { - startLog: number[]; - endLog: number[]; - activeCount: number; - maxActiveCount: number; - } = { - startLog: [], - endLog: [], - activeCount: 0, - maxActiveCount: 0, + accounts = [], +}: { + maxConcurrency?: number; + messenger?: RootMessenger; + accounts?: InternalAccount[]; +} = {}) => { + const mocks = { + AccountsController: { + listMultichainAccounts: jest.fn(), + }, + ErrorReportingService: { + captureException: jest.fn(), + }, + SnapController: { + handleKeyringRequest: { + getAccount: jest.fn(), + listAccounts: jest.fn(), + }, + handleRequest: jest.fn(), + }, + MultichainAccountService: { + ensureCanUseSnapPlatform: jest.fn(), + }, }; - class MockSnapAccountProvider extends SnapAccountProvider { - getName(): string { - return 'Test Provider'; - } + messenger.registerActionHandler( + 'AccountsController:listMultichainAccounts', + mocks.AccountsController.listMultichainAccounts, + ); + mocks.AccountsController.listMultichainAccounts.mockReturnValue(accounts); - isAccountCompatible(): boolean { - return true; - } + messenger.registerActionHandler( + 'MultichainAccountService:ensureCanUseSnapPlatform', + mocks.MultichainAccountService.ensureCanUseSnapPlatform, + ); + // Make the platform ready right away (having a resolved promise is enough). + mocks.MultichainAccountService.ensureCanUseSnapPlatform.mockResolvedValue( + undefined, + ); - async discoverAccounts(): Promise[]> { - return []; - } - - async createAccounts(options: { - entropySource: EntropySourceId; - groupIndex: number; - }): Promise[]> { - return this.withMaxConcurrency(async () => { - tracker.startLog.push(options.groupIndex); - tracker.activeCount += 1; - tracker.maxActiveCount = Math.max( - tracker.maxActiveCount, - tracker.activeCount, - ); - await new Promise((resolve) => - setTimeout(resolve, THROTTLED_OPERATION_DELAY_MS), + messenger.registerActionHandler( + 'SnapController:handleRequest', + mocks.SnapController.handleRequest, + ); + mocks.SnapController.handleRequest.mockImplementation( + async ({ request }: { request: JsonRpcRequest }) => { + if (request.method === String(KeyringRpcMethod.GetAccount)) { + return await mocks.SnapController.handleKeyringRequest.getAccount( + (request as GetAccountRequest).params.id, ); - tracker.activeCount -= 1; - tracker.endLog.push(options.groupIndex); - return []; - }); - } - } + } else if (request.method === String(KeyringRpcMethod.ListAccounts)) { + return await mocks.SnapController.handleKeyringRequest.listAccounts(); + } + throw new Error(`Unhandled method: ${request.method}`); + }, + ); + mocks.SnapController.handleKeyringRequest.getAccount.mockImplementation( + async (id) => + accounts.map(asKeyringAccount).find((account) => account.id === id), + ); + mocks.SnapController.handleKeyringRequest.listAccounts.mockImplementation( + async () => accounts.map(asKeyringAccount), + ); const keyring = { createAccount: jest.fn(), @@ -134,7 +190,10 @@ const setup = ({ ), ); - const serviceMessenger = getMultichainAccountServiceMessenger(messenger); + const serviceMessenger = getMultichainAccountServiceMessenger(messenger, { + // We need this extra action to be able to mock it. + actions: ['MultichainAccountService:ensureCanUseSnapPlatform'], + }); const config = { ...(maxConcurrency !== undefined && { maxConcurrency }), createAccounts: { @@ -152,42 +211,44 @@ const setup = ({ config, ); - return { messenger, provider, tracker, keyring }; + return { + messenger, + provider, + tracker: provider.tracker, + keyring, + mocks, + }; }; describe('SnapAccountProvider', () => { describe('constructor default parameters', () => { - const mockMessenger = { - call: jest.fn().mockResolvedValue({}), - registerActionHandler: jest.fn(), - subscribe: jest.fn(), - registerMethodActionHandlers: jest.fn(), - unregisterActionHandler: jest.fn(), - registerInitialEventPayload: jest.fn(), - publish: jest.fn(), - clearEventSubscriptions: jest.fn(), - } as unknown as MultichainAccountServiceMessenger; - - beforeEach(() => { - jest.clearAllMocks(); - }); - it('creates SolAccountProvider with default trace using 1 parameter', () => { - const provider = new SolAccountProvider(mockMessenger); + const { messenger } = setup(); + + const provider = new SolAccountProvider( + getMultichainAccountServiceMessenger(messenger), + ); expect(provider).toBeDefined(); expect(provider.snapId).toBe(SolAccountProvider.SOLANA_SNAP_ID); }); it('creates SolAccountProvider with default trace using 2 parameters', () => { - const provider = new SolAccountProvider(mockMessenger, undefined); + const { messenger } = setup(); + + const provider = new SolAccountProvider( + getMultichainAccountServiceMessenger(messenger), + undefined, + ); expect(provider).toBeDefined(); expect(provider.snapId).toBe(SolAccountProvider.SOLANA_SNAP_ID); }); it('creates SolAccountProvider with custom trace using 3 parameters', () => { + const { messenger } = setup(); + const customTrace = jest.fn(); const provider = new SolAccountProvider( - mockMessenger, + getMultichainAccountServiceMessenger(messenger), undefined, customTrace, ); @@ -196,6 +257,8 @@ describe('SnapAccountProvider', () => { }); it('creates SolAccountProvider with custom config and default trace', () => { + const { messenger } = setup(); + const customConfig = { discovery: { timeoutMs: 3000, @@ -206,25 +269,34 @@ describe('SnapAccountProvider', () => { timeoutMs: 5000, }, }; - const provider = new SolAccountProvider(mockMessenger, customConfig); + const provider = new SolAccountProvider( + getMultichainAccountServiceMessenger(messenger), + customConfig, + ); expect(provider).toBeDefined(); expect(provider.snapId).toBe(SolAccountProvider.SOLANA_SNAP_ID); }); it('creates BtcAccountProvider with default trace', () => { + const { messenger } = setup(); + // Test other subclasses to ensure branch coverage - const btcProvider = new BtcAccountProvider(mockMessenger); + const btcProvider = new BtcAccountProvider( + getMultichainAccountServiceMessenger(messenger), + ); expect(btcProvider).toBeDefined(); expect(isSnapAccountProvider(btcProvider)).toBe(true); }); it('creates TrxAccountProvider with custom trace', () => { + const { messenger } = setup(); + const customTrace = jest.fn(); // Explicitly test with all three parameters const trxProvider = new TrxAccountProvider( - mockMessenger, + getMultichainAccountServiceMessenger(messenger), undefined, customTrace, ); @@ -234,20 +306,29 @@ describe('SnapAccountProvider', () => { }); it('creates provider without trace parameter', () => { + const { messenger } = setup(); + // Test creating provider without passing trace parameter - const provider = new SolAccountProvider(mockMessenger, undefined); + const provider = new SolAccountProvider( + getMultichainAccountServiceMessenger(messenger), + undefined, + ); expect(provider).toBeDefined(); }); it('tests parameter spreading to trigger branch coverage', () => { + const { messenger } = setup(); + type SolConfig = ConstructorParameters[1]; type ProviderArgs = [ MultichainAccountServiceMessenger, SolConfig?, TraceCallback?, ]; - const args: ProviderArgs = [mockMessenger]; + const args: ProviderArgs = [ + getMultichainAccountServiceMessenger(messenger), + ]; const provider1 = new SolAccountProvider(...args); args.push(undefined); @@ -287,35 +368,16 @@ describe('SnapAccountProvider', () => { }); it('returns true for actual SnapAccountProvider instance', () => { - // Create a mock messenger with required methods - const mockMessenger = { - call: jest.fn(), - registerActionHandler: jest.fn(), - subscribe: jest.fn(), - registerMethodActionHandlers: jest.fn(), - unregisterActionHandler: jest.fn(), - registerInitialEventPayload: jest.fn(), - publish: jest.fn(), - clearEventSubscriptions: jest.fn(), - } as unknown as MultichainAccountServiceMessenger; - - const solProvider = new SolAccountProvider(mockMessenger); + const { messenger } = setup(); + + const solProvider = new SolAccountProvider( + getMultichainAccountServiceMessenger(messenger), + ); expect(isSnapAccountProvider(solProvider)).toBe(true); }); }); describe('trace functionality', () => { - const mockMessenger = { - call: jest.fn().mockResolvedValue({}), - registerActionHandler: jest.fn(), - subscribe: jest.fn(), - registerMethodActionHandlers: jest.fn(), - unregisterActionHandler: jest.fn(), - registerInitialEventPayload: jest.fn(), - publish: jest.fn(), - clearEventSubscriptions: jest.fn(), - } as unknown as MultichainAccountServiceMessenger; - const traceFallbackMock = traceFallback as jest.MockedFunction< typeof traceFallback >; @@ -326,6 +388,8 @@ describe('SnapAccountProvider', () => { }); it('uses default trace parameter when only messenger is provided', async () => { + const { messenger } = setup(); + traceFallbackMock.mockImplementation(async (_request, fn) => fn?.()); // Test with default config and trace @@ -339,9 +403,9 @@ describe('SnapAccountProvider', () => { timeoutMs: 3000, }, }; - const testProvider = new TestSnapAccountProvider( + const testProvider = new MockSnapAccountProvider( TEST_SNAP_ID, - mockMessenger, + getMultichainAccountServiceMessenger(messenger), defaultConfig, ); const request = { name: 'Test Request', data: {} }; @@ -358,14 +422,16 @@ describe('SnapAccountProvider', () => { }); it('uses custom trace when explicitly provided with all parameters', async () => { + const { messenger } = setup(); + const customTrace = jest.fn().mockImplementation(async (_request, fn) => { return await fn(); }); // Test with all parameters including custom trace - const testProvider = new TestSnapAccountProvider( + const testProvider = new MockSnapAccountProvider( TEST_SNAP_ID, - mockMessenger, + getMultichainAccountServiceMessenger(messenger), { discovery: { timeoutMs: 2000, @@ -390,6 +456,8 @@ describe('SnapAccountProvider', () => { }); it('calls trace callback with the correct arguments', async () => { + const { messenger } = setup(); + const mockTrace = jest.fn().mockImplementation(async (request, fn) => { expect(request).toStrictEqual({ name: 'Test Request', @@ -408,9 +476,9 @@ describe('SnapAccountProvider', () => { timeoutMs: 3000, }, }; - const testProvider = new TestSnapAccountProvider( + const testProvider = new MockSnapAccountProvider( TEST_SNAP_ID, - mockMessenger, + getMultichainAccountServiceMessenger(messenger), defaultConfig, mockTrace, ); @@ -425,6 +493,8 @@ describe('SnapAccountProvider', () => { }); it('propagates errors through trace callback', async () => { + const { messenger } = setup(); + const mockError = new Error('Test error'); const mockTrace = jest.fn().mockImplementation(async (_request, fn) => { return await fn(); @@ -440,9 +510,9 @@ describe('SnapAccountProvider', () => { timeoutMs: 3000, }, }; - const testProvider = new TestSnapAccountProvider( + const testProvider = new MockSnapAccountProvider( TEST_SNAP_ID, - mockMessenger, + getMultichainAccountServiceMessenger(messenger), defaultConfig, mockTrace, ); @@ -456,6 +526,8 @@ describe('SnapAccountProvider', () => { }); it('handles trace callback returning undefined', async () => { + const { messenger } = setup(); + const mockTrace = jest.fn().mockImplementation(async (_request, fn) => { return await fn(); }); @@ -470,9 +542,9 @@ describe('SnapAccountProvider', () => { timeoutMs: 3000, }, }; - const testProvider = new TestSnapAccountProvider( + const testProvider = new MockSnapAccountProvider( TEST_SNAP_ID, - mockMessenger, + getMultichainAccountServiceMessenger(messenger), defaultConfig, mockTrace, ); @@ -592,12 +664,7 @@ describe('SnapAccountProvider', () => { ].filter(isBip44Account); it('does not create any accounts if already in-sync', async () => { - const { provider, messenger } = setup(); - - messenger.registerActionHandler( - 'SnapController:handleRequest', - jest.fn().mockResolvedValue(mockAccounts.map(asKeyringAccount)), - ); + const { provider } = setup({ accounts: mockAccounts }); const createAccountsSpy = jest.spyOn(provider, 'createAccounts'); @@ -607,14 +674,11 @@ describe('SnapAccountProvider', () => { }); it('creates new accounts if de-synced', async () => { - const { provider, messenger } = setup(); - const captureExceptionSpy = jest.spyOn(messenger, 'captureException'); - - messenger.registerActionHandler( - 'SnapController:handleRequest', - jest.fn().mockResolvedValue([mockAccounts[0]].map(asKeyringAccount)), - ); + const { provider, messenger } = setup({ + accounts: [mockAccounts[0]], + }); + const captureExceptionSpy = jest.spyOn(messenger, 'captureException'); const createAccountsSpy = jest.spyOn(provider, 'createAccounts'); await provider.resyncAccounts(mockAccounts); @@ -633,13 +697,9 @@ describe('SnapAccountProvider', () => { }); it('reports an error if a Snap has more accounts than MetaMask', async () => { - const { provider, messenger } = setup(); - const captureExceptionSpy = jest.spyOn(messenger, 'captureException'); + const { provider, messenger } = setup({ accounts: mockAccounts }); - messenger.registerActionHandler( - 'SnapController:handleRequest', - jest.fn().mockResolvedValue(mockAccounts.map(asKeyringAccount)), - ); + const captureExceptionSpy = jest.spyOn(messenger, 'captureException'); await provider.resyncAccounts([mockAccounts[0]]); // Less accounts than the Snap @@ -651,14 +711,9 @@ describe('SnapAccountProvider', () => { }); it('does not throw errors if any provider is not able to re-sync', async () => { - const { provider, messenger } = setup(); - const captureExceptionSpy = jest.spyOn(messenger, 'captureException'); - - messenger.registerActionHandler( - 'SnapController:handleRequest', - jest.fn().mockResolvedValue([mockAccounts[0]].map(asKeyringAccount)), - ); + const { provider, messenger } = setup({ accounts: [mockAccounts[0]] }); + const captureExceptionSpy = jest.spyOn(messenger, 'captureException'); const createAccountsSpy = jest.spyOn(provider, 'createAccounts'); const providerError = new Error('Unable to create accounts'); @@ -684,4 +739,16 @@ describe('SnapAccountProvider', () => { ); }); }); + + describe('ensureCanUseSnapPlatform', () => { + it('delegates Snap platform readiness check to SnapPlatformWatcher', async () => { + const { provider, mocks } = setup(); + + await provider.ensureCanUseSnapPlatform(); + + expect( + mocks.MultichainAccountService.ensureCanUseSnapPlatform, + ).toHaveBeenCalledTimes(1); + }); + }); }); diff --git a/packages/multichain-account-service/src/providers/SnapAccountProvider.ts b/packages/multichain-account-service/src/providers/SnapAccountProvider.ts index b8a14343693..a6cd24bfd05 100644 --- a/packages/multichain-account-service/src/providers/SnapAccountProvider.ts +++ b/packages/multichain-account-service/src/providers/SnapAccountProvider.ts @@ -15,9 +15,10 @@ import { traceFallback } from '../analytics'; import type { MultichainAccountServiceMessenger } from '../types'; import { createSentryError } from '../utils'; -export type RestrictedSnapKeyringCreateAccount = ( - options: Record, -) => Promise; +export type RestrictedSnapKeyring = { + createAccount: (options: Record) => Promise; + removeAccount: (address: string) => Promise; +}; export type SnapAccountProviderConfig = { maxConcurrency?: number; @@ -37,7 +38,7 @@ export abstract class SnapAccountProvider extends BaseBip44AccountProvider { protected readonly config: SnapAccountProviderConfig; - protected readonly client: KeyringClient; + readonly #client: KeyringClient; readonly #queue?: Semaphore; @@ -53,7 +54,7 @@ export abstract class SnapAccountProvider extends BaseBip44AccountProvider { super(messenger); this.snapId = snapId; - this.client = this.#getKeyringClientFromSnapId(snapId); + this.#client = this.#getKeyringClientFromSnapId(snapId); const maxConcurrency = config.maxConcurrency ?? Infinity; this.config = { @@ -73,6 +74,18 @@ export abstract class SnapAccountProvider extends BaseBip44AccountProvider { this.#trace = trace; } + /** + * Ensures that the Snap platform is ready to be used. + * + * @returns A promise that resolves when the platform is ready. + * @throws An error if the platform is not ready (only effective once the platform has been ready at least once). + */ + async ensureCanUseSnapPlatform(): Promise { + return this.messenger.call( + 'MultichainAccountService:ensureCanUseSnapPlatform', + ); + } + /** * Wraps an async operation with concurrency limiting based on maxConcurrency config. * If maxConcurrency is Infinity (the default), the operation runs immediately without throttling. @@ -81,9 +94,9 @@ export abstract class SnapAccountProvider extends BaseBip44AccountProvider { * @param operation - The async operation to execute. * @returns The result of the operation. */ - protected async withMaxConcurrency( - operation: () => Promise, - ): Promise { + protected async withMaxConcurrency( + operation: () => Promise, + ): Promise { if (this.#queue) { return this.#queue.runExclusive(operation); } @@ -97,7 +110,7 @@ export abstract class SnapAccountProvider extends BaseBip44AccountProvider { return this.#trace(request, fn); } - protected async getRestrictedSnapAccountCreator(): Promise { + async #getRestrictedSnapKeyring(): Promise { // NOTE: We're not supposed to make the keyring instance escape `withKeyring` but // we have to use the `SnapKeyring` instance to be able to create Solana account // without triggering UI confirmation. @@ -108,17 +121,25 @@ export abstract class SnapAccountProvider extends BaseBip44AccountProvider { SnapKeyring['createAccount'] >(async ({ keyring }) => keyring.createAccount.bind(keyring)); - return (options) => - createAccount(this.snapId, options, { - displayAccountNameSuggestion: false, - displayConfirmation: false, - setSelectedAccount: false, - }); + return { + createAccount: async (options) => + // We use the "unguarded" account creation here (see explanation above). + await createAccount(this.snapId, options, { + displayAccountNameSuggestion: false, + displayConfirmation: false, + setSelectedAccount: false, + }), + removeAccount: async (address: string) => + // Though, when removing account, we can use the normal flow. + await this.#withSnapKeyring(async ({ keyring }) => { + await keyring.removeAccount(address); + }), + }; } #getKeyringClientFromSnapId(snapId: string): KeyringClient { return new KeyringClient({ - send: async (request: JsonRpcRequest) => { + send: async (request: JsonRpcRequest): Promise => { const response = await this.messenger.call( 'SnapController:handleRequest', { @@ -136,73 +157,71 @@ export abstract class SnapAccountProvider extends BaseBip44AccountProvider { async resyncAccounts( accounts: Bip44Account[], ): Promise { - const localSnapAccounts = accounts.filter( - (account) => - account.metadata.snap && account.metadata.snap.id === this.snapId, - ); - const snapAccounts = new Set( - (await this.client.listAccounts()).map((account) => account.id), - ); - - // NOTE: This should never happen, but we want to report that kind of errors still - // in case states are de-sync. - if (localSnapAccounts.length < snapAccounts.size) { - this.messenger.captureException?.( - new Error( - `Snap "${this.snapId}" has de-synced accounts, Snap has more accounts than MetaMask!`, - ), + await this.withSnap(async ({ keyring }) => { + const localSnapAccounts = accounts.filter( + (account) => + account.metadata.snap && account.metadata.snap.id === this.snapId, + ); + const snapAccounts = new Set( + (await this.#client.listAccounts()).map((account) => account.id), ); - // We don't recover from this case yet. - return; - } + // NOTE: This should never happen, but we want to report that kind of errors still + // in case states are de-sync. + if (localSnapAccounts.length < snapAccounts.size) { + this.messenger.captureException?.( + new Error( + `Snap "${this.snapId}" has de-synced accounts, Snap has more accounts than MetaMask!`, + ), + ); - // We want this part to be fast, so we only check for sizes, but we might need - // to make a real "diff" between the 2 states to not miss any de-sync. - if (localSnapAccounts.length > snapAccounts.size) { - // Accounts should never really be de-synced, so we want to log this to see how often this - // happens, cause that means that something else is buggy elsewhere... - this.messenger.captureException?.( - new Error( - `Snap "${this.snapId}" has de-synced accounts, we'll attempt to re-sync them...`, - ), - ); + // We don't recover from this case yet. + return; + } - // We always use the MetaMask list as the main reference here. - await Promise.all( - localSnapAccounts.map(async (account) => { - const { id: entropySource, groupIndex } = account.options.entropy; - - try { - if (!snapAccounts.has(account.id)) { - // We still need to remove the accounts from the Snap keyring since we're - // about to create the same account again, which will use a new ID, but will - // keep using the same address, and the Snap keyring does not allow this. - await this.#withSnapKeyring( - async ({ keyring }) => - await keyring.removeAccount(account.address), - ); + // We want this part to be fast, so we only check for sizes, but we might need + // to make a real "diff" between the 2 states to not miss any de-sync. + if (localSnapAccounts.length > snapAccounts.size) { + // Accounts should never really be de-synced, so we want to log this to see how often this + // happens, cause that means that something else is buggy elsewhere... + this.messenger.captureException?.( + new Error( + `Snap "${this.snapId}" has de-synced accounts, we'll attempt to re-sync them...`, + ), + ); + + // We always use the MetaMask list as the main reference here. + await Promise.all( + localSnapAccounts.map(async (account) => { + const { id: entropySource, groupIndex } = account.options.entropy; - // The Snap has no account in its state for this one, we re-create it. - await this.createAccounts({ - entropySource, - groupIndex, - }); + try { + if (!snapAccounts.has(account.id)) { + // We still need to remove the accounts from the Snap keyring since we're + // about to create the same account again, which will use a new ID, but will + // keep using the same address, and the Snap keyring does not allow this. + await keyring.removeAccount(account.address); + // The Snap has no account in its state for this one, we re-create it. + await this.createAccounts({ + entropySource, + groupIndex, + }); + } + } catch (error) { + const sentryError = createSentryError( + `Unable to re-sync account: ${groupIndex}`, + error as Error, + { + provider: this.getName(), + groupIndex, + }, + ); + this.messenger.captureException?.(sentryError); } - } catch (error) { - const sentryError = createSentryError( - `Unable to re-sync account: ${groupIndex}`, - error as Error, - { - provider: this.getName(), - groupIndex, - }, - ); - this.messenger.captureException?.(sentryError); - } - }), - ); - } + }), + ); + } + }); } async #withSnapKeyring( @@ -222,6 +241,20 @@ export abstract class SnapAccountProvider extends BaseBip44AccountProvider { ); } + protected async withSnap( + operation: (snap: { + client: KeyringClient; + keyring: RestrictedSnapKeyring; + }) => Promise, + ): Promise { + await this.ensureCanUseSnapPlatform(); + + return await operation({ + client: this.#client, + keyring: await this.#getRestrictedSnapKeyring(), + }); + } + abstract isAccountCompatible(account: Bip44Account): boolean; abstract createAccounts(options: { diff --git a/packages/multichain-account-service/src/providers/SolAccountProvider.test.ts b/packages/multichain-account-service/src/providers/SolAccountProvider.test.ts index 8d218513630..5234c8ac85d 100644 --- a/packages/multichain-account-service/src/providers/SolAccountProvider.test.ts +++ b/packages/multichain-account-service/src/providers/SolAccountProvider.test.ts @@ -5,6 +5,7 @@ import type { EthKeyring, InternalAccount, } from '@metamask/keyring-internal-api'; +import { SnapControllerState } from '@metamask/snaps-controllers'; import { AccountProviderWrapper } from './AccountProviderWrapper'; import { SnapAccountProviderConfig } from './SnapAccountProvider'; @@ -82,6 +83,12 @@ class MockSolanaKeyring { }); } +class MockSolAccountProvider extends SolAccountProvider { + override async ensureCanUseSnapPlatform(): Promise { + // Override to avoid waiting during tests. + } +} + /** * Sets up a SolAccountProvider for testing. * @@ -113,6 +120,11 @@ function setup({ } { const keyring = new MockSolanaKeyring(accounts); + messenger.registerActionHandler( + 'SnapController:getState', + () => ({ isReady: true }) as SnapControllerState, + ); + messenger.registerActionHandler( 'AccountsController:listMultichainAccounts', () => accounts, @@ -151,7 +163,7 @@ function setup({ const multichainMessenger = getMultichainAccountServiceMessenger(messenger); const provider = new AccountProviderWrapper( multichainMessenger, - new SolAccountProvider(multichainMessenger, config, mockTrace), + new MockSolAccountProvider(multichainMessenger, config, mockTrace), ); return { @@ -358,7 +370,7 @@ describe('SolAccountProvider', () => { const multichainMessenger = getMultichainAccountServiceMessenger(messenger); - const solProvider = new SolAccountProvider( + const solProvider = new MockSolAccountProvider( multichainMessenger, undefined, mocks.trace, @@ -401,7 +413,7 @@ describe('SolAccountProvider', () => { const multichainMessenger = getMultichainAccountServiceMessenger(messenger); - const solProvider = new SolAccountProvider( + const solProvider = new MockSolAccountProvider( multichainMessenger, undefined, mocks.trace, @@ -430,7 +442,7 @@ describe('SolAccountProvider', () => { const multichainMessenger = getMultichainAccountServiceMessenger(messenger); - const solProvider = new SolAccountProvider( + const solProvider = new MockSolAccountProvider( multichainMessenger, undefined, mocks.trace, diff --git a/packages/multichain-account-service/src/providers/SolAccountProvider.ts b/packages/multichain-account-service/src/providers/SolAccountProvider.ts index 0201cee6b86..8d3a2032056 100644 --- a/packages/multichain-account-service/src/providers/SolAccountProvider.ts +++ b/packages/multichain-account-service/src/providers/SolAccountProvider.ts @@ -12,7 +12,10 @@ import type { InternalAccount } from '@metamask/keyring-internal-api'; import type { SnapId } from '@metamask/snaps-sdk'; import { SnapAccountProvider } from './SnapAccountProvider'; -import type { SnapAccountProviderConfig } from './SnapAccountProvider'; +import type { + RestrictedSnapKeyring, + SnapAccountProviderConfig, +} from './SnapAccountProvider'; import { withRetry, withTimeout } from './utils'; import { traceFallback } from '../analytics'; import { TraceName } from '../constants/traces'; @@ -60,17 +63,18 @@ export class SolAccountProvider extends SnapAccountProvider { } async #createAccount({ + keyring, entropySource, groupIndex, derivationPath, }: { + keyring: RestrictedSnapKeyring; entropySource: EntropySourceId; groupIndex: number; derivationPath: string; }): Promise> { - const createAccount = await this.getRestrictedSnapAccountCreator(); const account = await withTimeout( - createAccount({ entropySource, derivationPath }), + keyring.createAccount({ entropySource, derivationPath }), this.config.createAccounts.timeoutMs, ); @@ -93,15 +97,18 @@ export class SolAccountProvider extends SnapAccountProvider { entropySource: EntropySourceId; groupIndex: number; }): Promise[]> { - return this.withMaxConcurrency(async () => { - const derivationPath = `m/44'/501'/${groupIndex}'/0'`; - const account = await this.#createAccount({ - entropySource, - groupIndex, - derivationPath, + return this.withSnap(async ({ keyring }) => { + return this.withMaxConcurrency(async () => { + const derivationPath = `m/44'/501'/${groupIndex}'/0'`; + const account = await this.#createAccount({ + keyring, + entropySource, + groupIndex, + derivationPath, + }); + + return [account]; }); - - return [account]; }); } @@ -112,50 +119,53 @@ export class SolAccountProvider extends SnapAccountProvider { entropySource: EntropySourceId; groupIndex: number; }): Promise[]> { - return await super.trace( - { - name: TraceName.SnapDiscoverAccounts, - data: { - provider: this.getName(), + return this.withSnap(async ({ client, keyring }) => { + return await super.trace( + { + name: TraceName.SnapDiscoverAccounts, + data: { + provider: this.getName(), + }, }, - }, - async () => { - if (!this.config.discovery.enabled) { - return []; - } - - const discoveredAccounts = await withRetry( - () => - withTimeout( - this.client.discoverAccounts( - [SolScope.Mainnet], + async () => { + if (!this.config.discovery.enabled) { + return []; + } + + const discoveredAccounts = await withRetry( + () => + withTimeout( + client.discoverAccounts( + [SolScope.Mainnet], + entropySource, + groupIndex, + ), + this.config.discovery.timeoutMs, + ), + { + maxAttempts: this.config.discovery.maxAttempts, + backOffMs: this.config.discovery.backOffMs, + }, + ); + + if (!discoveredAccounts.length) { + return []; + } + + const createdAccounts = await Promise.all( + discoveredAccounts.map((d) => + this.#createAccount({ + keyring, entropySource, groupIndex, - ), - this.config.discovery.timeoutMs, + derivationPath: d.derivationPath, + }), ), - { - maxAttempts: this.config.discovery.maxAttempts, - backOffMs: this.config.discovery.backOffMs, - }, - ); - - if (!discoveredAccounts.length) { - return []; - } - - const createdAccounts = await Promise.all( - discoveredAccounts.map((d) => - this.#createAccount({ - entropySource, - groupIndex, - derivationPath: d.derivationPath, - }), - ), - ); - - return createdAccounts; - }, - ); + ); + + return createdAccounts; + }, + ); + }); } } diff --git a/packages/multichain-account-service/src/providers/TrxAccountProvider.test.ts b/packages/multichain-account-service/src/providers/TrxAccountProvider.test.ts index 51406f0917a..4af11bdeabf 100644 --- a/packages/multichain-account-service/src/providers/TrxAccountProvider.test.ts +++ b/packages/multichain-account-service/src/providers/TrxAccountProvider.test.ts @@ -5,6 +5,7 @@ import type { EthKeyring, InternalAccount, } from '@metamask/keyring-internal-api'; +import { SnapControllerState } from '@metamask/snaps-controllers'; import { AccountProviderWrapper } from './AccountProviderWrapper'; import { SnapAccountProviderConfig } from './SnapAccountProvider'; @@ -43,7 +44,7 @@ class MockTronKeyring { .fn() .mockImplementation((_, { index }) => { // Use the provided index or fallback to accounts length - const groupIndex = index !== undefined ? index : this.accounts.length; + const groupIndex = index ?? this.accounts.length; // Check if an account already exists for this group index (idempotent behavior) const found = this.accounts.find( @@ -70,6 +71,11 @@ class MockTronKeyring { // Add discoverAccounts method to match the provider's usage discoverAccounts = jest.fn().mockResolvedValue([]); } +class MockTrxAccountProvider extends TrxAccountProvider { + override async ensureCanUseSnapPlatform(): Promise { + // Override to avoid waiting during tests. + } +} /** * Sets up a TrxAccountProvider for testing. @@ -102,6 +108,11 @@ function setup({ } { const keyring = new MockTronKeyring(accounts); + messenger.registerActionHandler( + 'SnapController:getState', + () => ({ isReady: true }) as SnapControllerState, + ); + messenger.registerActionHandler( 'AccountsController:listMultichainAccounts', () => accounts, @@ -137,7 +148,7 @@ function setup({ const multichainMessenger = getMultichainAccountServiceMessenger(messenger); const provider = new AccountProviderWrapper( multichainMessenger, - new TrxAccountProvider(multichainMessenger, config), + new MockTrxAccountProvider(multichainMessenger, config), ); return { @@ -359,7 +370,7 @@ describe('TrxAccountProvider', () => { const multichainMessenger = getMultichainAccountServiceMessenger(messenger); - const trxProvider = new TrxAccountProvider( + const trxProvider = new MockTrxAccountProvider( multichainMessenger, undefined, mockTrace, @@ -413,7 +424,7 @@ describe('TrxAccountProvider', () => { const multichainMessenger = getMultichainAccountServiceMessenger(messenger); - const trxProvider = new TrxAccountProvider( + const trxProvider = new MockTrxAccountProvider( multichainMessenger, undefined, mockTrace, @@ -446,7 +457,7 @@ describe('TrxAccountProvider', () => { const multichainMessenger = getMultichainAccountServiceMessenger(messenger); - const trxProvider = new TrxAccountProvider( + const trxProvider = new MockTrxAccountProvider( multichainMessenger, undefined, mockTrace, diff --git a/packages/multichain-account-service/src/providers/TrxAccountProvider.ts b/packages/multichain-account-service/src/providers/TrxAccountProvider.ts index 6541a3aeb72..382b4c04a97 100644 --- a/packages/multichain-account-service/src/providers/TrxAccountProvider.ts +++ b/packages/multichain-account-service/src/providers/TrxAccountProvider.ts @@ -8,7 +8,10 @@ import type { InternalAccount } from '@metamask/keyring-internal-api'; import type { SnapId } from '@metamask/snaps-sdk'; import { SnapAccountProvider } from './SnapAccountProvider'; -import type { SnapAccountProviderConfig } from './SnapAccountProvider'; +import type { + RestrictedSnapKeyring, + SnapAccountProviderConfig, +} from './SnapAccountProvider'; import { withRetry, withTimeout } from './utils'; import { traceFallback } from '../analytics'; import { TraceName } from '../constants/traces'; @@ -55,18 +58,18 @@ export class TrxAccountProvider extends SnapAccountProvider { ); } - async createAccounts({ + async #createAccounts({ + keyring, entropySource, groupIndex: index, }: { + keyring: RestrictedSnapKeyring; entropySource: EntropySourceId; groupIndex: number; }): Promise[]> { return this.withMaxConcurrency(async () => { - const createAccount = await this.getRestrictedSnapAccountCreator(); - const account = await withTimeout( - createAccount({ + keyring.createAccount({ entropySource, index, addressType: TrxAccountType.Eoa, @@ -80,6 +83,22 @@ export class TrxAccountProvider extends SnapAccountProvider { }); } + async createAccounts({ + entropySource, + groupIndex: index, + }: { + entropySource: EntropySourceId; + groupIndex: number; + }): Promise[]> { + return this.withSnap(async ({ keyring }) => { + return this.#createAccounts({ + keyring, + entropySource, + groupIndex: index, + }); + }); + } + async discoverAccounts({ entropySource, groupIndex, @@ -87,45 +106,48 @@ export class TrxAccountProvider extends SnapAccountProvider { entropySource: EntropySourceId; groupIndex: number; }): Promise[]> { - return await super.trace( - { - name: TraceName.SnapDiscoverAccounts, - data: { - provider: this.getName(), + return this.withSnap(async ({ client, keyring }) => { + return await super.trace( + { + name: TraceName.SnapDiscoverAccounts, + data: { + provider: this.getName(), + }, }, - }, - async () => { - if (!this.config.discovery.enabled) { - return []; - } - - const discoveredAccounts = await withRetry( - () => - withTimeout( - this.client.discoverAccounts( - [TrxScope.Mainnet], - entropySource, - groupIndex, + async () => { + if (!this.config.discovery.enabled) { + return []; + } + + const discoveredAccounts = await withRetry( + () => + withTimeout( + client.discoverAccounts( + [TrxScope.Mainnet], + entropySource, + groupIndex, + ), + this.config.discovery.timeoutMs, ), - this.config.discovery.timeoutMs, - ), - { - maxAttempts: this.config.discovery.maxAttempts, - backOffMs: this.config.discovery.backOffMs, - }, - ); - - if (!discoveredAccounts.length) { - return []; - } - - const createdAccounts = await this.createAccounts({ - entropySource, - groupIndex, - }); - - return createdAccounts; - }, - ); + { + maxAttempts: this.config.discovery.maxAttempts, + backOffMs: this.config.discovery.backOffMs, + }, + ); + + if (!discoveredAccounts.length) { + return []; + } + + const createdAccounts = await this.#createAccounts({ + keyring, + entropySource, + groupIndex, + }); + + return createdAccounts; + }, + ); + }); } } diff --git a/packages/multichain-account-service/src/snaps/SnapPlatformWatcher.test.ts b/packages/multichain-account-service/src/snaps/SnapPlatformWatcher.test.ts new file mode 100644 index 00000000000..39f4030e541 --- /dev/null +++ b/packages/multichain-account-service/src/snaps/SnapPlatformWatcher.test.ts @@ -0,0 +1,208 @@ +/* eslint-disable no-void */ +import { SnapControllerState } from '@metamask/snaps-controllers'; + +import { SnapPlatformWatcher } from './SnapPlatformWatcher'; +import { + getMultichainAccountServiceMessenger, + getRootMessenger, +} from '../tests'; +import type { RootMessenger } from '../tests'; +import { MultichainAccountServiceMessenger } from '../types'; + +function setup( + { + rootMessenger, + }: { + rootMessenger: RootMessenger; + } = { + rootMessenger: getRootMessenger(), + }, +): { + rootMessenger: RootMessenger; + messenger: MultichainAccountServiceMessenger; + mocks: { + // eslint-disable-next-line @typescript-eslint/naming-convention + SnapController: { + getState: jest.Mock; + }; + }; + watcher: SnapPlatformWatcher; +} { + const mocks = { + SnapController: { + getState: jest.fn(), + }, + }; + + rootMessenger.registerActionHandler( + 'SnapController:getState', + mocks.SnapController.getState, + ); + mocks.SnapController.getState.mockReturnValue({ isReady: false }); + + const messenger = getMultichainAccountServiceMessenger(rootMessenger); + + const watcher = new SnapPlatformWatcher(messenger); + + return { rootMessenger, messenger, watcher, mocks }; +} + +function publishIsReadyState(messenger: RootMessenger, isReady: boolean): void { + messenger.publish( + 'SnapController:stateChange', + { isReady } as SnapControllerState, + [], + ); +} + +describe('SnapPlatformWatcher', () => { + describe('constructor', () => { + it('initializes with isReady as false', () => { + const { messenger } = setup(); + const watcher = new SnapPlatformWatcher(messenger); + + expect(watcher).toBeDefined(); + expect(watcher.isReady).toBe(false); + }); + }); + + describe('ensureCanUsePlatform', () => { + it('waits for platform to be ready at least once before resolving', async () => { + const { rootMessenger, messenger } = setup(); + const watcher = new SnapPlatformWatcher(messenger); + + // Start the promise but don't await immediately. + const ensurePromise = watcher.ensureCanUseSnapPlatform(); + + // Should not resolve yet since platform is not ready. + let resolved = false; + void ensurePromise.then(() => { + resolved = true; + return null; + }); + + expect(resolved).toBe(false); + + // Publish state change with isReady: true. + publishIsReadyState(rootMessenger, true); + + await ensurePromise; + expect(resolved).toBe(true); + }); + + it('throws error if platform becomes unavailable after being ready once', async () => { + const { rootMessenger, messenger } = setup(); + const watcher = new SnapPlatformWatcher(messenger); + + // Make platform ready first. + publishIsReadyState(rootMessenger, true); + + // Make platform unavailable + publishIsReadyState(rootMessenger, false); + + // Should throw error since platform is not ready now. + await expect(watcher.ensureCanUseSnapPlatform()).rejects.toThrow( + 'Snap platform cannot be used now.', + ); + }); + + it('handles multiple state changes correctly', async () => { + const { rootMessenger, messenger } = setup(); + const watcher = new SnapPlatformWatcher(messenger); + + // Make platform ready + publishIsReadyState(rootMessenger, true); + + // Should work + expect(await watcher.ensureCanUseSnapPlatform()).toBeUndefined(); + + // Make platform unavailable. + publishIsReadyState(rootMessenger, false); + + // Should fail. + await expect(watcher.ensureCanUseSnapPlatform()).rejects.toThrow( + 'Snap platform cannot be used now.', + ); + + // Make platform ready again. + publishIsReadyState(rootMessenger, true); + + // Should work again. + expect(await watcher.ensureCanUseSnapPlatform()).toBeUndefined(); + }); + + it('handles concurrent calls correctly', async () => { + const { rootMessenger, messenger } = setup(); + const watcher = new SnapPlatformWatcher(messenger); + + // Start multiple concurrent calls. + const promise1 = watcher.ensureCanUseSnapPlatform(); + const promise2 = watcher.ensureCanUseSnapPlatform(); + const promise3 = watcher.ensureCanUseSnapPlatform(); + + // Make platform ready. + publishIsReadyState(rootMessenger, true); + + // All promises should resolve. + expect(await Promise.all([promise1, promise2, promise3])).toStrictEqual([ + undefined, + undefined, + undefined, + ]); + }); + + it('resolves deferred promise only once when platform becomes ready', async () => { + const { rootMessenger, messenger } = setup(); + const watcher = new SnapPlatformWatcher(messenger); + const resolveSpy = jest.fn(); + + // Access the private deferred promise through ensureCanUsePlatform. + const ensurePromise = watcher.ensureCanUseSnapPlatform(); + void ensurePromise.then(resolveSpy); + + // Make platform ready multiple times. + publishIsReadyState(rootMessenger, true); + publishIsReadyState(rootMessenger, true); + + // Should only resolve once. + await ensurePromise; + expect(resolveSpy).toHaveBeenCalledTimes(1); + }); + + it('ignores state changes with isReady: false before first ready state', async () => { + const { rootMessenger, messenger } = setup(); + const watcher = new SnapPlatformWatcher(messenger); + + // Start the promise + const ensurePromise = watcher.ensureCanUseSnapPlatform(); + let resolved = false; + void ensurePromise.then(() => { + resolved = true; + return null; + }); + + // Publish false state (should be ignored since we haven't been ready yet). + publishIsReadyState(rootMessenger, false); + expect(resolved).toBe(false); + expect(watcher.isReady).toBe(false); + + // Now make it ready.. + publishIsReadyState(rootMessenger, true); + await ensurePromise; + expect(resolved).toBe(true); + }); + + it('resolves immediately if platform is already ready', async () => { + const { messenger, mocks } = setup(); + + // Make the platform ready before creating the watcher. + mocks.SnapController.getState.mockReturnValue({ + isReady: true, + } as SnapControllerState); + + const watcher = new SnapPlatformWatcher(messenger); + + expect(watcher.isReady).toBe(true); + }); + }); +}); diff --git a/packages/multichain-account-service/src/snaps/SnapPlatformWatcher.ts b/packages/multichain-account-service/src/snaps/SnapPlatformWatcher.ts new file mode 100644 index 00000000000..8e0f2304775 --- /dev/null +++ b/packages/multichain-account-service/src/snaps/SnapPlatformWatcher.ts @@ -0,0 +1,61 @@ +import { createDeferredPromise, DeferredPromise } from '@metamask/utils'; +import { once } from 'lodash'; + +import { projectLogger as log } from '../logger'; +import { MultichainAccountServiceMessenger } from '../types'; + +export class SnapPlatformWatcher { + readonly #messenger: MultichainAccountServiceMessenger; + + readonly #isReadyOnce: DeferredPromise; + + #isReady: boolean; + + constructor(messenger: MultichainAccountServiceMessenger) { + this.#messenger = messenger; + + this.#isReady = false; + this.#isReadyOnce = createDeferredPromise(); + + this.#watch(); + } + + get isReady(): boolean { + return this.#isReady; + } + + async ensureCanUseSnapPlatform(): Promise { + // We always wait for the Snap platform to be ready at least once. + await this.#isReadyOnce.promise; + + // Then, we check for the current state and see if we can use it. + if (!this.#isReady) { + throw new Error('Snap platform cannot be used now.'); + } + } + + #watch(): void { + const logReadyOnce = once(() => log('Snap platform is ready!')); + + // If already ready, resolve immediately. + const initialState = this.#messenger.call('SnapController:getState'); + if (initialState.isReady) { + this.#isReady = true; + this.#isReadyOnce.resolve(); + } + + // We still subscribe to state changes to keep track of the platform's readiness. + this.#messenger.subscribe( + 'SnapController:stateChange', + (isReady: boolean) => { + this.#isReady = isReady; + + if (isReady) { + logReadyOnce(); + this.#isReadyOnce.resolve(); + } + }, + (state) => state.isReady, + ); + } +} diff --git a/packages/multichain-account-service/src/tests/messenger.ts b/packages/multichain-account-service/src/tests/messenger.ts index 64fb6f2abf5..f7de1afff77 100644 --- a/packages/multichain-account-service/src/tests/messenger.ts +++ b/packages/multichain-account-service/src/tests/messenger.ts @@ -7,10 +7,10 @@ import type { import type { MultichainAccountServiceMessenger } from '../types'; -type AllMultichainAccountServiceActions = +export type AllMultichainAccountServiceActions = MessengerActions; -type AllMultichainAccountServiceEvents = +export type AllMultichainAccountServiceEvents = MessengerEvents; export type RootMessenger = Messenger< @@ -35,10 +35,17 @@ export function getRootMessenger(): RootMessenger { * Retrieves a restricted messenger for the MultichainAccountService. * * @param rootMessenger - The root messenger instance. Defaults to a new Messenger created by getRootMessenger(). + * @param extra - Extra messenger options. + * @param extra.actions - Extra actions to delegate. + * @param extra.events - Extra events to delegate. * @returns The restricted messenger for the MultichainAccountService. */ export function getMultichainAccountServiceMessenger( rootMessenger: RootMessenger, + extra?: { + actions?: AllMultichainAccountServiceActions['type'][]; + events?: AllMultichainAccountServiceEvents['type'][]; + }, ): MultichainAccountServiceMessenger { const messenger = new Messenger< 'MultichainAccountService', @@ -55,6 +62,7 @@ export function getMultichainAccountServiceMessenger( 'AccountsController:getAccount', 'AccountsController:getAccountByAddress', 'AccountsController:listMultichainAccounts', + 'SnapController:getState', 'SnapController:handleRequest', 'KeyringController:withKeyring', 'KeyringController:getState', @@ -62,11 +70,14 @@ export function getMultichainAccountServiceMessenger( 'KeyringController:addNewKeyring', 'NetworkController:findNetworkClientIdByChainId', 'NetworkController:getNetworkClientById', + ...(extra?.actions ?? []), ], events: [ 'KeyringController:stateChange', + 'SnapController:stateChange', 'AccountsController:accountAdded', 'AccountsController:accountRemoved', + ...(extra?.events ?? []), ], }); return messenger; diff --git a/packages/multichain-account-service/src/types.ts b/packages/multichain-account-service/src/types.ts index 72c97cb841e..a392d70a0b0 100644 --- a/packages/multichain-account-service/src/types.ts +++ b/packages/multichain-account-service/src/types.ts @@ -25,7 +25,11 @@ import type { NetworkControllerFindNetworkClientIdByChainIdAction, NetworkControllerGetNetworkClientByIdAction, } from '@metamask/network-controller'; -import type { HandleSnapRequest as SnapControllerHandleSnapRequestAction } from '@metamask/snaps-controllers'; +import type { + HandleSnapRequest as SnapControllerHandleSnapRequestAction, + SnapControllerGetStateAction, + SnapStateChange as SnapControllerStateChangeEvent, +} from '@metamask/snaps-controllers'; import type { MultichainAccountService, @@ -87,6 +91,11 @@ export type MultichainAccountServiceResyncAccountsAction = { handler: MultichainAccountService['resyncAccounts']; }; +export type MultichainAccountServiceEnsureCanUseSnapPlatformAction = { + type: `${typeof serviceName}:ensureCanUseSnapPlatform`; + handler: MultichainAccountService['ensureCanUseSnapPlatform']; +}; + /** * All actions that {@link MultichainAccountService} registers so that other * modules can call them. @@ -102,7 +111,8 @@ export type MultichainAccountServiceActions = | MultichainAccountServiceAlignWalletAction | MultichainAccountServiceAlignWalletsAction | MultichainAccountServiceCreateMultichainAccountWalletAction - | MultichainAccountServiceResyncAccountsAction; + | MultichainAccountServiceResyncAccountsAction + | MultichainAccountServiceEnsureCanUseSnapPlatformAction; export type MultichainAccountServiceMultichainAccountGroupCreatedEvent = { type: `${typeof serviceName}:multichainAccountGroupCreated`; @@ -136,6 +146,7 @@ type AllowedActions = | AccountsControllerListMultichainAccountsAction | AccountsControllerGetAccountAction | AccountsControllerGetAccountByAddressAction + | SnapControllerGetStateAction | SnapControllerHandleSnapRequestAction | KeyringControllerWithKeyringAction | KeyringControllerGetStateAction @@ -149,6 +160,7 @@ type AllowedActions = * subscribes to. */ type AllowedEvents = + | SnapControllerStateChangeEvent | KeyringControllerStateChangeEvent | AccountsControllerAccountAddedEvent | AccountsControllerAccountRemovedEvent; diff --git a/yarn.lock b/yarn.lock index 0d4fc4e3f44..24e14829d70 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4036,6 +4036,7 @@ __metadata: async-mutex: "npm:^0.5.0" deepmerge: "npm:^4.2.2" jest: "npm:^27.5.1" + lodash: "npm:^4.17.21" ts-jest: "npm:^27.1.4" typedoc: "npm:^0.24.8" typedoc-plugin-missing-exports: "npm:^2.0.0"