diff --git a/packages/accounts-controller/src/AccountsController.test.ts b/packages/accounts-controller/src/AccountsController.test.ts index 930a569cd4a..2be244ac888 100644 --- a/packages/accounts-controller/src/AccountsController.test.ts +++ b/packages/accounts-controller/src/AccountsController.test.ts @@ -2023,33 +2023,6 @@ describe('AccountsController', () => { `Account Id "${accountId}" not found`, ); }); - - it('handle the edge case of undefined accountId during onboarding', async () => { - const { accountsController } = setupAccountsController({ - initialState: { - internalAccounts: { - accounts: { [mockAccount.id]: mockAccount }, - selectedAccount: mockAccount.id, - }, - }, - }); - - // @ts-expect-error forcing undefined accountId - expect(accountsController.getAccountExpect(undefined)).toStrictEqual({ - id: '', - address: '', - options: {}, - methods: [], - type: EthAccountType.Eoa, - metadata: { - name: '', - keyring: { - type: '', - }, - importTime: 0, - }, - }); - }); }); describe('setSelectedAccount', () => { @@ -2113,6 +2086,32 @@ describe('AccountsController', () => { mockNonEvmAccount, ); }); + + it('handle the edge case of undefined accountId during onboarding', async () => { + const { accountsController } = setupAccountsController({ + initialState: { + internalAccounts: { + accounts: {}, + selectedAccount: '', + }, + }, + }); + + expect(accountsController.getSelectedAccount()).toStrictEqual({ + id: '', + address: '', + options: {}, + methods: [], + type: EthAccountType.Eoa, + metadata: { + name: '', + keyring: { + type: '', + }, + importTime: 0, + }, + }); + }); }); describe('setAccountName', () => { diff --git a/packages/accounts-controller/src/AccountsController.ts b/packages/accounts-controller/src/AccountsController.ts index 820ab6447f7..113e0bc1911 100644 --- a/packages/accounts-controller/src/AccountsController.ts +++ b/packages/accounts-controller/src/AccountsController.ts @@ -265,9 +265,22 @@ export class AccountsController extends BaseController< * @throws An error if the account ID is not found. */ getAccountExpect(accountId: string): InternalAccount { + const account = this.getAccount(accountId); + if (account === undefined) { + throw new Error(`Account Id "${accountId}" not found`); + } + return account; + } + + /** + * Returns the last selected evm account. + * + * @returns The selected internal account. + */ + getSelectedAccount(): InternalAccount { // Edge case where the extension is setup but the srp is not yet created // certain ui elements will query the selected address before any accounts are created. - if (!accountId) { + if (this.state.internalAccounts.selectedAccount === '') { return { id: '', address: '', @@ -284,19 +297,6 @@ export class AccountsController extends BaseController< }; } - const account = this.getAccount(accountId); - if (account === undefined) { - throw new Error(`Account Id "${accountId}" not found`); - } - return account; - } - - /** - * Returns the last selected evm account. - * - * @returns The selected internal account. - */ - getSelectedAccount(): InternalAccount { const selectedAccount = this.getAccountExpect( this.state.internalAccounts.selectedAccount, ); diff --git a/packages/assets-controllers/package.json b/packages/assets-controllers/package.json index 4536f6db481..509ecd5d4ad 100644 --- a/packages/assets-controllers/package.json +++ b/packages/assets-controllers/package.json @@ -53,12 +53,15 @@ "@metamask/contract-metadata": "^2.4.0", "@metamask/controller-utils": "^11.0.0", "@metamask/eth-query": "^4.0.0", + "@metamask/keyring-api": "^6.4.0", "@metamask/keyring-controller": "^17.0.0", "@metamask/metamask-eth-abis": "^3.1.1", "@metamask/network-controller": "^19.0.0", "@metamask/polling-controller": "^8.0.0", "@metamask/preferences-controller": "^13.0.0", "@metamask/rpc-errors": "^6.2.1", + "@metamask/snaps-sdk": "^4.2.0", + "@metamask/snaps-utils": "^7.4.0", "@metamask/utils": "^8.3.0", "@types/bn.js": "^5.1.5", "@types/uuid": "^8.3.0", @@ -73,11 +76,11 @@ "devDependencies": { "@metamask/auto-changelog": "^3.4.4", "@metamask/ethjs-provider-http": "^0.3.0", - "@metamask/keyring-api": "^6.4.0", "@types/jest": "^27.4.1", "@types/lodash": "^4.14.191", "@types/node": "^16.18.54", "deepmerge": "^4.2.2", + "immer": "^9.0.6", "jest": "^27.5.1", "jest-environment-jsdom": "^27.5.1", "nock": "^13.3.1", @@ -92,7 +95,8 @@ "@metamask/approval-controller": "^7.0.0", "@metamask/keyring-controller": "^17.0.0", "@metamask/network-controller": "^19.0.0", - "@metamask/preferences-controller": "^13.0.0" + "@metamask/preferences-controller": "^13.0.0", + "@metamask/snaps-controllers": "^8.1.1" }, "engines": { "node": "^18.18 || >=20" diff --git a/packages/assets-controllers/src/AccountTrackerController.test.ts b/packages/assets-controllers/src/AccountTrackerController.test.ts index 7f32c4ad550..93212d287fe 100644 --- a/packages/assets-controllers/src/AccountTrackerController.test.ts +++ b/packages/assets-controllers/src/AccountTrackerController.test.ts @@ -1,13 +1,10 @@ import { query } from '@metamask/controller-utils'; import HttpProvider from '@metamask/ethjs-provider-http'; -import { - getDefaultPreferencesState, - type Identity, - type PreferencesState, -} from '@metamask/preferences-controller'; +import type { InternalAccount } from '@metamask/keyring-api'; import * as sinon from 'sinon'; import { advanceTime } from '../../../tests/helpers'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import { AccountTrackerController } from './AccountTrackerController'; jest.mock('@metamask/controller-utils', () => { @@ -18,7 +15,9 @@ jest.mock('@metamask/controller-utils', () => { }); const ADDRESS_1 = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; +const ACCOUNT_1 = createMockInternalAccount({ address: ADDRESS_1 }); const ADDRESS_2 = '0x742d35Cc6634C0532925a3b844Bc454e4438f44e'; +const ACCOUNT_2 = createMockInternalAccount({ address: ADDRESS_2 }); const mockedQuery = query as jest.Mock< ReturnType, @@ -44,9 +43,9 @@ describe('AccountTrackerController', () => { it('should set default state', () => { const controller = new AccountTrackerController({ - onPreferencesStateChange: sinon.stub(), - getIdentities: () => ({}), - getSelectedAddress: () => '', + onSelectedAccountChange: sinon.stub(), + getInternalAccounts: () => [], + getSelectedAccount: () => ACCOUNT_1, getMultiAccountBalancesEnabled: () => true, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn(), @@ -61,9 +60,11 @@ describe('AccountTrackerController', () => { it('should throw when provider property is accessed', () => { const controller = new AccountTrackerController({ - onPreferencesStateChange: sinon.stub(), - getIdentities: () => ({}), - getSelectedAddress: () => '', + onSelectedAccountChange: sinon.stub(), + getInternalAccounts: () => [], + getSelectedAccount: () => { + return {} as InternalAccount; + }, getMultiAccountBalancesEnabled: () => true, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn(), @@ -73,31 +74,31 @@ describe('AccountTrackerController', () => { ); }); - it('should refresh when preferences state changes', async () => { - const preferencesStateChangeListeners: (( - state: PreferencesState, + it('should refresh when selectedAccount changes', async () => { + const selectedAccountChangeListeners: (( + internalAccount: InternalAccount, ) => void)[] = []; const controller = new AccountTrackerController( { - onPreferencesStateChange: (listener) => { - preferencesStateChangeListeners.push(listener); + onSelectedAccountChange: (listener) => { + selectedAccountChangeListeners.push(listener); }, - getIdentities: () => ({}), - getSelectedAddress: () => '0x0', + getInternalAccounts: () => [], + getSelectedAccount: () => ACCOUNT_1, getMultiAccountBalancesEnabled: () => true, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn(), }, { provider }, ); - const triggerPreferencesStateChange = (state: PreferencesState) => { - for (const listener of preferencesStateChangeListeners) { - listener(state); + const triggerSelectedAccountChange = (internalAccount: InternalAccount) => { + for (const listener of selectedAccountChangeListeners) { + listener(internalAccount); } }; controller.refresh = sinon.stub(); - triggerPreferencesStateChange(getDefaultPreferencesState()); + triggerSelectedAccountChange(ACCOUNT_1); // TODO: Replace `any` with type // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -113,16 +114,13 @@ describe('AccountTrackerController', () => { describe('without networkClientId', () => { it('should sync addresses', async () => { + const bazAccount = createMockInternalAccount({ address: 'baz' }); + const barAccount = createMockInternalAccount({ address: 'bar' }); const controller = new AccountTrackerController( { - onPreferencesStateChange: sinon.stub(), - getIdentities: () => { - return { - bar: {} as Identity, - baz: {} as Identity, - }; - }, - getSelectedAddress: () => '0x0', + onSelectedAccountChange: sinon.stub(), + getInternalAccounts: () => [bazAccount, barAccount], + getSelectedAccount: () => barAccount, getMultiAccountBalancesEnabled: () => true, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn(), @@ -169,11 +167,9 @@ describe('AccountTrackerController', () => { const controller = new AccountTrackerController( { - onPreferencesStateChange: sinon.stub(), - getIdentities: () => { - return { [ADDRESS_1]: {} as Identity }; - }, - getSelectedAddress: () => ADDRESS_1, + onSelectedAccountChange: sinon.stub(), + getInternalAccounts: () => [ACCOUNT_1], + getSelectedAccount: () => ACCOUNT_1, getMultiAccountBalancesEnabled: () => true, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn(), @@ -206,14 +202,9 @@ describe('AccountTrackerController', () => { const controller = new AccountTrackerController( { - onPreferencesStateChange: sinon.stub(), - getIdentities: () => { - return { - [ADDRESS_1]: {} as Identity, - [ADDRESS_2]: {} as Identity, - }; - }, - getSelectedAddress: () => ADDRESS_1, + onSelectedAccountChange: sinon.stub(), + getInternalAccounts: () => [ACCOUNT_1, ACCOUNT_2], + getSelectedAccount: () => ACCOUNT_1, getMultiAccountBalancesEnabled: () => false, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn(), @@ -244,14 +235,9 @@ describe('AccountTrackerController', () => { const controller = new AccountTrackerController( { - onPreferencesStateChange: sinon.stub(), - getIdentities: () => { - return { - [ADDRESS_1]: {} as Identity, - [ADDRESS_2]: {} as Identity, - }; - }, - getSelectedAddress: () => ADDRESS_1, + onSelectedAccountChange: sinon.stub(), + getInternalAccounts: () => [ACCOUNT_1, ACCOUNT_2], + getSelectedAccount: () => ACCOUNT_1, getMultiAccountBalancesEnabled: () => true, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn(), @@ -278,16 +264,13 @@ describe('AccountTrackerController', () => { describe('with networkClientId', () => { it('should sync addresses', async () => { + const bazAccount = createMockInternalAccount({ address: 'baz' }); + const barAccount = createMockInternalAccount({ address: 'bar' }); const controller = new AccountTrackerController( { - onPreferencesStateChange: sinon.stub(), - getIdentities: () => { - return { - bar: {} as Identity, - baz: {} as Identity, - }; - }, - getSelectedAddress: () => '0x0', + onSelectedAccountChange: sinon.stub(), + getInternalAccounts: () => [bazAccount, barAccount], + getSelectedAccount: () => bazAccount, getMultiAccountBalancesEnabled: () => true, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn().mockReturnValue({ @@ -342,11 +325,9 @@ describe('AccountTrackerController', () => { mockedQuery.mockReturnValueOnce(Promise.resolve('0x10')); const controller = new AccountTrackerController({ - onPreferencesStateChange: sinon.stub(), - getIdentities: () => { - return { [ADDRESS_1]: {} as Identity }; - }, - getSelectedAddress: () => ADDRESS_1, + onSelectedAccountChange: sinon.stub(), + getInternalAccounts: () => [ACCOUNT_1], + getSelectedAccount: () => ACCOUNT_1, getMultiAccountBalancesEnabled: () => true, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn().mockReturnValue({ @@ -386,14 +367,11 @@ describe('AccountTrackerController', () => { .mockReturnValueOnce(Promise.resolve('0x11')); const controller = new AccountTrackerController({ - onPreferencesStateChange: sinon.stub(), - getIdentities: () => { - return { - [ADDRESS_1]: {} as Identity, - [ADDRESS_2]: {} as Identity, - }; + onSelectedAccountChange: sinon.stub(), + getInternalAccounts: () => { + return [ACCOUNT_1, ACCOUNT_2]; }, - getSelectedAddress: () => ADDRESS_1, + getSelectedAccount: () => ACCOUNT_1, getMultiAccountBalancesEnabled: () => false, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn().mockReturnValue({ @@ -430,14 +408,11 @@ describe('AccountTrackerController', () => { .mockReturnValueOnce(Promise.resolve('0x12')); const controller = new AccountTrackerController({ - onPreferencesStateChange: sinon.stub(), - getIdentities: () => { - return { - [ADDRESS_1]: {} as Identity, - [ADDRESS_2]: {} as Identity, - }; + onSelectedAccountChange: sinon.stub(), + getInternalAccounts: () => { + return [ACCOUNT_1, ACCOUNT_2]; }, - getSelectedAddress: () => ADDRESS_1, + getSelectedAccount: () => ACCOUNT_1, getMultiAccountBalancesEnabled: () => true, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn().mockReturnValue({ @@ -474,11 +449,9 @@ describe('AccountTrackerController', () => { it('should sync balance with addresses', async () => { const controller = new AccountTrackerController( { - onPreferencesStateChange: sinon.stub(), - getIdentities: () => { - return {}; - }, - getSelectedAddress: () => ADDRESS_1, + onSelectedAccountChange: sinon.stub(), + getInternalAccounts: () => [], + getSelectedAccount: () => ACCOUNT_1, getMultiAccountBalancesEnabled: () => true, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn(), @@ -501,9 +474,9 @@ describe('AccountTrackerController', () => { const poll = sinon.spy(AccountTrackerController.prototype, 'poll'); const controller = new AccountTrackerController( { - onPreferencesStateChange: jest.fn(), - getIdentities: () => ({}), - getSelectedAddress: () => '', + onSelectedAccountChange: jest.fn(), + getInternalAccounts: () => [], + getSelectedAccount: () => ACCOUNT_1, getMultiAccountBalancesEnabled: () => true, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn(), @@ -523,9 +496,9 @@ describe('AccountTrackerController', () => { sinon.stub(AccountTrackerController.prototype, 'poll'); const controller = new AccountTrackerController( { - onPreferencesStateChange: jest.fn(), - getIdentities: () => ({}), - getSelectedAddress: () => '', + onSelectedAccountChange: jest.fn(), + getInternalAccounts: () => [], + getSelectedAccount: () => ACCOUNT_1, getMultiAccountBalancesEnabled: () => true, getCurrentChainId: () => '0x1', getNetworkClientById: jest.fn(), diff --git a/packages/assets-controllers/src/AccountTrackerController.ts b/packages/assets-controllers/src/AccountTrackerController.ts index 3020b41c7fa..881ecc775a8 100644 --- a/packages/assets-controllers/src/AccountTrackerController.ts +++ b/packages/assets-controllers/src/AccountTrackerController.ts @@ -1,7 +1,9 @@ +import type { AccountsController } from '@metamask/accounts-controller'; import type { BaseConfig, BaseState } from '@metamask/base-controller'; import { query, safelyExecuteWithTimeout } from '@metamask/controller-utils'; import EthQuery from '@metamask/eth-query'; import type { Provider } from '@metamask/eth-query'; +import { type InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientId, NetworkController, @@ -79,7 +81,11 @@ export class AccountTrackerController extends StaticIntervalPollingControllerV1< }); } - const addresses = Object.keys(this.getIdentities()); + const addresses = Object.values( + this.getInternalAccounts().map( + (internalAccount) => internalAccount.address, + ), + ); const newAddresses = addresses.filter( (address) => !existing.includes(address), ); @@ -114,9 +120,9 @@ export class AccountTrackerController extends StaticIntervalPollingControllerV1< */ override name = 'AccountTrackerController' as const; - private readonly getIdentities: () => PreferencesState['identities']; + private readonly getInternalAccounts: AccountsController['listAccounts']; - private readonly getSelectedAddress: () => PreferencesState['selectedAddress']; + private readonly getSelectedAccount: AccountsController['getSelectedAccount']; private readonly getMultiAccountBalancesEnabled: () => PreferencesState['isMultiAccountBalancesEnabled']; @@ -128,29 +134,29 @@ export class AccountTrackerController extends StaticIntervalPollingControllerV1< * Creates an AccountTracker instance. * * @param options - The controller options. - * @param options.onPreferencesStateChange - Allows subscribing to preference controller state changes. - * @param options.getIdentities - Gets the identities from the Preferences store. - * @param options.getSelectedAddress - Gets the selected address from the Preferences store. * @param options.getMultiAccountBalancesEnabled - Gets the multi account balances enabled flag from the Preferences store. * @param options.getCurrentChainId - Gets the chain ID for the current network from the Network store. * @param options.getNetworkClientById - Gets the network client with the given id from the NetworkController. + * @param options.onSelectedAccountChange - A function that subscribes to selected account changes. + * @param options.getInternalAccounts - A function that returns the internal accounts. + * @param options.getSelectedAccount - A function that returns the selected account. * @param config - Initial options used to configure this controller. * @param state - Initial state to set on this controller. */ constructor( { - onPreferencesStateChange, - getIdentities, - getSelectedAddress, + onSelectedAccountChange, + getInternalAccounts, + getSelectedAccount, getMultiAccountBalancesEnabled, getCurrentChainId, getNetworkClientById, }: { - onPreferencesStateChange: ( - listener: (preferencesState: PreferencesState) => void, + onSelectedAccountChange: ( + listener: (internalAccount: InternalAccount) => void, ) => void; - getIdentities: () => PreferencesState['identities']; - getSelectedAddress: () => PreferencesState['selectedAddress']; + getInternalAccounts: AccountsController['listAccounts']; + getSelectedAccount: AccountsController['getSelectedAccount']; getMultiAccountBalancesEnabled: () => PreferencesState['isMultiAccountBalancesEnabled']; getCurrentChainId: () => Hex; getNetworkClientById: NetworkController['getNetworkClientById']; @@ -170,12 +176,12 @@ export class AccountTrackerController extends StaticIntervalPollingControllerV1< }; this.initialize(); this.setIntervalLength(this.config.interval); - this.getIdentities = getIdentities; - this.getSelectedAddress = getSelectedAddress; this.getMultiAccountBalancesEnabled = getMultiAccountBalancesEnabled; this.getCurrentChainId = getCurrentChainId; this.getNetworkClientById = getNetworkClientById; - onPreferencesStateChange(() => { + this.getSelectedAccount = getSelectedAccount; + this.getInternalAccounts = getInternalAccounts; + onSelectedAccountChange(() => { this.refresh(); }); this.poll(); @@ -253,6 +259,7 @@ export class AccountTrackerController extends StaticIntervalPollingControllerV1< * @param networkClientId - Optional networkClientId to fetch a network client with */ refresh = async (networkClientId?: NetworkClientId) => { + const selectedAccount = this.getSelectedAccount(); const releaseLock = await this.refreshMutex.acquire(); try { const { chainId, ethQuery } = @@ -264,7 +271,7 @@ export class AccountTrackerController extends StaticIntervalPollingControllerV1< const accountsToUpdate = isMultiAccountBalancesEnabled ? Object.keys(accounts) - : [this.getSelectedAddress()]; + : [selectedAccount.address]; const accountsForChain = { ...accountsByChainId[chainId] }; for (const address of accountsToUpdate) { diff --git a/packages/assets-controllers/src/BalancesController.ts b/packages/assets-controllers/src/BalancesController.ts new file mode 100644 index 00000000000..78ebd096b2e --- /dev/null +++ b/packages/assets-controllers/src/BalancesController.ts @@ -0,0 +1,186 @@ +import { type AccountsControllerGetAccountAction } from '@metamask/accounts-controller'; +import { + BaseController, + type ControllerGetStateAction, + type ControllerStateChangeEvent, + type RestrictedControllerMessenger, +} from '@metamask/base-controller'; +import { + KeyringRpcMethod, + type Balance, + type CaipAssetType, +} from '@metamask/keyring-api'; +import type { HandleSnapRequest } from '@metamask/snaps-controllers'; +import type { SnapId } from '@metamask/snaps-sdk'; +import { HandlerType } from '@metamask/snaps-utils'; +import type { Draft } from 'immer'; +import { v4 as uuid } from 'uuid'; + +const controllerName = 'BalancesController'; + +/** + * State used by the {@link BalancesController} to cache account balances. + */ +export type BalancesControllerState = { + balances: { + [account: string]: { + [asset: string]: { + amount: string; + unit: string; + }; + }; + }; +}; + +/** + * Default state of the {@link BalancesController}. + */ +const defaultState: BalancesControllerState = { balances: {} }; + +/** + * Returns the state of the {@link BalancesController}. + */ +export type GetBalancesControllerState = ControllerGetStateAction< + typeof controllerName, + BalancesControllerState +>; + +/** + * Returns the balances of an account. + */ +export type GetBalances = { + type: `${typeof controllerName}:getBalances`; + handler: BalancesController['getBalances']; +}; + +/** + * Event emitted when the state of the {@link BalancesController} changes. + */ +export type BalancesControllerStateChange = ControllerStateChangeEvent< + typeof controllerName, + BalancesControllerState +>; + +/** + * Actions exposed by the {@link BalancesController}. + */ +export type BalancesControllerActions = + | GetBalancesControllerState + | GetBalances; + +/** + * Events emitted by {@link BalancesController}. + */ +export type BalancesControllerEvents = BalancesControllerStateChange; + +/** + * Actions that this controller is allowed to call. + */ +export type AllowedActions = + | HandleSnapRequest + | AccountsControllerGetAccountAction; + +/** + * Messenger type for the BalancesController. + */ +export type BalancesControllerMessenger = RestrictedControllerMessenger< + typeof controllerName, + BalancesControllerActions | AllowedActions, + BalancesControllerEvents, + AllowedActions['type'], + never +>; + +/** + * {@link BalancesController}'s metadata. + * + * This allows us to choose if fields of the state should be persisted or not + * using the `persist` flag; and if they can be sent to Sentry or not, using + * the `anonymous` flag. + */ +const balancesControllerMetadata = { + balances: { + persist: true, + anonymous: false, + }, +}; + +/** + * The BalancesController is responsible for fetching and caching account + * balances. + */ +export class BalancesController extends BaseController< + typeof controllerName, + BalancesControllerState, + BalancesControllerMessenger +> { + constructor({ + messenger, + state, + }: { + messenger: BalancesControllerMessenger; + state: BalancesControllerState; + }) { + super({ + messenger, + name: controllerName, + metadata: balancesControllerMetadata, + state: { + ...defaultState, + ...state, + }, + }); + } + + /** + * Get the balances for an account. + * + * @param accountId - ID of the account to get balances for. + * @param assetTypes - Array of asset types to get balances for. + * @returns A map of asset types to balances. + */ + async getBalances( + accountId: string, + assetTypes: CaipAssetType[], + ): Promise> { + console.log('!!! Getting balances for account', accountId); + console.log('!!! Assets:', assetTypes); + + const account = this.messagingSystem.call( + 'AccountsController:getAccount', + accountId, + ); + if (!account) { + return {}; + } + + const snapId = account.metadata.snap?.id; + if (!snapId) { + return {}; + } + + const balances = (await this.messagingSystem.call( + 'SnapController:handleRequest', + { + snapId: snapId as SnapId, + origin: 'metamask', + handler: HandlerType.OnRpcRequest, + request: { + jsonrpc: '2.0', + id: uuid(), + method: KeyringRpcMethod.GetAccountBalances, + params: { + id: account.id, + assets: assetTypes, + }, + }, + }, + )) as Record; + + this.update((state: Draft) => { + state.balances[accountId] = balances; + }); + + return balances; + } +} diff --git a/packages/assets-controllers/src/NftController.test.ts b/packages/assets-controllers/src/NftController.test.ts index 2c5004e1c3a..80853834c33 100644 --- a/packages/assets-controllers/src/NftController.test.ts +++ b/packages/assets-controllers/src/NftController.test.ts @@ -1,5 +1,14 @@ import type { Network } from '@ethersproject/providers'; -import type { ApprovalControllerMessenger } from '@metamask/approval-controller'; +import type { + AccountsControllerGetAccountAction, + AccountsControllerGetSelectedAccountAction, + AccountsControllerSelectedEvmAccountChangeEvent, +} from '@metamask/accounts-controller'; +import type { + AddApprovalRequest, + ApprovalStateChange, + ApprovalControllerMessenger, +} from '@metamask/approval-controller'; import { ApprovalController } from '@metamask/approval-controller'; import { ControllerMessenger } from '@metamask/base-controller'; import { @@ -15,11 +24,15 @@ import { NFT_API_BASE_URL, InfuraNetworkType, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientConfiguration, NetworkClientId, + NetworkControllerGetNetworkClientByIdAction, + NetworkControllerNetworkDidChangeEvent, } from '@metamask/network-controller'; import { defaultState as defaultNetworkState } from '@metamask/network-controller'; +import type { PreferencesControllerStateChangeEvent } from '@metamask/preferences-controller'; import { getDefaultPreferencesState, type PreferencesState, @@ -29,6 +42,7 @@ import nock from 'nock'; import * as sinon from 'sinon'; import { v4 } from 'uuid'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import type { ExtractAvailableAction, ExtractAvailableEvent, @@ -62,6 +76,11 @@ const ERC721_DEPRESSIONIST_ADDRESS = '0x18E8E76aeB9E2d9FA2A2b88DD9CF3C8ED45c3660'; const ERC721_DEPRESSIONIST_ID = '36'; const OWNER_ADDRESS = '0x5a3CA5cD63807Ce5e4d7841AB32Ce6B6d9BbBa2D'; +const OWNER_ID = '54d1e7bc-1dce-4220-a15f-2f454bae7869'; +const OWNER_ACCOUNT = createMockInternalAccount({ + id: OWNER_ID, + address: OWNER_ADDRESS, +}); const SECOND_OWNER_ADDRESS = '0x500017171kasdfbou081'; const DEPRESSIONIST_CID_V1 = @@ -84,6 +103,17 @@ const GOERLI = { ticker: NetworksTicker.goerli, }; +type ApprovalActions = + | AddApprovalRequest + | AccountsControllerGetAccountAction + | AccountsControllerGetSelectedAccountAction + | NetworkControllerGetNetworkClientByIdAction; +type ApprovalEvents = + | ApprovalStateChange + | PreferencesControllerStateChangeEvent + | NetworkControllerNetworkDidChangeEvent + | AccountsControllerSelectedEvmAccountChangeEvent; + const controllerName = 'NftController' as const; // Mock out detectNetwork function for cleaner tests, Ethers calls this a bunch of times because the Web3Provider is paranoid. @@ -149,6 +179,20 @@ function setupController({ getNetworkClientById, ); + const getInternalAccountMock = jest.fn().mockReturnValue(OWNER_ACCOUNT); + + messenger.registerActionHandler( + 'AccountsController:getAccount', + getInternalAccountMock, + ); + + const getSelectedAccountMock = jest.fn().mockReturnValue(OWNER_ACCOUNT); + + messenger.registerActionHandler( + 'AccountsController:getSelectedAccount', + getSelectedAccountMock, + ); + const approvalControllerMessenger = messenger.getRestricted({ name: 'ApprovalController', allowedActions: [], @@ -160,15 +204,27 @@ function setupController({ showApprovalRequest: jest.fn(), }); - const nftControllerMessenger = messenger.getRestricted({ + const nftControllerMessenger = messenger.getRestricted< + typeof controllerName, + ApprovalActions['type'], + Extract< + ApprovalEvents, + | PreferencesControllerStateChangeEvent + | AccountsControllerSelectedEvmAccountChangeEvent + | NetworkControllerNetworkDidChangeEvent + >['type'] + >({ name: controllerName, allowedActions: [ 'ApprovalController:addRequest', + 'AccountsController:getSelectedAccount', + 'AccountsController:getAccount', 'NetworkController:getNetworkClientById', ], allowedEvents: [ - 'NetworkController:networkDidChange', + 'AccountsController:selectedEvmAccountChange', 'PreferencesController:stateChange', + 'NetworkController:networkDidChange', ], }); @@ -203,15 +259,30 @@ function setupController({ triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: OWNER_ADDRESS, }); + const triggerSelectedAccountChange = ( + internalAccount: InternalAccount, + ): void => { + messenger.publish( + 'AccountsController:selectedEvmAccountChange', + internalAccount, + ); + }; + + if (!options.selectedAccountId) { + triggerSelectedAccountChange(OWNER_ACCOUNT); + } + return { nftController, messenger, approvalController, changeNetwork, triggerPreferencesStateChange, + triggerSelectedAccountChange, + getInternalAccountMock, + getSelectedAccountMock, }; } @@ -402,12 +473,17 @@ describe('NftController', () => { }, }); - const callActionSpy = jest.spyOn(messenger, 'call').mockResolvedValue({}); + const callActionSpy = jest.spyOn(messenger, 'call'); await expect(() => nftController.watchNft(ERC721_NFT, ERC721, 'https://test-dapp.com'), ).rejects.toThrow('Suggested NFT is not owned by the selected account'); - expect(callActionSpy).toHaveBeenCalledTimes(0); + // First call is getInternalAccount. Second call is the approval request. + expect(callActionSpy).not.toHaveBeenNthCalledWith( + 2, + 'ApprovalController:addRequest', + expect.any(Object), + ); }); it('should error if the call to isNftOwner fail', async function () { @@ -432,12 +508,13 @@ describe('NftController', () => { }, }); - const callActionSpy = jest.spyOn(messenger, 'call').mockResolvedValue({}); + const callActionSpy = jest.spyOn(messenger, 'call'); await expect(() => nftController.watchNft(ERC1155_NFT, ERC1155, 'https://test-dapp.com'), ).rejects.toThrow('Suggested NFT is not owned by the selected account'); - expect(callActionSpy).toHaveBeenCalledTimes(0); + // First call is to get InternalAccount + expect(callActionSpy).toHaveBeenCalledTimes(1); }); it('should handle ERC721 type and add pending request to ApprovalController with the OpenSea API disabled and IPFS gateway enabled', async function () { @@ -451,20 +528,24 @@ describe('NftController', () => { description: 'testERC721Description', }), ); - const { nftController, messenger, triggerPreferencesStateChange } = - setupController({ - options: { - getERC721TokenURI: jest - .fn() - .mockImplementation(() => 'https://testtokenuri.com'), - getERC721OwnerOf: jest.fn().mockImplementation(() => OWNER_ADDRESS), - }, - }); + const { + nftController, + messenger, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + } = setupController({ + options: { + getERC721TokenURI: jest + .fn() + .mockImplementation(() => 'https://testtokenuri.com'), + getERC721OwnerOf: jest.fn().mockImplementation(() => OWNER_ADDRESS), + }, + }); + triggerSelectedAccountChange(OWNER_ACCOUNT); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), isIpfsGatewayEnabled: true, openSeaEnabled: false, - selectedAddress: OWNER_ADDRESS, }); const requestId = 'approval-request-id-1'; @@ -473,11 +554,17 @@ describe('NftController', () => { (v4 as jest.Mock).mockImplementationOnce(() => requestId); - const callActionSpy = jest.spyOn(messenger, 'call').mockResolvedValue({}); + const callActionSpy = jest + .spyOn(messenger, 'call') + .mockReturnValueOnce(OWNER_ACCOUNT) + .mockResolvedValueOnce({}) + .mockReturnValueOnce(OWNER_ACCOUNT); await nftController.watchNft(ERC721_NFT, ERC721, 'https://test-dapp.com'); - expect(callActionSpy).toHaveBeenCalledTimes(1); - expect(callActionSpy).toHaveBeenCalledWith( + // First call is getInternalAccount. Second call is the approval request. + expect(callActionSpy).toHaveBeenCalledTimes(3); + expect(callActionSpy).toHaveBeenNthCalledWith( + 2, 'ApprovalController:addRequest', { id: requestId, @@ -512,20 +599,24 @@ describe('NftController', () => { description: 'testERC721Description', }), ); - const { nftController, messenger, triggerPreferencesStateChange } = - setupController({ - options: { - getERC721TokenURI: jest - .fn() - .mockImplementation(() => 'https://testtokenuri.com'), - getERC721OwnerOf: jest.fn().mockImplementation(() => OWNER_ADDRESS), - }, - }); + const { + nftController, + messenger, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + } = setupController({ + options: { + getERC721TokenURI: jest + .fn() + .mockImplementation(() => 'https://testtokenuri.com'), + getERC721OwnerOf: jest.fn().mockImplementation(() => OWNER_ADDRESS), + }, + }); + triggerSelectedAccountChange(OWNER_ACCOUNT); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), isIpfsGatewayEnabled: true, openSeaEnabled: true, - selectedAddress: OWNER_ADDRESS, }); const requestId = 'approval-request-id-1'; @@ -534,11 +625,17 @@ describe('NftController', () => { (v4 as jest.Mock).mockImplementationOnce(() => requestId); - const callActionSpy = jest.spyOn(messenger, 'call').mockResolvedValue({}); + const callActionSpy = jest + .spyOn(messenger, 'call') + .mockReturnValueOnce(OWNER_ACCOUNT) + .mockResolvedValueOnce({}) + .mockReturnValueOnce(OWNER_ACCOUNT); await nftController.watchNft(ERC721_NFT, ERC721, 'https://test-dapp.com'); - expect(callActionSpy).toHaveBeenCalledTimes(1); - expect(callActionSpy).toHaveBeenCalledWith( + // First call is getInternalAccount. Second call is the approval request. + expect(callActionSpy).toHaveBeenCalledTimes(3); + expect(callActionSpy).toHaveBeenNthCalledWith( + 2, 'ApprovalController:addRequest', { id: requestId, @@ -573,20 +670,24 @@ describe('NftController', () => { description: 'testERC721Description', }), ); - const { nftController, messenger, triggerPreferencesStateChange } = - setupController({ - options: { - getERC721TokenURI: jest - .fn() - .mockImplementation(() => 'ipfs://testtokenuri.com'), - getERC721OwnerOf: jest.fn().mockImplementation(() => OWNER_ADDRESS), - }, - }); + const { + nftController, + messenger, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + } = setupController({ + options: { + getERC721TokenURI: jest + .fn() + .mockImplementation(() => 'ipfs://testtokenuri.com'), + getERC721OwnerOf: jest.fn().mockImplementation(() => OWNER_ADDRESS), + }, + }); + triggerSelectedAccountChange(OWNER_ACCOUNT); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), isIpfsGatewayEnabled: false, openSeaEnabled: false, - selectedAddress: OWNER_ADDRESS, }); const requestId = 'approval-request-id-1'; @@ -595,11 +696,17 @@ describe('NftController', () => { (v4 as jest.Mock).mockImplementationOnce(() => requestId); - const callActionSpy = jest.spyOn(messenger, 'call').mockResolvedValue({}); + const callActionSpy = jest + .spyOn(messenger, 'call') + .mockReturnValueOnce(OWNER_ACCOUNT) + .mockResolvedValueOnce({}) + .mockReturnValueOnce(OWNER_ACCOUNT); await nftController.watchNft(ERC721_NFT, ERC721, 'https://test-dapp.com'); - expect(callActionSpy).toHaveBeenCalledTimes(1); - expect(callActionSpy).toHaveBeenCalledWith( + // First call is getInternalAccount. Second call is the approval request. + expect(callActionSpy).toHaveBeenCalledTimes(3); + expect(callActionSpy).toHaveBeenNthCalledWith( + 2, 'ApprovalController:addRequest', { id: requestId, @@ -634,20 +741,25 @@ describe('NftController', () => { description: 'testERC721Description', }), ); - const { nftController, messenger, triggerPreferencesStateChange } = - setupController({ - options: { - getERC721TokenURI: jest - .fn() - .mockImplementation(() => 'ipfs://testtokenuri.com'), - getERC721OwnerOf: jest.fn().mockImplementation(() => OWNER_ADDRESS), - }, - }); + const { + nftController, + messenger, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + } = setupController({ + options: { + getERC721TokenURI: jest + .fn() + .mockImplementation(() => 'ipfs://testtokenuri.com'), + getERC721OwnerOf: jest.fn().mockImplementation(() => OWNER_ADDRESS), + }, + }); + + triggerSelectedAccountChange(OWNER_ACCOUNT); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), isIpfsGatewayEnabled: false, openSeaEnabled: true, - selectedAddress: OWNER_ADDRESS, }); const requestId = 'approval-request-id-1'; @@ -656,11 +768,17 @@ describe('NftController', () => { (v4 as jest.Mock).mockImplementationOnce(() => requestId); - const callActionSpy = jest.spyOn(messenger, 'call').mockResolvedValue({}); + const callActionSpy = jest + .spyOn(messenger, 'call') + .mockReturnValueOnce(OWNER_ACCOUNT) + .mockResolvedValueOnce({}) + .mockReturnValueOnce(OWNER_ACCOUNT); await nftController.watchNft(ERC721_NFT, ERC721, 'https://test-dapp.com'); - expect(callActionSpy).toHaveBeenCalledTimes(1); - expect(callActionSpy).toHaveBeenCalledWith( + // First call is getInternalAccount. Second call is the approval request. + expect(callActionSpy).toHaveBeenCalledTimes(3); + expect(callActionSpy).toHaveBeenNthCalledWith( + 2, 'ApprovalController:addRequest', { id: requestId, @@ -696,23 +814,28 @@ describe('NftController', () => { }), ); - const { nftController, messenger, triggerPreferencesStateChange } = - setupController({ - options: { - getERC721TokenURI: jest - .fn() - .mockRejectedValue(new Error('Not an ERC721 contract')), - getERC1155TokenURI: jest - .fn() - .mockImplementation(() => 'https://testtokenuri.com'), - getERC1155BalanceOf: jest.fn().mockImplementation(() => new BN(1)), - }, - }); + const { + nftController, + messenger, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + } = setupController({ + options: { + getERC721TokenURI: jest + .fn() + .mockRejectedValue(new Error('Not an ERC721 contract')), + getERC1155TokenURI: jest + .fn() + .mockImplementation(() => 'https://testtokenuri.com'), + getERC1155BalanceOf: jest.fn().mockImplementation(() => new BN(1)), + }, + }); + + triggerSelectedAccountChange(OWNER_ACCOUNT); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), isIpfsGatewayEnabled: true, openSeaEnabled: false, - selectedAddress: OWNER_ADDRESS, }); const requestId = 'approval-request-id-1'; @@ -720,15 +843,21 @@ describe('NftController', () => { (v4 as jest.Mock).mockImplementationOnce(() => requestId); - const callActionSpy = jest.spyOn(messenger, 'call').mockResolvedValue({}); + const callActionSpy = jest + .spyOn(messenger, 'call') + .mockReturnValueOnce(OWNER_ACCOUNT) + .mockResolvedValueOnce({}) + .mockReturnValueOnce(OWNER_ACCOUNT); await nftController.watchNft( ERC1155_NFT, ERC1155, 'https://etherscan.io', ); - expect(callActionSpy).toHaveBeenCalledTimes(1); - expect(callActionSpy).toHaveBeenCalledWith( + // First call is getInternalAccount. Second call is the approval request. + expect(callActionSpy).toHaveBeenCalledTimes(3); + expect(callActionSpy).toHaveBeenNthCalledWith( + 2, 'ApprovalController:addRequest', { id: requestId, @@ -780,7 +909,6 @@ describe('NftController', () => { ...getDefaultPreferencesState(), isIpfsGatewayEnabled: true, openSeaEnabled: true, - selectedAddress: OWNER_ADDRESS, }); const requestId = 'approval-request-id-1'; @@ -788,15 +916,21 @@ describe('NftController', () => { (v4 as jest.Mock).mockImplementationOnce(() => requestId); - const callActionSpy = jest.spyOn(messenger, 'call').mockResolvedValue({}); + const callActionSpy = jest + .spyOn(messenger, 'call') + .mockReturnValueOnce(OWNER_ACCOUNT) + .mockResolvedValueOnce({}) + .mockReturnValue(OWNER_ACCOUNT); await nftController.watchNft( ERC1155_NFT, ERC1155, 'https://etherscan.io', ); - expect(callActionSpy).toHaveBeenCalledTimes(1); - expect(callActionSpy).toHaveBeenCalledWith( + // First call is getInternalAccount. Second call is the approval request. + expect(callActionSpy).toHaveBeenCalledTimes(3); + expect(callActionSpy).toHaveBeenNthCalledWith( + 2, 'ApprovalController:addRequest', { id: requestId, @@ -838,6 +972,7 @@ describe('NftController', () => { approvalController, changeNetwork, triggerPreferencesStateChange, + triggerSelectedAccountChange, } = setupController({ options: { getERC721OwnerOf: jest @@ -882,10 +1017,10 @@ describe('NftController', () => { expect(nftController.state.allNfts).toStrictEqual({}); // this is our account and network status when the watchNFT request is made + triggerSelectedAccountChange(OWNER_ACCOUNT); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: OWNER_ADDRESS, }); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.goerli }); @@ -938,6 +1073,7 @@ describe('NftController', () => { messenger, approvalController, triggerPreferencesStateChange, + triggerSelectedAccountChange, changeNetwork, } = setupController({ options: { @@ -981,6 +1117,7 @@ describe('NftController', () => { expect(nftController.state.allNfts).toStrictEqual({}); // this is our account and network status when the watchNFT request is made + triggerSelectedAccountChange(OWNER_ACCOUNT); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, @@ -994,10 +1131,13 @@ describe('NftController', () => { await pendingRequest; // change the network and selectedAddress before accepting the request + const differentAccount = createMockInternalAccount({ + address: '0xDifferentAddress', + }); + triggerSelectedAccountChange(differentAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: '0xDifferentAddress', }); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); // now accept the request @@ -1048,15 +1188,23 @@ describe('NftController', () => { "Unable to verify ownership. Possibly because the standard is not supported or the user's currently selected network does not match the chain of the asset in question.", ); }); + + // it('handle unset selectedAccount', async function () { + // const { nftController } = setupController({ + // options: { selectedAccountId: '' }, + // }); + // jest.spyOn(nftController, 'addNft'); + + // const erc721Result = nftController.watchNft(ERC721_NFT, type); + // }); }); describe('addNft', () => { it('should add NFT and NFT contract', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { chainId: ChainId.mainnet, - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, getERC721AssetName: jest.fn().mockResolvedValue('Name'), }, }); @@ -1076,7 +1224,7 @@ describe('NftController', () => { }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual({ address: '0x01', description: 'description', @@ -1093,7 +1241,7 @@ describe('NftController', () => { }); expect( - nftController.state.allNftContracts[selectedAddress][ + nftController.state.allNftContracts[OWNER_ACCOUNT.address][ ChainId.mainnet ][0], ).toStrictEqual({ @@ -1165,15 +1313,26 @@ describe('NftController', () => { const mockGetERC721TokenURI = jest.fn().mockResolvedValue(tokenURI); const mockGetERC1155TokenURI = jest.fn().mockRejectedValue(''); - const { nftController, triggerPreferencesStateChange } = setupController({ + const { + nftController, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + getInternalAccountMock, + } = setupController({ options: { getERC721TokenURI: mockGetERC721TokenURI, getERC1155TokenURI: mockGetERC1155TokenURI, }, }); const firstAddress = '0x123'; + const firstAccount = createMockInternalAccount({ address: firstAddress }); const secondAddress = '0x321'; + const secondAccount = createMockInternalAccount({ + address: secondAddress, + }); + getInternalAccountMock.mockReturnValue(firstAccount); + triggerSelectedAccountChange(firstAccount); nock('https://url').get('/').reply(200, { name: 'name', image: 'url', @@ -1182,19 +1341,20 @@ describe('NftController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: firstAddress, }); await nftController.addNft('0x01', '1234'); + getInternalAccountMock.mockReturnValue(secondAccount); + triggerSelectedAccountChange(secondAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: secondAddress, }); await nftController.addNft('0x02', '4321'); + getInternalAccountMock.mockReturnValue(firstAccount); + triggerSelectedAccountChange(firstAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: firstAddress, }); expect( nftController.state.allNfts[firstAddress][ChainId.mainnet][0], @@ -1212,10 +1372,9 @@ describe('NftController', () => { }); it('should update NFT if image is different', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -1230,7 +1389,7 @@ describe('NftController', () => { }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual({ address: '0x01', description: 'description', @@ -1253,7 +1412,7 @@ describe('NftController', () => { }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual({ address: '0x01', description: 'description', @@ -1267,10 +1426,9 @@ describe('NftController', () => { }); it('should not duplicate NFT nor NFT contract if already added', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -1295,19 +1453,20 @@ describe('NftController', () => { }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet], ).toHaveLength(1); expect( - nftController.state.allNftContracts[selectedAddress][ChainId.mainnet], + nftController.state.allNftContracts[OWNER_ACCOUNT.address][ + ChainId.mainnet + ], ).toHaveLength(1); }); it('should add NFT and get information from NFT-API', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, getERC721TokenURI: jest .fn() .mockRejectedValue(new Error('Not an ERC721 contract')), @@ -1319,7 +1478,7 @@ describe('NftController', () => { await nftController.addNft('0x01', '1'); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual({ address: '0x01', description: 'Description', @@ -1336,10 +1495,9 @@ describe('NftController', () => { }); it('should add NFT erc721 and aggregate NFT data from both contract and NFT-API', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, getERC721AssetName: jest.fn().mockResolvedValue('KudosToken'), getERC721AssetSymbol: jest.fn().mockResolvedValue('KDO'), getERC721TokenURI: jest @@ -1377,7 +1535,7 @@ describe('NftController', () => { await nftController.addNft(ERC721_KUDOSADDRESS, ERC721_KUDOS_TOKEN_ID); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual({ address: ERC721_KUDOSADDRESS, image: 'Kudos Image (directly from tokenURI)', @@ -1392,7 +1550,7 @@ describe('NftController', () => { }); expect( - nftController.state.allNftContracts[selectedAddress][ + nftController.state.allNftContracts[OWNER_ACCOUNT.address][ ChainId.mainnet ][0], ).toStrictEqual({ @@ -1404,10 +1562,9 @@ describe('NftController', () => { }); it('should add NFT erc1155 and get NFT information from contract when NFT API call fail', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, getERC721TokenURI: jest .fn() .mockRejectedValue(new Error('Not a 721 contract')), @@ -1433,7 +1590,7 @@ describe('NftController', () => { await nftController.addNft(ERC1155_NFT_ADDRESS, ERC1155_NFT_ID); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual({ address: ERC1155_NFT_ADDRESS, image: 'image (directly from tokenURI)', @@ -1449,10 +1606,9 @@ describe('NftController', () => { }); it('should add NFT erc721 and get NFT information only from contract', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, getERC721AssetName: jest.fn().mockResolvedValue('KudosToken'), getERC721AssetSymbol: jest.fn().mockResolvedValue('KDO'), getERC721TokenURI: jest.fn().mockImplementation((tokenAddress) => { @@ -1482,7 +1638,7 @@ describe('NftController', () => { await nftController.addNft(ERC721_KUDOSADDRESS, ERC721_KUDOS_TOKEN_ID); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual({ address: ERC721_KUDOSADDRESS, image: 'Kudos Image (directly from tokenURI)', @@ -1497,7 +1653,7 @@ describe('NftController', () => { }); expect( - nftController.state.allNftContracts[selectedAddress][ + nftController.state.allNftContracts[OWNER_ACCOUNT.address][ ChainId.mainnet ][0], ).toStrictEqual({ @@ -1509,12 +1665,11 @@ describe('NftController', () => { }); it('should add NFT by provider type', async () => { - const selectedAddress = OWNER_ADDRESS; const tokenURI = 'https://url/'; const mockGetERC721TokenURI = jest.fn().mockResolvedValue(tokenURI); const { nftController, changeNetwork } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, getERC721TokenURI: mockGetERC721TokenURI, }, }); @@ -1530,11 +1685,15 @@ describe('NftController', () => { changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); expect( - nftController.state.allNfts[selectedAddress]?.[ChainId[GOERLI.type]], + nftController.state.allNfts[OWNER_ACCOUNT.address]?.[ + ChainId[GOERLI.type] + ], ).toBeUndefined(); expect( - nftController.state.allNfts[selectedAddress][ChainId[SEPOLIA.type]][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ + ChainId[SEPOLIA.type] + ][0], ).toStrictEqual({ address: '0x01', description: 'description', @@ -1554,10 +1713,9 @@ describe('NftController', () => { const mockGetERC721AssetSymbol = jest.fn().mockResolvedValue(''); const mockGetERC721AssetName = jest.fn().mockResolvedValue(''); const mockGetERC721TokenURI = jest.fn().mockResolvedValue(tokenURI); - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, onNftAdded: mockOnNftAdded, getERC721AssetSymbol: mockGetERC721AssetSymbol, getERC721AssetName: mockGetERC721AssetName, @@ -1574,7 +1732,7 @@ describe('NftController', () => { await nftController.addNft('0x01234abcdefg', '1234'); expect(nftController.state.allNftContracts).toStrictEqual({ - [selectedAddress]: { + [OWNER_ACCOUNT.address]: { [ChainId.mainnet]: [ { address: '0x01234abcdefg', @@ -1585,7 +1743,7 @@ describe('NftController', () => { }); expect(nftController.state.allNfts).toStrictEqual({ - [selectedAddress]: { + [OWNER_ACCOUNT.address]: { [ChainId.mainnet]: [ { address: '0x01234abcdefg', @@ -1676,11 +1834,10 @@ describe('NftController', () => { }); it('should add an nft and nftContract when there is valid contract information and source is "detected"', async () => { - const selectedAddress = OWNER_ADDRESS; const mockOnNftAdded = jest.fn(); const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, onNftAdded: mockOnNftAdded, getERC721AssetName: jest .fn() @@ -1716,26 +1873,28 @@ describe('NftController', () => { '0x6EbeAf8e8E946F0716E6533A6f2cefc83f60e8Ab', '123', { - userAddress: selectedAddress, + userAddress: OWNER_ACCOUNT.address, source: Source.Detected, }, ); expect( - nftController.state.allNfts[selectedAddress]?.[ChainId.mainnet], + nftController.state.allNfts[OWNER_ACCOUNT.address]?.[ChainId.mainnet], ).toBeUndefined(); expect( - nftController.state.allNftContracts[selectedAddress]?.[ChainId.mainnet], + nftController.state.allNftContracts[OWNER_ACCOUNT.address]?.[ + ChainId.mainnet + ], ).toBeUndefined(); await nftController.addNft(ERC721_KUDOSADDRESS, ERC721_KUDOS_TOKEN_ID, { - userAddress: selectedAddress, + userAddress: OWNER_ACCOUNT.address, source: Source.Detected, }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet], ).toStrictEqual([ { address: ERC721_KUDOSADDRESS, @@ -1756,7 +1915,9 @@ describe('NftController', () => { ]); expect( - nftController.state.allNftContracts[selectedAddress][ChainId.mainnet], + nftController.state.allNftContracts[OWNER_ACCOUNT.address][ + ChainId.mainnet + ], ).toStrictEqual([ { address: ERC721_KUDOSADDRESS, @@ -1776,11 +1937,10 @@ describe('NftController', () => { }); it('should not add an nft and nftContract when there is not valid contract information (or an issue fetching it) and source is "detected"', async () => { - const selectedAddress = OWNER_ADDRESS; const mockOnNftAdded = jest.fn(); const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, onNftAdded: mockOnNftAdded, getERC721AssetName: jest .fn() @@ -1799,12 +1959,12 @@ describe('NftController', () => { '0x6EbeAf8e8E946F0716E6533A6f2cefc83f60e8Ab', '123', { - userAddress: selectedAddress, + userAddress: OWNER_ACCOUNT.address, source: Source.Detected, }, ); await nftController.addNft(ERC721_KUDOSADDRESS, ERC721_KUDOS_TOKEN_ID, { - userAddress: selectedAddress, + userAddress: OWNER_ACCOUNT.address, source: Source.Detected, }); @@ -1814,10 +1974,9 @@ describe('NftController', () => { }); it('should not add duplicate NFTs to the ignoredNfts list', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -1840,13 +1999,13 @@ describe('NftController', () => { }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet], ).toHaveLength(2); expect(nftController.state.ignoredNfts).toHaveLength(0); nftController.removeAndIgnoreNft('0x01', '1'); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet], ).toHaveLength(1); expect(nftController.state.ignoredNfts).toHaveLength(1); @@ -1860,20 +2019,23 @@ describe('NftController', () => { }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet], ).toHaveLength(2); expect(nftController.state.ignoredNfts).toHaveLength(1); nftController.removeAndIgnoreNft('0x01', '1'); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet], ).toHaveLength(1); expect(nftController.state.ignoredNfts).toHaveLength(1); }); it('should add NFT with metadata hosted in IPFS', async () => { - const selectedAddress = OWNER_ADDRESS; - const { nftController, triggerPreferencesStateChange } = setupController({ + const { + nftController, + triggerPreferencesStateChange, + getInternalAccountMock, + } = setupController({ options: { getERC721AssetName: jest .fn() @@ -1892,9 +2054,9 @@ describe('NftController', () => { .mockRejectedValue(new Error('Not an ERC1155 token')), }, }); + getInternalAccountMock.mockReturnValue(OWNER_ACCOUNT); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, ipfsGateway: IPFS_DEFAULT_GATEWAY_URL, }); @@ -1904,7 +2066,7 @@ describe('NftController', () => { ); expect( - nftController.state.allNftContracts[selectedAddress][ + nftController.state.allNftContracts[OWNER_ACCOUNT.address][ ChainId.mainnet ][0], ).toStrictEqual({ @@ -1914,7 +2076,7 @@ describe('NftController', () => { schemaName: ERC721, }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual({ address: ERC721_DEPRESSIONIST_ADDRESS, tokenId: '36', @@ -1930,7 +2092,6 @@ describe('NftController', () => { }); it('should add NFT erc721 when call to NFT API fail', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController(); nock(NFT_API_BASE_URL) .get( @@ -1941,7 +2102,7 @@ describe('NftController', () => { await nftController.addNft(ERC721_NFT_ADDRESS, ERC721_NFT_ID); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual({ address: ERC721_NFT_ADDRESS, image: null, @@ -2191,6 +2352,34 @@ describe('NftController', () => { }, ]); }); + + it('should handle unset selectedAccount', async () => { + const { nftController, getInternalAccountMock } = setupController({ + options: { + chainId: ChainId.mainnet, + selectedAccountId: '', + getERC721AssetName: jest.fn().mockResolvedValue('Name'), + }, + }); + + getInternalAccountMock.mockReturnValue(null); + + await nftController.addNft('0x01', '1', { + nftMetadata: { + name: 'name', + image: 'image', + description: 'description', + standard: 'standard', + favorite: false, + collection: { + tokenCount: '0', + image: 'url', + }, + }, + }); + + expect(nftController.state.allNftContracts['']).toBeUndefined(); + }); }); describe('addNftVerifyOwnership', () => { @@ -2198,13 +2387,28 @@ describe('NftController', () => { const tokenURI = 'https://url/'; const mockGetERC721TokenURI = jest.fn().mockResolvedValue(tokenURI); - const { nftController, triggerPreferencesStateChange } = setupController({ + const { + nftController, + getInternalAccountMock, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + } = setupController({ options: { getERC721TokenURI: mockGetERC721TokenURI, }, }); const firstAddress = '0x123'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + id: '22c022b5-309c-45e4-a82d-64bb11fc0e74', + }); const secondAddress = '0x321'; + const secondAccount = createMockInternalAccount({ + address: secondAddress, + id: 'f9a42417-6071-4b51-8ecd-f7b14abd8851', + }); + getInternalAccountMock.mockReturnValue(firstAccount); + triggerSelectedAccountChange(firstAccount); jest.spyOn(nftController, 'isNftOwner').mockResolvedValue(true); nock('https://url').get('/').reply(200, { @@ -2215,22 +2419,23 @@ describe('NftController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: firstAddress, }); await nftController.addNftVerifyOwnership('0x01', '1234'); + getInternalAccountMock.mockReturnValue(secondAccount); + triggerSelectedAccountChange(secondAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: secondAddress, }); await nftController.addNftVerifyOwnership('0x02', '4321'); + getInternalAccountMock.mockReturnValue(firstAccount); + triggerSelectedAccountChange(firstAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: firstAddress, }); expect( - nftController.state.allNfts[firstAddress][ChainId.mainnet][0], + nftController.state.allNfts[firstAccount.address][ChainId.mainnet][0], ).toStrictEqual({ address: '0x01', description: 'description', @@ -2245,14 +2450,24 @@ describe('NftController', () => { }); it('should throw an error if selected address is not owner of input NFT', async () => { - const { nftController, triggerPreferencesStateChange } = - setupController(); + const { + nftController, + getInternalAccountMock, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + } = setupController(); + // TODO: Replace `any` with type jest.spyOn(nftController, 'isNftOwner').mockResolvedValue(false); const firstAddress = '0x123'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + id: '22c022b5-309c-45e4-a82d-64bb11fc0e74', + }); + getInternalAccountMock.mockReturnValue(firstAccount); + triggerSelectedAccountChange(firstAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: firstAddress, }); const result = async () => await nftController.addNftVerifyOwnership('0x01', '1234'); @@ -2263,14 +2478,27 @@ describe('NftController', () => { it('should verify ownership by selected address and add NFT by the correct chainId when passed networkClientId', async () => { const tokenURI = 'https://url/'; const mockGetERC721TokenURI = jest.fn().mockResolvedValue(tokenURI); - const { nftController, triggerPreferencesStateChange } = setupController({ + const { + nftController, + triggerPreferencesStateChange, + getInternalAccountMock, + triggerSelectedAccountChange, + } = setupController({ options: { getERC721TokenURI: mockGetERC721TokenURI, }, }); const firstAddress = '0x123'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + id: '22c022b5-309c-45e4-a82d-64bb11fc0e74', + }); const secondAddress = '0x321'; + const secondAccount = createMockInternalAccount({ + address: secondAddress, + id: 'f9a42417-6071-4b51-8ecd-f7b14abd8851', + }); jest.spyOn(nftController, 'isNftOwner').mockResolvedValue(true); @@ -2282,25 +2510,27 @@ describe('NftController', () => { description: 'description', }) .persist(); + getInternalAccountMock.mockReturnValue(firstAccount); + triggerSelectedAccountChange(firstAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: firstAddress, }); await nftController.addNftVerifyOwnership('0x01', '1234', { networkClientId: 'sepolia', }); + getInternalAccountMock.mockReturnValue(secondAccount); + triggerSelectedAccountChange(secondAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: secondAddress, }); await nftController.addNftVerifyOwnership('0x02', '4321', { networkClientId: 'goerli', }); expect( - nftController.state.allNfts[firstAddress][SEPOLIA.chainId][0], + nftController.state.allNfts[firstAccount.address][SEPOLIA.chainId][0], ).toStrictEqual({ address: '0x01', description: 'description', @@ -2313,7 +2543,7 @@ describe('NftController', () => { tokenURI, }); expect( - nftController.state.allNfts[secondAddress][GOERLI.chainId][0], + nftController.state.allNfts[secondAccount.address][GOERLI.chainId][0], ).toStrictEqual({ address: '0x02', description: 'description', @@ -2330,17 +2560,21 @@ describe('NftController', () => { it('should verify ownership by selected address and add NFT by the correct userAddress when passed userAddress', async () => { const tokenURI = 'https://url/'; const mockGetERC721TokenURI = jest.fn().mockResolvedValue(tokenURI); - const { nftController, changeNetwork, triggerPreferencesStateChange } = - setupController({ - options: { - getERC721TokenURI: mockGetERC721TokenURI, - }, - }); + const { + nftController, + changeNetwork, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + } = setupController({ + options: { + getERC721TokenURI: mockGetERC721TokenURI, + }, + }); // Ensure that the currently selected address is not the same as either of the userAddresses + triggerSelectedAccountChange(OWNER_ACCOUNT); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: OWNER_ADDRESS, }); const firstAddress = '0x123'; @@ -2396,10 +2630,9 @@ describe('NftController', () => { describe('removeNft', () => { it('should remove NFT and NFT contract', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -2413,16 +2646,17 @@ describe('NftController', () => { }); nftController.removeNft('0x01', '1'); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet], ).toHaveLength(0); expect( - nftController.state.allNftContracts[selectedAddress][ChainId.mainnet], + nftController.state.allNftContracts[OWNER_ACCOUNT.address][ + ChainId.mainnet + ], ).toHaveLength(0); }); it('should not remove NFT contract if NFT still exists', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController(); await nftController.addNft('0x01', '1', { @@ -2444,18 +2678,25 @@ describe('NftController', () => { }); nftController.removeNft('0x01', '1'); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet], ).toHaveLength(1); expect( - nftController.state.allNftContracts[selectedAddress][ChainId.mainnet], + nftController.state.allNftContracts[OWNER_ACCOUNT.address][ + ChainId.mainnet + ], ).toHaveLength(1); }); it('should remove NFT by selected address', async () => { const tokenURI = 'https://url/'; const mockGetERC721TokenURI = jest.fn().mockResolvedValue(tokenURI); - const { nftController, triggerPreferencesStateChange } = setupController({ + const { + nftController, + triggerPreferencesStateChange, + getInternalAccountMock, + triggerSelectedAccountChange, + } = setupController({ options: { getERC721TokenURI: mockGetERC721TokenURI, }, @@ -2466,30 +2707,39 @@ describe('NftController', () => { description: 'description', }); const firstAddress = '0x123'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + id: '22c022b5-309c-45e4-a82d-64bb11fc0e74', + }); const secondAddress = '0x321'; + const secondAccount = createMockInternalAccount({ + address: secondAddress, + id: 'f9a42417-6071-4b51-8ecd-f7b14abd8851', + }); + getInternalAccountMock.mockReturnValue(firstAccount); + triggerSelectedAccountChange(firstAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: firstAddress, }); await nftController.addNft('0x02', '4321'); + getInternalAccountMock.mockReturnValue(secondAccount); + triggerSelectedAccountChange(secondAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: secondAddress, }); await nftController.addNft('0x01', '1234'); nftController.removeNft('0x01', '1234'); expect( - nftController.state.allNfts[secondAddress][ChainId.mainnet], + nftController.state.allNfts[secondAccount.address][ChainId.mainnet], ).toHaveLength(0); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: firstAddress, }); expect( - nftController.state.allNfts[firstAddress][ChainId.mainnet][0], + nftController.state.allNfts[firstAccount.address][ChainId.mainnet][0], ).toStrictEqual({ address: '0x02', description: 'description', @@ -2504,12 +2754,11 @@ describe('NftController', () => { }); it('should remove NFT by provider type', async () => { - const selectedAddress = OWNER_ADDRESS; const tokenURI = 'https://url/'; const mockGetERC721TokenURI = jest.fn().mockResolvedValue(tokenURI); const { nftController, changeNetwork } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, getERC721TokenURI: mockGetERC721TokenURI, }, }); @@ -2525,13 +2774,13 @@ describe('NftController', () => { await nftController.addNft('0x01', '1234'); nftController.removeNft('0x01', '1234'); expect( - nftController.state.allNfts[selectedAddress][GOERLI.chainId], + nftController.state.allNfts[OWNER_ACCOUNT.address][GOERLI.chainId], ).toHaveLength(0); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); expect( - nftController.state.allNfts[selectedAddress][SEPOLIA.chainId][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][SEPOLIA.chainId][0], ).toStrictEqual({ address: '0x02', description: 'description', @@ -2546,17 +2795,31 @@ describe('NftController', () => { }); it('should remove correct NFT and NFT contract when passed networkClientId and userAddress in options', async () => { - const { nftController, changeNetwork, triggerPreferencesStateChange } = - setupController(); + const { + nftController, + changeNetwork, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + getInternalAccountMock, + } = setupController(); const userAddress1 = '0x123'; + const userAccount1 = createMockInternalAccount({ + address: userAddress1, + id: '5fd59cae-95d3-4a1d-ba97-657c8f83c300', + }); const userAddress2 = '0x321'; + const userAccount2 = createMockInternalAccount({ + address: userAddress2, + id: '9ea40063-a95c-4f79-a4b6-0c065549245e', + }); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); + getInternalAccountMock.mockReturnValue(userAccount1); + triggerSelectedAccountChange(userAccount1); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: userAddress1, }); await nftController.addNft('0x01', '1', { @@ -2582,10 +2845,11 @@ describe('NftController', () => { }); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.goerli }); + getInternalAccountMock.mockReturnValue(userAccount2); + triggerSelectedAccountChange(userAccount2); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: userAddress2, }); // now remove the nft after changing to a different network and account from the one where it was added @@ -2605,10 +2869,9 @@ describe('NftController', () => { }); it('should be able to clear the ignoredNfts list', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -2623,13 +2886,13 @@ describe('NftController', () => { }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet], ).toHaveLength(1); expect(nftController.state.ignoredNfts).toHaveLength(0); nftController.removeAndIgnoreNft('0x02', '1'); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet], ).toHaveLength(0); expect(nftController.state.ignoredNfts).toHaveLength(1); @@ -2766,17 +3029,19 @@ describe('NftController', () => { }); it('should add NFT with null metadata if the ipfs gateway is disabled and opensea is disabled', async () => { - const selectedAddress = OWNER_ADDRESS; - const { nftController, triggerPreferencesStateChange } = setupController({ + const { + nftController, + triggerPreferencesStateChange, + getInternalAccountMock, + } = setupController({ options: { getERC721TokenURI: jest.fn().mockRejectedValue(''), getERC1155TokenURI: jest.fn().mockResolvedValue('ipfs://*'), }, }); - + getInternalAccountMock.mockReturnValue(OWNER_ACCOUNT); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, isIpfsGatewayEnabled: false, openSeaEnabled: false, }); @@ -2784,7 +3049,7 @@ describe('NftController', () => { await nftController.addNft(ERC1155_NFT_ADDRESS, ERC1155_NFT_ID); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual({ address: ERC1155_NFT_ADDRESS, name: null, @@ -2801,10 +3066,9 @@ describe('NftController', () => { describe('updateNftFavoriteStatus', () => { it('should not set NFT as favorite if nft not found', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -2821,7 +3085,7 @@ describe('NftController', () => { ); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual( expect.objectContaining({ address: ERC721_DEPRESSIONIST_ADDRESS, @@ -2831,10 +3095,9 @@ describe('NftController', () => { ); }); it('should set NFT as favorite', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -2851,7 +3114,7 @@ describe('NftController', () => { ); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual( expect.objectContaining({ address: ERC721_DEPRESSIONIST_ADDRESS, @@ -2862,10 +3125,9 @@ describe('NftController', () => { }); it('should set NFT as favorite and then unset it', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -2882,7 +3144,7 @@ describe('NftController', () => { ); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual( expect.objectContaining({ address: ERC721_DEPRESSIONIST_ADDRESS, @@ -2898,7 +3160,7 @@ describe('NftController', () => { ); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual( expect.objectContaining({ address: ERC721_DEPRESSIONIST_ADDRESS, @@ -2909,10 +3171,9 @@ describe('NftController', () => { }); it('should keep the favorite status as true after updating metadata', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -2929,7 +3190,7 @@ describe('NftController', () => { ); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual( expect.objectContaining({ address: ERC721_DEPRESSIONIST_ADDRESS, @@ -2952,7 +3213,7 @@ describe('NftController', () => { ); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual( expect.objectContaining({ image: 'new_image', @@ -2966,15 +3227,14 @@ describe('NftController', () => { ); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet], ).toHaveLength(1); }); it('should keep the favorite status as false after updating metadata', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -2985,7 +3245,7 @@ describe('NftController', () => { ); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual( expect.objectContaining({ address: ERC721_DEPRESSIONIST_ADDRESS, @@ -3008,7 +3268,7 @@ describe('NftController', () => { ); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual( expect.objectContaining({ image: 'new_image', @@ -3022,22 +3282,36 @@ describe('NftController', () => { ); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet], ).toHaveLength(1); }); it('should set NFT as favorite when passed networkClientId and userAddress in options', async () => { - const { nftController, triggerPreferencesStateChange, changeNetwork } = - setupController(); + const { + nftController, + triggerPreferencesStateChange, + changeNetwork, + triggerSelectedAccountChange, + getInternalAccountMock, + } = setupController(); const userAddress1 = '0x123'; + const userAccount1 = createMockInternalAccount({ + address: userAddress1, + id: '0a2a9a41-2b35-4863-8f36-baceec4e9686', + }); const userAddress2 = '0x321'; + const userAccount2 = createMockInternalAccount({ + address: userAddress2, + id: '09b239a4-c229-4a2b-9739-1cb4b9dea7b9', + }); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); + getInternalAccountMock.mockReturnValue(userAccount1); + triggerSelectedAccountChange(userAccount1); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: userAddress1, }); await nftController.addNft( @@ -3047,7 +3321,7 @@ describe('NftController', () => { ); expect( - nftController.state.allNfts[userAddress1][SEPOLIA.chainId][0], + nftController.state.allNfts[userAccount1.address][SEPOLIA.chainId][0], ).toStrictEqual( expect.objectContaining({ address: ERC721_DEPRESSIONIST_ADDRESS, @@ -3057,10 +3331,11 @@ describe('NftController', () => { ); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.goerli }); + getInternalAccountMock.mockReturnValue(userAccount2); + triggerSelectedAccountChange(userAccount2); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: userAddress2, }); // now favorite the nft after changing to a different account from the one where it was added @@ -3070,12 +3345,12 @@ describe('NftController', () => { true, { networkClientId: SEPOLIA.type, - userAddress: userAddress1, + userAddress: userAccount1.address, }, ); expect( - nftController.state.allNfts[userAddress1][SEPOLIA.chainId][0], + nftController.state.allNfts[userAccount1.address][SEPOLIA.chainId][0], ).toStrictEqual( expect.objectContaining({ address: ERC721_DEPRESSIONIST_ADDRESS, @@ -3088,11 +3363,10 @@ describe('NftController', () => { describe('checkAndUpdateNftsOwnershipStatus', () => { describe('checkAndUpdateAllNftsOwnershipStatus', () => { - it('should check whether NFTs for the current selectedAddress/chainId combination are still owned by the selectedAddress and update the isCurrentlyOwned value to false when NFT is not still owned', async () => { - const selectedAddress = OWNER_ADDRESS; + it('should check whether NFTs for the current selectedAccount/chainId combination are still owned by the selectedAccount and update the isCurrentlyOwned value to false when NFT is not still owned', async () => { const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); jest.spyOn(nftController, 'isNftOwner').mockResolvedValue(false); @@ -3107,23 +3381,22 @@ describe('NftController', () => { }, }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0] + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0] .isCurrentlyOwned, ).toBe(true); await nftController.checkAndUpdateAllNftsOwnershipStatus(); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0] + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0] .isCurrentlyOwned, ).toBe(false); }); - it('should check whether NFTs for the current selectedAddress/chainId combination are still owned by the selectedAddress and leave/set the isCurrentlyOwned value to true when NFT is still owned', async () => { - const selectedAddress = OWNER_ADDRESS; + it('should check whether NFTs for the current selectedAccount/chainId combination are still owned by the selectedAccount and leave/set the isCurrentlyOwned value to true when NFT is still owned', async () => { const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); jest.spyOn(nftController, 'isNftOwner').mockResolvedValue(true); @@ -3139,22 +3412,21 @@ describe('NftController', () => { }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0] + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0] .isCurrentlyOwned, ).toBe(true); await nftController.checkAndUpdateAllNftsOwnershipStatus(); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0] + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0] .isCurrentlyOwned, ).toBe(true); }); - it('should check whether NFTs for the current selectedAddress/chainId combination are still owned by the selectedAddress and leave the isCurrentlyOwned value as is when NFT ownership check fails', async () => { - const selectedAddress = OWNER_ADDRESS; + it('should check whether NFTs for the current selectedAccount/chainId combination are still owned by the selectedAccount and leave the isCurrentlyOwned value as is when NFT ownership check fails', async () => { const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); jest @@ -3172,26 +3444,29 @@ describe('NftController', () => { }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0] + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0] .isCurrentlyOwned, ).toBe(true); await nftController.checkAndUpdateAllNftsOwnershipStatus(); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0] + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0] .isCurrentlyOwned, ).toBe(true); }); - it('should check whether NFTs for the current selectedAddress/chainId combination are still owned by the selectedAddress and update the isCurrentlyOwned value to false when NFT is not still owned, when the currently configured selectedAddress/chainId are different from those passed', async () => { - const selectedAddress = OWNER_ADDRESS; - const { nftController, changeNetwork, triggerPreferencesStateChange } = - setupController(); + it('should check whether NFTs for the current selectedAccount/chainId combination are still owned by the selectedAccount and update the isCurrentlyOwned value to false when NFT is not still owned, when the currently configured selectedAccount/chainId are different from those passed', async () => { + const { + nftController, + changeNetwork, + triggerPreferencesStateChange, + getInternalAccountMock, + } = setupController(); + getInternalAccountMock.mockReturnValue(OWNER_ACCOUNT); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress, }); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); @@ -3206,7 +3481,7 @@ describe('NftController', () => { }); expect( - nftController.state.allNfts[selectedAddress][ChainId.sepolia][0] + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.sepolia][0] .isCurrentlyOwned, ).toBe(true); @@ -3215,7 +3490,6 @@ describe('NftController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: SECOND_OWNER_ADDRESS, }); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.goerli }); @@ -3229,14 +3503,38 @@ describe('NftController', () => { .isCurrentlyOwned, ).toBe(false); }); + + it('should handle default case where selectedAccount is not set', async () => { + const { nftController, getInternalAccountMock } = setupController({ + options: { + selectedAccountId: '', + }, + }); + getInternalAccountMock.mockReturnValue(null); + jest.spyOn(nftController, 'isNftOwner').mockResolvedValue(false); + + await nftController.addNft('0x02', '1', { + nftMetadata: { + name: 'name', + image: 'image', + description: 'description', + standard: 'standard', + favorite: false, + }, + }); + expect(nftController.state.allNfts['']).toBeUndefined(); + + await nftController.checkAndUpdateAllNftsOwnershipStatus(); + + expect(nftController.state.allNfts['']).toBeUndefined(); + }); }); describe('checkAndUpdateSingleNftOwnershipStatus', () => { - it('should check whether the passed NFT is still owned by the the current selectedAddress/chainId combination and update its isCurrentlyOwned property in state if batch is false and isNftOwner returns false', async () => { - const selectedAddress = OWNER_ADDRESS; + it('should check whether the passed NFT is still owned by the the current selectedAccount/chainId combination and update its isCurrentlyOwned property in state if batch is false and isNftOwner returns false', async () => { const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -3255,7 +3553,7 @@ describe('NftController', () => { }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0] + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0] .isCurrentlyOwned, ).toBe(true); @@ -3264,16 +3562,15 @@ describe('NftController', () => { await nftController.checkAndUpdateSingleNftOwnershipStatus(nft, false); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0] + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0] .isCurrentlyOwned, ).toBe(false); }); it('should check whether the passed NFT is still owned by the the current selectedAddress/chainId combination and return the updated NFT object without updating state if batch is true', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -3292,7 +3589,7 @@ describe('NftController', () => { }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0] + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0] .isCurrentlyOwned, ).toBe(true); @@ -3302,22 +3599,26 @@ describe('NftController', () => { await nftController.checkAndUpdateSingleNftOwnershipStatus(nft, true); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0] + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0] .isCurrentlyOwned, ).toBe(true); - expect(updatedNft.isCurrentlyOwned).toBe(false); + expect(updatedNft?.isCurrentlyOwned).toBe(false); }); it('should check whether the passed NFT is still owned by the the selectedAddress/chainId combination passed in the accountParams argument and update its isCurrentlyOwned property in state, when the currently configured selectedAddress/chainId are different from those passed', async () => { - const firstSelectedAddress = OWNER_ADDRESS; - const { nftController, changeNetwork, triggerPreferencesStateChange } = - setupController(); - + const firstSelectedAddress = OWNER_ACCOUNT.address; + const { + nftController, + changeNetwork, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + } = setupController(); + + triggerSelectedAccountChange(OWNER_ACCOUNT); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: firstSelectedAddress, }); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); @@ -3341,11 +3642,13 @@ describe('NftController', () => { ).toBe(true); jest.spyOn(nftController, 'isNftOwner').mockResolvedValue(false); - + const secondAccount = createMockInternalAccount({ + address: SECOND_OWNER_ADDRESS, + }); + triggerSelectedAccountChange(secondAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: SECOND_OWNER_ADDRESS, }); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.goerli }); @@ -3361,14 +3664,18 @@ describe('NftController', () => { }); it('should check whether the passed NFT is still owned by the the selectedAddress/chainId combination passed in the accountParams argument and return the updated NFT object without updating state, when the currently configured selectedAddress/chainId are different from those passed and batch is true', async () => { - const firstSelectedAddress = OWNER_ADDRESS; - const { nftController, changeNetwork, triggerPreferencesStateChange } = - setupController(); - + const firstSelectedAddress = OWNER_ACCOUNT.address; + const { + nftController, + changeNetwork, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + } = setupController(); + + triggerSelectedAccountChange(OWNER_ACCOUNT); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: OWNER_ADDRESS, }); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); @@ -3392,11 +3699,13 @@ describe('NftController', () => { ).toBe(true); jest.spyOn(nftController, 'isNftOwner').mockResolvedValue(false); - + const secondAccount = createMockInternalAccount({ + address: SECOND_OWNER_ADDRESS, + }); + triggerSelectedAccountChange(secondAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), openSeaEnabled: true, - selectedAddress: SECOND_OWNER_ADDRESS, }); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.goerli }); @@ -3435,10 +3744,9 @@ describe('NftController', () => { }; it('should return null if the NFT does not exist in the state', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -3446,20 +3754,19 @@ describe('NftController', () => { nftController.findNftByAddressAndTokenId( mockNft.address, mockNft.tokenId, - selectedAddress, + OWNER_ACCOUNT.address, ChainId.mainnet, ), ).toBeNull(); }); it('should return the NFT by the address and tokenId', () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, state: { allNfts: { - [OWNER_ADDRESS]: { [ChainId.mainnet]: [mockNft] }, + [OWNER_ACCOUNT.address]: { [ChainId.mainnet]: [mockNft] }, }, }, }, @@ -3469,7 +3776,7 @@ describe('NftController', () => { nftController.findNftByAddressAndTokenId( mockNft.address, mockNft.tokenId, - selectedAddress, + OWNER_ACCOUNT.address, ChainId.mainnet, ), ).toStrictEqual({ nft: mockNft, index: 0 }); @@ -3477,7 +3784,6 @@ describe('NftController', () => { }); describe('updateNftByAddressAndTokenId', () => { - const selectedAddress = OWNER_ADDRESS; const mockTransactionId = '60d36710-b150-11ec-8a49-c377fbd05e27'; const mockNft = { address: '0x02', @@ -3503,10 +3809,10 @@ describe('NftController', () => { it('should update the NFT if the NFT exist', async () => { const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, state: { allNfts: { - [OWNER_ADDRESS]: { [ChainId.mainnet]: [mockNft] }, + [OWNER_ACCOUNT.address]: { [ChainId.mainnet]: [mockNft] }, }, }, }, @@ -3517,19 +3823,19 @@ describe('NftController', () => { { transactionId: mockTransactionId, }, - selectedAddress, + OWNER_ACCOUNT.address, ChainId.mainnet, ); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0], ).toStrictEqual(expectedMockNft); }); it('should return undefined if the NFT does not exist', () => { const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); @@ -3539,7 +3845,7 @@ describe('NftController', () => { { transactionId: mockTransactionId, }, - selectedAddress, + OWNER_ACCOUNT.address, ChainId.mainnet, ), ).toBeUndefined(); @@ -3562,27 +3868,25 @@ describe('NftController', () => { }; it('should not update any NFT state and should return false when passed a transaction id that does not match that of any NFT', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, }, }); expect( nftController.resetNftTransactionStatusByTransactionId( nonExistTransactionId, - selectedAddress, + OWNER_ACCOUNT.address, ChainId.mainnet, ), ).toBe(false); }); it('should set the transaction id of an NFT in state to undefined, and return true when it has successfully updated this state', async () => { - const selectedAddress = OWNER_ADDRESS; const { nftController } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, state: { allNfts: { [OWNER_ADDRESS]: { [ChainId.mainnet]: [mockNft] }, @@ -3592,20 +3896,20 @@ describe('NftController', () => { }); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0] + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0] .transactionId, ).toBe(mockTransactionId); expect( nftController.resetNftTransactionStatusByTransactionId( mockTransactionId, - selectedAddress, + OWNER_ACCOUNT.address, ChainId.mainnet, ), ).toBe(true); expect( - nftController.state.allNfts[selectedAddress][ChainId.mainnet][0] + nftController.state.allNfts[OWNER_ACCOUNT.address][ChainId.mainnet][0] .transactionId, ).toBeUndefined(); }); @@ -3613,17 +3917,17 @@ describe('NftController', () => { describe('updateNftMetadata', () => { it('should update Nft metadata successfully', async () => { - const selectedAddress = OWNER_ADDRESS; const tokenURI = 'https://api.pudgypenguins.io/lil/4'; const mockGetERC721TokenURI = jest.fn().mockResolvedValue(tokenURI); - const { nftController } = setupController({ + const { nftController, getInternalAccountMock } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, getERC721TokenURI: mockGetERC721TokenURI, }, }); const spy = jest.spyOn(nftController, 'updateNft'); const testNetworkClientId = 'sepolia'; + getInternalAccountMock.mockReturnValue(OWNER_ACCOUNT); await nftController.addNft('0xtest', '3', { nftMetadata: { name: '', description: '', image: '', standard: '' }, networkClientId: testNetworkClientId, @@ -3655,7 +3959,7 @@ describe('NftController', () => { expect(spy).toHaveBeenCalledTimes(1); expect( - nftController.state.allNfts[selectedAddress][SEPOLIA.chainId][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][SEPOLIA.chainId][0], ).toStrictEqual({ address: '0xtest', description: 'description pudgy', @@ -3670,17 +3974,17 @@ describe('NftController', () => { }); it('should not update metadata when state nft and fetched nft are the same', async () => { - const selectedAddress = OWNER_ADDRESS; const tokenURI = 'https://url/'; const mockGetERC721TokenURI = jest.fn().mockResolvedValue(tokenURI); - const { nftController } = setupController({ + const { nftController, getInternalAccountMock } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, getERC721TokenURI: mockGetERC721TokenURI, }, }); const updateNftSpy = jest.spyOn(nftController, 'updateNft'); const testNetworkClientId = 'sepolia'; + getInternalAccountMock.mockReturnValue(OWNER_ACCOUNT); await nftController.addNft('0xtest', '3', { nftMetadata: { name: 'toto', @@ -3713,6 +4017,7 @@ describe('NftController', () => { }, ]; + getInternalAccountMock.mockReturnValue(OWNER_ACCOUNT); await nftController.updateNftMetadata({ nfts: testInputNfts, networkClientId: testNetworkClientId, @@ -3720,7 +4025,7 @@ describe('NftController', () => { expect(updateNftSpy).toHaveBeenCalledTimes(0); expect( - nftController.state.allNfts[selectedAddress][SEPOLIA.chainId][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][SEPOLIA.chainId][0], ).toStrictEqual({ address: '0xtest', description: 'description', @@ -3735,17 +4040,17 @@ describe('NftController', () => { }); it('should trigger update metadata when state nft and fetched nft are not the same', async () => { - const selectedAddress = OWNER_ADDRESS; const tokenURI = 'https://url/'; const mockGetERC721TokenURI = jest.fn().mockResolvedValue(tokenURI); - const { nftController } = setupController({ + const { nftController, getInternalAccountMock } = setupController({ options: { - selectedAddress, + selectedAccountId: OWNER_ACCOUNT.id, getERC721TokenURI: mockGetERC721TokenURI, }, }); const spy = jest.spyOn(nftController, 'updateNft'); const testNetworkClientId = 'sepolia'; + getInternalAccountMock.mockReturnValue(OWNER_ACCOUNT); await nftController.addNft('0xtest', '3', { nftMetadata: { name: 'toto', @@ -3781,7 +4086,7 @@ describe('NftController', () => { expect(spy).toHaveBeenCalledTimes(1); expect( - nftController.state.allNfts[selectedAddress][SEPOLIA.chainId][0], + nftController.state.allNfts[OWNER_ACCOUNT.address][SEPOLIA.chainId][0], ).toStrictEqual({ address: '0xtest', description: 'description', @@ -3796,8 +4101,11 @@ describe('NftController', () => { }); it('should not update metadata when nfts has image/name/description already', async () => { - const { nftController, triggerPreferencesStateChange } = - setupController(); + const { + nftController, + triggerPreferencesStateChange, + triggerSelectedAccountChange, + } = setupController(); const spy = jest.spyOn(nftController, 'updateNftMetadata'); const testNetworkClientId = 'sepolia'; @@ -3813,12 +4121,12 @@ describe('NftController', () => { networkClientId: testNetworkClientId, }); + triggerSelectedAccountChange(OWNER_ACCOUNT); // trigger preference change triggerPreferencesStateChange({ ...getDefaultPreferencesState(), isIpfsGatewayEnabled: false, openSeaEnabled: true, - selectedAddress: OWNER_ADDRESS, }); expect(spy).toHaveBeenCalledTimes(0); @@ -3827,12 +4135,16 @@ describe('NftController', () => { it('should trigger calling updateNftMetadata when preferences change - openseaEnabled', async () => { const tokenURI = 'https://url/'; const mockGetERC721TokenURI = jest.fn().mockResolvedValue(tokenURI); - const { nftController, triggerPreferencesStateChange, changeNetwork } = - setupController({ - options: { - getERC721TokenURI: mockGetERC721TokenURI, - }, - }); + const { + nftController, + triggerPreferencesStateChange, + changeNetwork, + triggerSelectedAccountChange, + } = setupController({ + options: { + getERC721TokenURI: mockGetERC721TokenURI, + }, + }); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); const spy = jest.spyOn(nftController, 'updateNftMetadata'); @@ -3865,21 +4177,24 @@ describe('NftController', () => { ...getDefaultPreferencesState(), isIpfsGatewayEnabled: false, openSeaEnabled: true, - selectedAddress: OWNER_ADDRESS, }); - + triggerSelectedAccountChange(OWNER_ACCOUNT); expect(spy).toHaveBeenCalledTimes(1); }); it('should trigger calling updateNftMetadata when preferences change - ipfs enabled', async () => { const tokenURI = 'https://url/'; const mockGetERC721TokenURI = jest.fn().mockResolvedValue(tokenURI); - const { nftController, triggerPreferencesStateChange, changeNetwork } = - setupController({ - options: { - getERC721TokenURI: mockGetERC721TokenURI, - }, - }); + const { + nftController, + triggerPreferencesStateChange, + changeNetwork, + triggerSelectedAccountChange, + } = setupController({ + options: { + getERC721TokenURI: mockGetERC721TokenURI, + }, + }); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); const spy = jest.spyOn(nftController, 'updateNftMetadata'); @@ -3912,8 +4227,8 @@ describe('NftController', () => { ...getDefaultPreferencesState(), isIpfsGatewayEnabled: true, openSeaEnabled: false, - selectedAddress: OWNER_ADDRESS, }); + triggerSelectedAccountChange(OWNER_ACCOUNT); expect(spy).toHaveBeenCalledTimes(1); }); diff --git a/packages/assets-controllers/src/NftController.ts b/packages/assets-controllers/src/NftController.ts index 83fdeae592e..6d79aaba7cc 100644 --- a/packages/assets-controllers/src/NftController.ts +++ b/packages/assets-controllers/src/NftController.ts @@ -1,4 +1,9 @@ import { isAddress } from '@ethersproject/address'; +import { + type AccountsControllerSelectedEvmAccountChangeEvent, + type AccountsControllerGetAccountAction, + type AccountsControllerGetSelectedAccountAction, +} from '@metamask/accounts-controller'; import type { AddApprovalRequest } from '@metamask/approval-controller'; import type { RestrictedControllerMessenger, @@ -20,6 +25,7 @@ import { ApprovalType, NFT_API_BASE_URL, } from '@metamask/controller-utils'; +import { type InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientId, NetworkControllerGetNetworkClientByIdAction, @@ -214,11 +220,14 @@ export type NftControllerActions = NftControllerGetStateAction; */ export type AllowedActions = | AddApprovalRequest + | AccountsControllerGetAccountAction + | AccountsControllerGetSelectedAccountAction | NetworkControllerGetNetworkClientByIdAction; export type AllowedEvents = | PreferencesControllerStateChangeEvent - | NetworkControllerNetworkDidChangeEvent; + | NetworkControllerNetworkDidChangeEvent + | AccountsControllerSelectedEvmAccountChangeEvent; export type NftControllerStateChangeEvent = ControllerStateChangeEvent< typeof controllerName, @@ -259,7 +268,7 @@ export class NftController extends BaseController< */ openSeaApiKey?: string; - #selectedAddress: string; + #selectedAccountId: string; #chainId: Hex; @@ -296,7 +305,7 @@ export class NftController extends BaseController< * * @param options - The controller options. * @param options.chainId - The chain ID of the current network. - * @param options.selectedAddress - The currently selected address. + * @param options.selectedAccountId - The currently selected account id. * @param options.ipfsGateway - The configured IPFS gateway. * @param options.openSeaEnabled - Controls whether the OpenSea API is used. * @param options.useIpfsSubdomains - Controls whether IPFS subdomains are used. @@ -314,7 +323,7 @@ export class NftController extends BaseController< */ constructor({ chainId: initialChainId, - selectedAddress = '', + selectedAccountId = '', ipfsGateway = IPFS_DEFAULT_GATEWAY_URL, openSeaEnabled = false, useIpfsSubdomains = true, @@ -330,7 +339,7 @@ export class NftController extends BaseController< state = {}, }: { chainId: Hex; - selectedAddress?: string; + selectedAccountId?: string; ipfsGateway?: string; openSeaEnabled?: boolean; useIpfsSubdomains?: boolean; @@ -361,7 +370,7 @@ export class NftController extends BaseController< }, }); - this.#selectedAddress = selectedAddress; + this.#selectedAccountId = selectedAccountId; this.#chainId = initialChainId; this.#ipfsGateway = ipfsGateway; this.#openSeaEnabled = openSeaEnabled; @@ -385,6 +394,11 @@ export class NftController extends BaseController< 'NetworkController:networkDidChange', this.#onNetworkControllerNetworkDidChange.bind(this), ); + + this.messagingSystem.subscribe( + 'AccountsController:selectedEvmAccountChange', + this.#onSelectedAccountChange.bind(this), + ); } /** @@ -407,18 +421,19 @@ export class NftController extends BaseController< /** * Handles the state change of the preference controller. * @param preferencesState - The new state of the preference controller. - * @param preferencesState.selectedAddress - The current selected address. * @param preferencesState.ipfsGateway - The configured IPFS gateway. * @param preferencesState.openSeaEnabled - Controls whether the OpenSea API is used. * @param preferencesState.isIpfsGatewayEnabled - Controls whether IPFS is enabled or not. */ async #onPreferencesControllerStateChange({ - selectedAddress, ipfsGateway, openSeaEnabled, isIpfsGatewayEnabled, }: PreferencesState) { - this.#selectedAddress = selectedAddress; + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getSelectedAccount', + ); + this.#selectedAccountId = selectedAccount.id; this.#ipfsGateway = ipfsGateway; this.#openSeaEnabled = openSeaEnabled; this.#isIpfsGatewayEnabled = isIpfsGatewayEnabled; @@ -428,7 +443,37 @@ export class NftController extends BaseController< if (needsUpdateNftMetadata) { const nfts: Nft[] = - this.state.allNfts[selectedAddress]?.[this.#chainId] ?? []; + this.state.allNfts[selectedAccount?.address]?.[this.#chainId] ?? []; + // filter only nfts + const nftsToUpdate = nfts.filter( + (singleNft) => + !singleNft.name && !singleNft.description && !singleNft.image, + ); + if (nftsToUpdate.length !== 0) { + await this.updateNftMetadata({ + nfts: nftsToUpdate, + userAddress: selectedAccount?.address, + }); + } + } + } + + /** + * Handles the selected account change on the accounts controller. + * @param internalAccount - The new selected account. + */ + async #onSelectedAccountChange(internalAccount: InternalAccount) { + const oldSelectedAccountId = this.#selectedAccountId; + + this.#selectedAccountId = internalAccount.id; + const needsUpdateNftMetadata = + ((this.#isIpfsGatewayEnabled && this.#ipfsGateway !== '') || + this.#openSeaEnabled) && + oldSelectedAccountId !== internalAccount.id; + + if (needsUpdateNftMetadata) { + const nfts: Nft[] = + this.state.allNfts[internalAccount.address]?.[this.#chainId] ?? []; // filter only nfts const nftsToUpdate = nfts.filter( (singleNft) => @@ -437,7 +482,7 @@ export class NftController extends BaseController< if (nftsToUpdate.length !== 0) { await this.updateNftMetadata({ nfts: nftsToUpdate, - userAddress: selectedAddress, + userAddress: internalAccount.address, }); } } @@ -466,6 +511,12 @@ export class NftController extends BaseController< baseStateKey: Key, { userAddress, chainId }: { userAddress: string; chainId: Hex }, ) { + // userAddress can be an empty string if it is not set via an account change or in constructor + // while this doesn't cause any issues, we want to ensure that we don't store assets to an empty string address + if (!userAddress) { + return; + } + this.update((state) => { const oldState = state[baseStateKey]; const addressState = oldState[userAddress] || {}; @@ -1218,14 +1269,23 @@ export class NftController extends BaseController< origin: string, { networkClientId, - userAddress = this.#selectedAddress, + userAddress, }: { networkClientId?: NetworkClientId; userAddress?: string; - } = { - userAddress: this.#selectedAddress, - }, + } = {}, ) { + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + userAddress = userAddress || selectedAccount?.address || ''; + if (!userAddress) { + return; + } + await this.#validateWatchNft(asset, type, userAddress); const nftMetadata = await this.#getNftInformation( @@ -1341,17 +1401,21 @@ export class NftController extends BaseController< address: string, tokenId: string, { - userAddress = this.#selectedAddress, + userAddress, networkClientId, source, }: { userAddress?: string; networkClientId?: NetworkClientId; source?: Source; - } = { - userAddress: this.#selectedAddress, - }, + } = {}, ) { + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + userAddress = userAddress || selectedAccount?.address || ''; + if ( !(await this.isNftOwner(userAddress, address, tokenId, { networkClientId, @@ -1383,7 +1447,7 @@ export class NftController extends BaseController< tokenId: string, { nftMetadata, - userAddress = this.#selectedAddress, + userAddress, source = Source.Custom, networkClientId, }: { @@ -1391,8 +1455,18 @@ export class NftController extends BaseController< userAddress?: string; source?: Source; networkClientId?: NetworkClientId; - } = { userAddress: this.#selectedAddress }, + } = {}, ) { + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + userAddress = userAddress || selectedAccount?.address || ''; + if (!userAddress) { + return; + } + const checksumHexAddress = toChecksumHexAddress(tokenAddress); const chainId = this.#getCorrectChainId({ networkClientId }); @@ -1443,13 +1517,21 @@ export class NftController extends BaseController< */ async updateNftMetadata({ nfts, - userAddress = this.#selectedAddress, + userAddress, networkClientId, }: { nfts: Nft[]; userAddress?: string; networkClientId?: NetworkClientId; }) { + const userAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + const addressToSearch = userAddress || userAccount?.address || ''; + const chainId = this.#getCorrectChainId({ networkClientId }); const nftsWithChecksumAdr = nfts.map((nft) => { @@ -1475,7 +1557,7 @@ export class NftController extends BaseController< // We want to avoid updating the state if the state and fetched nft info are the same const nftsWithDifferentMetadata: NftUpdate[] = []; const { allNfts } = this.state; - const stateNfts = allNfts[userAddress]?.[chainId] || []; + const stateNfts = allNfts[addressToSearch]?.[chainId] || []; nftMetadataResults.forEach((singleNft) => { const existingEntry: Nft | undefined = stateNfts.find( @@ -1498,7 +1580,7 @@ export class NftController extends BaseController< if (nftsWithDifferentMetadata.length !== 0) { nftsWithDifferentMetadata.forEach((elm) => - this.updateNft(elm.nft, elm.newMetadata, userAddress, chainId), + this.updateNft(elm.nft, elm.newMetadata, addressToSearch, chainId), ); } } @@ -1517,25 +1599,33 @@ export class NftController extends BaseController< tokenId: string, { networkClientId, - userAddress = this.#selectedAddress, - }: { networkClientId?: NetworkClientId; userAddress?: string } = { - userAddress: this.#selectedAddress, - }, + userAddress, + }: { networkClientId?: NetworkClientId; userAddress?: string } = {}, ) { + const userAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + const addressToSearch = userAddress || userAccount?.address || ''; const chainId = this.#getCorrectChainId({ networkClientId }); const checksumHexAddress = toChecksumHexAddress(address); this.#removeIndividualNft(checksumHexAddress, tokenId, { chainId, - userAddress, + userAddress: addressToSearch, }); const { allNfts } = this.state; - const nfts = allNfts[userAddress]?.[chainId] || []; + const nfts = allNfts[addressToSearch]?.[chainId] || []; const remainingNft = nfts.find( (nft) => nft.address.toLowerCase() === checksumHexAddress.toLowerCase(), ); if (!remainingNft) { - this.#removeNftContract(checksumHexAddress, { chainId, userAddress }); + this.#removeNftContract(checksumHexAddress, { + chainId, + userAddress: addressToSearch, + }); } } @@ -1553,24 +1643,32 @@ export class NftController extends BaseController< tokenId: string, { networkClientId, - userAddress = this.#selectedAddress, - }: { networkClientId?: NetworkClientId; userAddress?: string } = { - userAddress: this.#selectedAddress, - }, + userAddress, + }: { networkClientId?: NetworkClientId; userAddress?: string } = {}, ) { + const userAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + const addressToSearch = userAddress || userAccount?.address || ''; const chainId = this.#getCorrectChainId({ networkClientId }); const checksumHexAddress = toChecksumHexAddress(address); this.#removeAndIgnoreIndividualNft(checksumHexAddress, tokenId, { chainId, - userAddress, + userAddress: addressToSearch, }); const { allNfts } = this.state; - const nfts = allNfts[userAddress]?.[chainId] || []; + const nfts = allNfts[addressToSearch]?.[chainId] || []; const remainingNft = nfts.find( (nft) => nft.address.toLowerCase() === checksumHexAddress.toLowerCase(), ); if (!remainingNft) { - this.#removeNftContract(checksumHexAddress, { chainId, userAddress }); + this.#removeNftContract(checksumHexAddress, { + chainId, + userAddress: addressToSearch, + }); } } @@ -1598,17 +1696,22 @@ export class NftController extends BaseController< nft: Nft, batch: boolean, { - userAddress = this.#selectedAddress, + userAddress, networkClientId, - }: { networkClientId?: NetworkClientId; userAddress?: string } = { - userAddress: this.#selectedAddress, - }, + }: { networkClientId?: NetworkClientId; userAddress?: string } = {}, ) { + const userAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + const addressToSearch = userAddress || userAccount?.address || ''; const chainId = this.#getCorrectChainId({ networkClientId }); const { address, tokenId } = nft; let isOwned = nft.isCurrentlyOwned; try { - isOwned = await this.isNftOwner(userAddress, address, tokenId, { + isOwned = await this.isNftOwner(addressToSearch, address, tokenId, { networkClientId, }); } catch { @@ -1628,7 +1731,7 @@ export class NftController extends BaseController< // if this is not part of a batched update we update this one NFT in state const { allNfts } = this.state; - const nfts = [...(allNfts[userAddress]?.[chainId] || [])]; + const nfts = [...(allNfts[addressToSearch]?.[chainId] || [])]; const indexToUpdate = nfts.findIndex( (item) => item.tokenId === tokenId && @@ -1638,16 +1741,16 @@ export class NftController extends BaseController< if (indexToUpdate !== -1) { nfts[indexToUpdate] = updatedNft; this.update((state) => { - state.allNfts[userAddress] = Object.assign( + state.allNfts[addressToSearch] = Object.assign( {}, - state.allNfts[userAddress], + state.allNfts[addressToSearch], { [chainId]: nfts, }, ); }); this.#updateNestedNftState(nfts, ALL_NFTS_STATE_KEY, { - userAddress, + userAddress: addressToSearch, chainId, }); } @@ -1662,17 +1765,23 @@ export class NftController extends BaseController< * @param options.networkClientId - The networkClientId that can be used to identify the network client to use for this request. * @param options.userAddress - The address of the account where the NFT ownership status is checked/updated. */ - async checkAndUpdateAllNftsOwnershipStatus( - { - networkClientId, - userAddress = this.#selectedAddress, - }: { networkClientId?: NetworkClientId; userAddress?: string } = { - userAddress: this.#selectedAddress, - }, - ) { + async checkAndUpdateAllNftsOwnershipStatus({ + networkClientId, + userAddress, + }: { + networkClientId?: NetworkClientId; + userAddress?: string; + } = {}) { + const userAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + const addressToSearch = userAddress || userAccount?.address || ''; const chainId = this.#getCorrectChainId({ networkClientId }); const { allNfts } = this.state; - const nfts = allNfts[userAddress]?.[chainId] || []; + const nfts = allNfts[addressToSearch]?.[chainId] || []; const updatedNfts = await Promise.all( nfts.map(async (nft) => { return ( @@ -1685,7 +1794,7 @@ export class NftController extends BaseController< ); this.#updateNestedNftState(updatedNfts, ALL_NFTS_STATE_KEY, { - userAddress, + userAddress: addressToSearch, chainId, }); } @@ -1706,17 +1815,22 @@ export class NftController extends BaseController< favorite: boolean, { networkClientId, - userAddress = this.#selectedAddress, + userAddress, }: { networkClientId?: NetworkClientId; userAddress?: string; - } = { - userAddress: this.#selectedAddress, - }, + } = {}, ) { + const userAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + const addressToSearch = userAddress || userAccount?.address || ''; const chainId = this.#getCorrectChainId({ networkClientId }); const { allNfts } = this.state; - const nfts = [...(allNfts[userAddress]?.[chainId] || [])]; + const nfts = [...(allNfts[addressToSearch]?.[chainId] || [])]; const index: number = nfts.findIndex( (nft) => nft.address === address && nft.tokenId === tokenId, ); @@ -1735,7 +1849,7 @@ export class NftController extends BaseController< this.#updateNestedNftState(nfts, ALL_NFTS_STATE_KEY, { chainId, - userAddress, + userAddress: addressToSearch, }); } diff --git a/packages/assets-controllers/src/NftDetectionController.test.ts b/packages/assets-controllers/src/NftDetectionController.test.ts index 8984134a0e6..4a989d5e155 100644 --- a/packages/assets-controllers/src/NftDetectionController.test.ts +++ b/packages/assets-controllers/src/NftDetectionController.test.ts @@ -1,3 +1,4 @@ +import type { AccountsController } from '@metamask/accounts-controller'; import { ControllerMessenger } from '@metamask/base-controller'; import { NFT_API_BASE_URL, ChainId } from '@metamask/controller-utils'; import { @@ -20,6 +21,7 @@ import * as sinon from 'sinon'; import { FakeBlockTracker } from '../../../tests/fake-block-tracker'; import { FakeProvider } from '../../../tests/fake-provider'; import { advanceTime } from '../../../tests/helpers'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import { buildCustomNetworkClientConfiguration, buildMockGetNetworkClientById, @@ -37,6 +39,8 @@ const DEFAULT_INTERVAL = 180000; const controllerName = 'NftDetectionController' as const; +const defaultSelectedAccount = createMockInternalAccount(); + describe('NftDetectionController', () => { let clock: sinon.SinonFakeTimers; @@ -288,8 +292,17 @@ describe('NftDetectionController', () => { }); it('should poll and detect NFTs on interval while on mainnet', async () => { + const mockGetSelectedAccount = jest + .fn() + .mockReturnValue(defaultSelectedAccount); await withController( - { options: { interval: 10 } }, + { + options: { + interval: 10, + }, + getSelectedAccount: mockGetSelectedAccount, + }, + async ({ controller, controllerEvents }) => { const mockNfts = sinon .stub(controller, 'detectNfts') @@ -317,51 +330,56 @@ describe('NftDetectionController', () => { }); it('should poll and detect NFTs by networkClientId on interval while on mainnet', async () => { - await withController(async ({ controller }) => { - const spy = jest - .spyOn(controller, 'detectNfts') - .mockImplementation(() => { - return Promise.resolve(); - }); + await withController( + { + options: {}, + }, + async ({ controller }) => { + const spy = jest + .spyOn(controller, 'detectNfts') + .mockImplementation(() => { + return Promise.resolve(); + }); - controller.startPollingByNetworkClientId('mainnet', { - address: '0x1', - }); + controller.startPollingByNetworkClientId('mainnet', { + address: '0x1', + }); - await advanceTime({ clock, duration: 0 }); - expect(spy.mock.calls).toHaveLength(1); - await advanceTime({ - clock, - duration: DEFAULT_INTERVAL / 2, - }); - expect(spy.mock.calls).toHaveLength(1); - await advanceTime({ - clock, - duration: DEFAULT_INTERVAL / 2, - }); - expect(spy.mock.calls).toHaveLength(2); - await advanceTime({ clock, duration: DEFAULT_INTERVAL }); - expect(spy.mock.calls).toMatchObject([ - [ - { - networkClientId: 'mainnet', - userAddress: '0x1', - }, - ], - [ - { - networkClientId: 'mainnet', - userAddress: '0x1', - }, - ], - [ - { - networkClientId: 'mainnet', - userAddress: '0x1', - }, - ], - ]); - }); + await advanceTime({ clock, duration: 0 }); + expect(spy.mock.calls).toHaveLength(1); + await advanceTime({ + clock, + duration: DEFAULT_INTERVAL / 2, + }); + expect(spy.mock.calls).toHaveLength(1); + await advanceTime({ + clock, + duration: DEFAULT_INTERVAL / 2, + }); + expect(spy.mock.calls).toHaveLength(2); + await advanceTime({ clock, duration: DEFAULT_INTERVAL }); + expect(spy.mock.calls).toMatchObject([ + [ + { + networkClientId: 'mainnet', + userAddress: '0x1', + }, + ], + [ + { + networkClientId: 'mainnet', + userAddress: '0x1', + }, + ], + [ + { + networkClientId: 'mainnet', + userAddress: '0x1', + }, + ], + ]); + }, + ); }); it('should not rely on the currently selected chain to poll for NFTs when a specific chain is being targeted for polling', async () => { @@ -498,7 +516,9 @@ describe('NftDetectionController', () => { it('should respond to chain ID changing when using legacy polling', async () => { const mockAddNft = jest.fn(); + const mockGetSelectedAccount = jest.fn(); const pollingInterval = 100; + const selectedAccount = createMockInternalAccount({ address: '0x1' }); await withController( { @@ -515,9 +535,9 @@ describe('NftDetectionController', () => { mockNetworkState: { selectedNetworkClientId: 'mainnet', }, - mockPreferencesState: { - selectedAddress: '0x1', - }, + mockPreferencesState: {}, + getSelectedAccount: + mockGetSelectedAccount.mockReturnValue(selectedAccount), }, async ({ controller, controllerEvents }) => { await controller.start(); @@ -595,17 +615,18 @@ describe('NftDetectionController', () => { it('should detect and add NFTs correctly when blockaid result is not included in response', async () => { const mockAddNft = jest.fn(); const selectedAddress = '0x1'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); await withController( { options: { addNft: mockAddNft }, - mockPreferencesState: { - selectedAddress, - }, + mockPreferencesState: {}, + getSelectedAccount: jest.fn().mockReturnValue(selectedAccount), }, async ({ controller, controllerEvents }) => { controllerEvents.triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useNftDetection: true, }); // Wait for detect call triggered by preferences state change to settle @@ -628,7 +649,7 @@ describe('NftDetectionController', () => { standard: 'ERC721', imageOriginal: 'imageOriginal/2574.png', }, - userAddress: selectedAddress, + userAddress: selectedAccount.address, source: Source.Detected, networkClientId: undefined, }, @@ -640,15 +661,18 @@ describe('NftDetectionController', () => { it('should detect and add NFTs correctly when blockaid result is in response', async () => { const mockAddNft = jest.fn(); const selectedAddress = '0x123'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); await withController( { options: { addNft: mockAddNft }, - mockPreferencesState: { selectedAddress }, + mockPreferencesState: {}, + getSelectedAccount: jest.fn().mockReturnValue(selectedAccount), }, async ({ controller, controllerEvents }) => { controllerEvents.triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useNftDetection: true, }); // Wait for detect call triggered by preferences state change to settle @@ -669,7 +693,7 @@ describe('NftDetectionController', () => { standard: 'ERC721', imageOriginal: 'imageOriginal/2574.png', }, - userAddress: selectedAddress, + userAddress: selectedAccount.address, source: Source.Detected, networkClientId: undefined, }); @@ -681,7 +705,7 @@ describe('NftDetectionController', () => { standard: 'ERC721', imageOriginal: 'imageOriginal/2575.png', }, - userAddress: selectedAddress, + userAddress: selectedAccount.address, source: Source.Detected, networkClientId: undefined, }); @@ -692,15 +716,18 @@ describe('NftDetectionController', () => { it('should detect and add NFTs and filter them correctly', async () => { const mockAddNft = jest.fn(); const selectedAddress = '0x12345'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); await withController( { options: { addNft: mockAddNft }, - mockPreferencesState: { selectedAddress }, + mockPreferencesState: {}, + getSelectedAccount: jest.fn().mockReturnValue(selectedAccount), }, async ({ controller, controllerEvents }) => { controllerEvents.triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useNftDetection: true, }); // Wait for detect call triggered by preferences state change to settle @@ -727,7 +754,7 @@ describe('NftDetectionController', () => { standard: 'ERC721', imageOriginal: 'imageOriginal/1.png', }, - userAddress: selectedAddress, + userAddress: selectedAccount.address, source: Source.Detected, networkClientId: undefined, }, @@ -744,7 +771,7 @@ describe('NftDetectionController', () => { standard: 'ERC721', imageOriginal: 'imageOriginal/2.png', }, - userAddress: selectedAddress, + userAddress: selectedAccount.address, source: Source.Detected, networkClientId: undefined, }, @@ -755,13 +782,22 @@ describe('NftDetectionController', () => { it('should detect and add NFTs by networkClientId correctly', async () => { const mockAddNft = jest.fn(); + const mockGetSelectedAccount = jest.fn(); await withController( - { options: { addNft: mockAddNft } }, + { + options: { + addNft: mockAddNft, + }, + getSelectedAccount: mockGetSelectedAccount, + }, async ({ controller, controllerEvents }) => { const selectedAddress = '0x1'; + const updatedSelectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); + mockGetSelectedAccount.mockReturnValue(updatedSelectedAccount); controllerEvents.triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useNftDetection: true, }); // Wait for detect call triggered by preferences state change to settle @@ -798,6 +834,7 @@ describe('NftDetectionController', () => { it('should not autodetect NFTs that exist in the ignoreList', async () => { const mockAddNft = jest.fn(); + const mockGetSelectedAccount = jest.fn(); const mockGetNftState = jest.fn().mockImplementation(() => { return { ...getDefaultNftControllerState(), @@ -813,15 +850,19 @@ describe('NftDetectionController', () => { }; }); const selectedAddress = '0x9'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); await withController( { options: { addNft: mockAddNft, getNftState: mockGetNftState }, mockPreferencesState: { selectedAddress }, + getSelectedAccount: mockGetSelectedAccount, }, async ({ controller, controllerEvents }) => { + mockGetSelectedAccount.mockReturnValue(selectedAccount); controllerEvents.triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useNftDetection: true, }); // Wait for detect call triggered by preferences state change to settle @@ -840,17 +881,20 @@ describe('NftDetectionController', () => { it('should not detect and add NFTs if there is no selectedAddress', async () => { const mockAddNft = jest.fn(); - const selectedAddress = ''; // Emtpy selected address + // mock uninitialised selectedAccount when it is '' + const mockGetSelectedInternalAccount = jest + .fn() + .mockReturnValue({ address: '' }); await withController( { options: { addNft: mockAddNft }, - mockPreferencesState: { selectedAddress }, + mockPreferencesState: {}, + getSelectedAccount: mockGetSelectedInternalAccount, }, async ({ controller, controllerEvents }) => { controllerEvents.triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, - useNftDetection: true, // auto-detect is enabled so it proceeds to check userAddress + useNftDetection: true, // auto-detect is enableds }); await controller.detectNfts(); @@ -913,16 +957,21 @@ describe('NftDetectionController', () => { it('should not detect and add NFTs if preferences controller useNftDetection is set to false', async () => { const mockAddNft = jest.fn(); + const mockGetSelectedAccount = jest.fn(); const selectedAddress = '0x9'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); await withController( { options: { addNft: mockAddNft, disabled: false }, - mockPreferencesState: { selectedAddress }, + mockPreferencesState: {}, + getSelectedAccount: mockGetSelectedAccount, }, async ({ controller, controllerEvents }) => { + mockGetSelectedAccount.mockReturnValue(selectedAccount); controllerEvents.triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useNftDetection: false, }); // Wait for detect call triggered by preferences state change to settle @@ -940,9 +989,9 @@ describe('NftDetectionController', () => { }); it('should do nothing when the request to Nft API fails', async () => { - const selectedAddress = '0x3'; + const selectedAccount = createMockInternalAccount({ address: '0x3' }); nock(NFT_API_BASE_URL) - .get(`/users/${selectedAddress}/tokens`) + .get(`/users/${selectedAccount.address}/tokens`) .query({ continuation: '', limit: '50', @@ -952,12 +1001,17 @@ describe('NftDetectionController', () => { .replyWithError(new Error('Failed to fetch')) .persist(); const mockAddNft = jest.fn(); + const mockGetSelectedAccount = jest.fn().mockReturnValue(selectedAccount); await withController( - { options: { addNft: mockAddNft } }, + { + options: { + addNft: mockAddNft, + }, + getSelectedAccount: mockGetSelectedAccount, + }, async ({ controller, controllerEvents }) => { controllerEvents.triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useNftDetection: true, }); // Wait for detect call triggered by preferences state change to settle @@ -976,8 +1030,14 @@ describe('NftDetectionController', () => { it('should rethrow error when Nft APi server fails with error other than fetch failure', async () => { const selectedAddress = '0x4'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); await withController( - { mockPreferencesState: { selectedAddress } }, + { + mockPreferencesState: {}, + getSelectedAccount: jest.fn().mockReturnValue(selectedAccount), + }, async ({ controller, controllerEvents }) => { // This mock is for the initial detect call after preferences change nock(NFT_API_BASE_URL) @@ -993,7 +1053,6 @@ describe('NftDetectionController', () => { }); controllerEvents.triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useNftDetection: true, }); // Wait for detect call triggered by preferences state change to settle @@ -1021,16 +1080,21 @@ describe('NftDetectionController', () => { it('should rethrow error when attempt to add NFT fails', async () => { const mockAddNft = jest.fn(); + const mockGetSelectedAccount = jest.fn(); const selectedAddress = '0x1'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); await withController( { options: { addNft: mockAddNft }, - mockPreferencesState: { selectedAddress }, + mockPreferencesState: {}, + getSelectedAccount: mockGetSelectedAccount, }, async ({ controller, controllerEvents }) => { + mockGetSelectedAccount.mockReturnValue(selectedAccount); controllerEvents.triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useNftDetection: true, }); // Wait for detect call triggered by preferences state change to settle @@ -1049,28 +1113,38 @@ describe('NftDetectionController', () => { }); it('should only re-detect when relevant settings change', async () => { - await withController({}, async ({ controller, controllerEvents }) => { - const detectNfts = sinon.stub(controller, 'detectNfts'); + const mockGetSelectedAccount = jest + .fn() + .mockReturnValue(defaultSelectedAccount); + await withController( + { + options: {}, + getSelectedAccount: mockGetSelectedAccount, + }, + async ({ controller, controllerEvents }) => { + const detectNfts = sinon.stub(controller, 'detectNfts'); + + // Repeated preference changes should only trigger 1 detection + for (let i = 0; i < 5; i++) { + controllerEvents.triggerPreferencesStateChange({ + ...getDefaultPreferencesState(), + useNftDetection: true, + securityAlertsEnabled: true, + }); + } + await advanceTime({ clock, duration: 1 }); + expect(detectNfts.callCount).toBe(1); - // Repeated preference changes should only trigger 1 detection - for (let i = 0; i < 5; i++) { + // Irrelevant preference changes shouldn't trigger a detection controllerEvents.triggerPreferencesStateChange({ ...getDefaultPreferencesState(), useNftDetection: true, + securityAlertsEnabled: true, }); - } - await advanceTime({ clock, duration: 1 }); - expect(detectNfts.callCount).toBe(1); - - // Irrelevant preference changes shouldn't trigger a detection - controllerEvents.triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - useNftDetection: true, - securityAlertsEnabled: true, - }); - await advanceTime({ clock, duration: 1 }); - expect(detectNfts.callCount).toBe(1); - }); + await advanceTime({ clock, duration: 1 }); + expect(detectNfts.callCount).toBe(1); + }, + ); }); }); @@ -1097,6 +1171,7 @@ type WithControllerOptions = { >; mockNetworkState?: Partial; mockPreferencesState?: Partial; + getSelectedAccount?: jest.Mock; }; type WithControllerArgs = @@ -1121,6 +1196,7 @@ async function withController( mockNetworkClientConfigurationsByNetworkClientId = {}, mockNetworkState = {}, mockPreferencesState = {}, + getSelectedAccount = jest.fn().mockReturnValue(defaultSelectedAccount), }, testFunction, ] = args.length === 2 ? args : [{}, args[0]]; @@ -1135,6 +1211,11 @@ async function withController( }), ); + messenger.registerActionHandler( + 'AccountsController:getSelectedAccount', + getSelectedAccount, + ); + const getNetworkClientById = buildMockGetNetworkClientById( mockNetworkClientConfigurationsByNetworkClientId, ); @@ -1157,6 +1238,7 @@ async function withController( 'NetworkController:getState', 'NetworkController:getNetworkClientById', 'PreferencesController:getState', + 'AccountsController:getSelectedAccount', ], allowedEvents: [ 'NetworkController:stateChange', diff --git a/packages/assets-controllers/src/NftDetectionController.ts b/packages/assets-controllers/src/NftDetectionController.ts index 25f7fd6ef4d..0754b890d76 100644 --- a/packages/assets-controllers/src/NftDetectionController.ts +++ b/packages/assets-controllers/src/NftDetectionController.ts @@ -1,3 +1,4 @@ +import type { AccountsControllerGetSelectedAccountAction } from '@metamask/accounts-controller'; import type { AddApprovalRequest } from '@metamask/approval-controller'; import type { RestrictedControllerMessenger } from '@metamask/base-controller'; import { @@ -37,7 +38,8 @@ export type AllowedActions = | AddApprovalRequest | NetworkControllerGetStateAction | NetworkControllerGetNetworkClientByIdAction - | PreferencesControllerGetStateAction; + | PreferencesControllerGetStateAction + | AccountsControllerGetSelectedAccountAction; export type AllowedEvents = | PreferencesControllerStateChangeEvent @@ -542,8 +544,8 @@ export class NftDetectionController extends StaticIntervalPollingController< }) { const userAddress = options?.userAddress ?? - this.messagingSystem.call('PreferencesController:getState') - .selectedAddress; + this.messagingSystem.call('AccountsController:getSelectedAccount') + .address; /* istanbul ignore if */ if (!this.isMainnet() || this.#disabled) { return; diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index 1d722b421ca..5ac19788a47 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -3,6 +3,7 @@ import { toHex } from '@metamask/controller-utils'; import BN from 'bn.js'; import { flushPromises } from '../../../tests/helpers'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import type { AllowedActions, AllowedEvents, @@ -31,7 +32,7 @@ function getMessenger( ): TokenBalancesControllerMessenger { return controllerMessenger.getRestricted({ name: controllerName, - allowedActions: ['PreferencesController:getState'], + allowedActions: ['AccountsController:getSelectedAccount'], allowedEvents: ['TokensController:stateChange'], }); } @@ -52,8 +53,10 @@ describe('TokenBalancesController', () => { it('should set default state', () => { controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ getERC20BalanceOf: jest.fn(), @@ -65,8 +68,10 @@ describe('TokenBalancesController', () => { it('should poll and update balances in the right interval', async () => { controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const updateBalancesSpy = jest.spyOn( TokenBalancesController.prototype, @@ -91,8 +96,10 @@ describe('TokenBalancesController', () => { it('should update balances if enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ disabled: false, @@ -112,8 +119,10 @@ describe('TokenBalancesController', () => { it('should not update balances if disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ disabled: true, @@ -131,8 +140,10 @@ describe('TokenBalancesController', () => { it('should update balances if controller is manually enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ disabled: true, @@ -157,8 +168,10 @@ describe('TokenBalancesController', () => { it('should not update balances if controller is manually disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ disabled: false, @@ -185,8 +198,10 @@ describe('TokenBalancesController', () => { it('should update balances if tokens change and controller is manually enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ disabled: true, @@ -223,8 +238,10 @@ describe('TokenBalancesController', () => { it('should not update balances if tokens change and controller is manually disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ disabled: false, @@ -262,8 +279,10 @@ describe('TokenBalancesController', () => { it('should clear previous interval', async () => { controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ interval: 1337, @@ -292,8 +311,12 @@ describe('TokenBalancesController', () => { }, ]; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue( + createMockInternalAccount({ address: selectedAddress }), + ), ); const controller = new TokenBalancesController({ interval: 1337, @@ -327,8 +350,8 @@ describe('TokenBalancesController', () => { ]; controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({}), + 'AccountsController:getSelectedAccount', + jest.fn().mockReturnValue(createMockInternalAccount({ address })), ); const controller = new TokenBalancesController({ interval: 1337, @@ -355,8 +378,10 @@ describe('TokenBalancesController', () => { it('should update balances when tokens change', async () => { controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ getERC20BalanceOf: jest.fn(), @@ -384,8 +409,10 @@ describe('TokenBalancesController', () => { it('should update token balances when detected tokens are added', async () => { controllerMessenger.registerActionHandler( - 'PreferencesController:getState', - jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + 'AccountsController:getSelectedAccount', + jest + .fn() + .mockReturnValue(createMockInternalAccount({ address: '0x1234' })), ); const controller = new TokenBalancesController({ interval: 1337, diff --git a/packages/assets-controllers/src/TokenBalancesController.ts b/packages/assets-controllers/src/TokenBalancesController.ts index 280793ef68c..1251b4101c9 100644 --- a/packages/assets-controllers/src/TokenBalancesController.ts +++ b/packages/assets-controllers/src/TokenBalancesController.ts @@ -1,3 +1,4 @@ +import { type AccountsControllerGetSelectedAccountAction } from '@metamask/accounts-controller'; import { type RestrictedControllerMessenger, type ControllerGetStateAction, @@ -5,7 +6,6 @@ import { BaseController, } from '@metamask/base-controller'; import { safelyExecute, toHex } from '@metamask/controller-utils'; -import type { PreferencesControllerGetStateAction } from '@metamask/preferences-controller'; import type { AssetsContractController } from './AssetsContractController'; import type { Token } from './TokenRatesController'; @@ -56,7 +56,7 @@ export type TokenBalancesControllerGetStateAction = ControllerGetStateAction< export type TokenBalancesControllerActions = TokenBalancesControllerGetStateAction; -export type AllowedActions = PreferencesControllerGetStateAction; +export type AllowedActions = AccountsControllerGetSelectedAccountAction; export type TokenBalancesControllerStateChangeEvent = ControllerStateChangeEvent< @@ -195,16 +195,18 @@ export class TokenBalancesController extends BaseController< if (this.#disabled) { return; } - - const { selectedAddress } = this.messagingSystem.call( - 'PreferencesController:getState', + const selectedInternalAccount = this.messagingSystem.call( + 'AccountsController:getSelectedAccount', ); const newContractBalances: ContractBalances = {}; for (const token of this.#tokens) { const { address } = token; try { - const balance = await this.#getERC20BalanceOf(address, selectedAddress); + const balance = await this.#getERC20BalanceOf( + address, + selectedInternalAccount.address, + ); newContractBalances[address] = toHex(balance); token.hasBalanceError = false; } catch (error) { diff --git a/packages/assets-controllers/src/TokenDetectionController.test.ts b/packages/assets-controllers/src/TokenDetectionController.test.ts index 7b040fefee4..a21c05ff462 100644 --- a/packages/assets-controllers/src/TokenDetectionController.test.ts +++ b/packages/assets-controllers/src/TokenDetectionController.test.ts @@ -27,6 +27,7 @@ import nock from 'nock'; import * as sinon from 'sinon'; import { advanceTime } from '../../../tests/helpers'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import { formatAggregatorNames } from './assetsUtil'; import { TOKEN_END_POINT_API } from './token-service'; import type { @@ -144,6 +145,7 @@ function buildTokenDetectionControllerMessenger( return controllerMessenger.getRestricted({ name: controllerName, allowedActions: [ + 'AccountsController:getAccount', 'AccountsController:getSelectedAccount', 'KeyringController:getState', 'NetworkController:getNetworkClientById', @@ -155,7 +157,7 @@ function buildTokenDetectionControllerMessenger( 'PreferencesController:getState', ], allowedEvents: [ - 'AccountsController:selectedAccountChange', + 'AccountsController:selectedEvmAccountChange', 'KeyringController:lock', 'KeyringController:unlock', 'NetworkController:networkDidChange', @@ -166,6 +168,8 @@ function buildTokenDetectionControllerMessenger( } describe('TokenDetectionController', () => { + const defaultSelectedAccount = createMockInternalAccount(); + beforeEach(async () => { nock(TOKEN_END_POINT_API) .get(getTokensPath(ChainId.mainnet)) @@ -203,6 +207,7 @@ describe('TokenDetectionController', () => { await withController( { isKeyringUnlocked: false, + options: { selectedAccountId: defaultSelectedAccount.id }, }, async ({ controller }) => { const mockTokens = sinon.stub(controller, 'detectTokens'); @@ -221,8 +226,12 @@ describe('TokenDetectionController', () => { await withController( { isKeyringUnlocked: false, + options: { + selectedAccountId: defaultSelectedAccount.id, + }, }, - async ({ controller, triggerKeyringUnlock }) => { + async ({ controller, mockGetAccount, triggerKeyringUnlock }) => { + mockGetAccount(defaultSelectedAccount); const mockTokens = sinon.stub(controller, 'detectTokens'); await controller.start(); @@ -255,16 +264,24 @@ describe('TokenDetectionController', () => { }); it('should poll and detect tokens on interval while on supported networks', async () => { - await withController(async ({ controller }) => { - const mockTokens = sinon.stub(controller, 'detectTokens'); - controller.setIntervalLength(10); + await withController( + { + options: { + selectedAccountId: defaultSelectedAccount.id, + }, + }, + async ({ controller, mockGetAccount }) => { + mockGetAccount(defaultSelectedAccount); + const mockTokens = sinon.stub(controller, 'detectTokens'); + controller.setIntervalLength(10); - await controller.start(); + await controller.start(); - expect(mockTokens.calledOnce).toBe(true); - await advanceTime({ clock, duration: 15 }); - expect(mockTokens.calledTwice).toBe(true); - }); + expect(mockTokens.calledOnce).toBe(true); + await advanceTime({ clock, duration: 15 }); + expect(mockTokens.calledTwice).toBe(true); + }, + ); }); it('should not autodetect while not on supported networks', async () => { @@ -275,9 +292,11 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, + selectedAccountId: defaultSelectedAccount.id, }, }, - async ({ controller, mockNetworkState }) => { + async ({ controller, mockGetAccount, mockNetworkState }) => { + mockGetAccount(defaultSelectedAccount); mockNetworkState({ ...defaultNetworkState, selectedNetworkClientId: NetworkType.goerli, @@ -293,15 +312,23 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller, mockTokenListGetState, callActionSpy }) => { + async ({ + controller, + mockGetAccount, + mockTokenListGetState, + callActionSpy, + }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -329,7 +356,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -340,21 +367,25 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ controller, + mockGetAccount, mockTokenListGetState, mockNetworkState, mockGetNetworkClientById, callActionSpy, }) => { + mockGetAccount(selectedAccount); mockNetworkState({ ...defaultNetworkState, selectedNetworkClientId: 'polygon', @@ -393,7 +424,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: '0x89', - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -405,17 +436,25 @@ describe('TokenDetectionController', () => { [sampleTokenA.address]: new BN(1), [sampleTokenB.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); const interval = 100; await withController( { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, interval, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller, mockTokenListGetState, callActionSpy }) => { + async ({ + controller, + mockGetAccount, + mockTokenListGetState, + callActionSpy, + }) => { + mockGetAccount(selectedAccount); const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { @@ -455,7 +494,7 @@ describe('TokenDetectionController', () => { [sampleTokenA, sampleTokenB], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -466,20 +505,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ controller, + mockGetAccount, mockTokensGetState, mockTokenListGetState, callActionSpy, }) => { + mockGetAccount(selectedAccount); mockTokensGetState({ ...getDefaultTokensState(), ignoredTokens: [sampleTokenA.address], @@ -521,9 +564,16 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, + selectedAccountId: defaultSelectedAccount.id, }, }, - async ({ controller, mockTokenListGetState, callActionSpy }) => { + async ({ + controller, + mockGetAccount, + mockTokenListGetState, + callActionSpy, + }) => { + mockGetAccount(defaultSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -569,23 +619,27 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerSelectedAccountChange, callActionSpy, }) => { + mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -606,9 +660,8 @@ describe('TokenDetectionController', () => { }, }); - triggerSelectedAccountChange({ - address: secondSelectedAddress, - } as InternalAccount); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); expect(callActionSpy).toHaveBeenCalledWith( @@ -616,7 +669,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress: secondSelectedAddress, + selectedAddress: secondSelectedAccount.address, }, ); }, @@ -627,13 +680,15 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ @@ -662,7 +717,7 @@ describe('TokenDetectionController', () => { }); triggerSelectedAccountChange({ - address: selectedAddress, + address: selectedAccount.address, } as InternalAccount); await advanceTime({ clock, duration: 1 }); @@ -678,16 +733,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, isKeyringUnlocked: false, }, @@ -717,7 +774,7 @@ describe('TokenDetectionController', () => { }); triggerSelectedAccountChange({ - address: secondSelectedAddress, + address: secondSelectedAccount.address, } as InternalAccount); await advanceTime({ clock, duration: 1 }); @@ -735,16 +792,18 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, }, async ({ @@ -773,7 +832,7 @@ describe('TokenDetectionController', () => { }); triggerSelectedAccountChange({ - address: secondSelectedAddress, + address: secondSelectedAccount.address, } as InternalAccount); await advanceTime({ clock, duration: 1 }); @@ -801,23 +860,28 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, + triggerSelectedAccountChange, callActionSpy, }) => { + mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -840,17 +904,18 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress: secondSelectedAddress, useTokenDetection: true, }); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); - expect(callActionSpy).toHaveBeenCalledWith( + expect(callActionSpy).toHaveBeenLastCalledWith( 'TokensController:addDetectedTokens', [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress: secondSelectedAddress, + selectedAddress: secondSelectedAccount.address, }, ); }, @@ -861,20 +926,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -897,14 +966,12 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: false, }); await advanceTime({ clock, duration: 1 }); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: true, }); await advanceTime({ clock, duration: 1 }); @@ -914,7 +981,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -925,23 +992,28 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, + triggerSelectedAccountChange, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -959,9 +1031,10 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress: secondSelectedAddress, useTokenDetection: false, }); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); expect(callActionSpy).not.toHaveBeenCalledWith( @@ -975,20 +1048,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1006,7 +1083,6 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: true, }); await advanceTime({ clock, duration: 1 }); @@ -1023,24 +1099,29 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, isKeyringUnlocked: false, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, + triggerSelectedAccountChange, callActionSpy, }) => { + mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1058,9 +1139,10 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress: secondSelectedAddress, useTokenDetection: true, }); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); expect(callActionSpy).not.toHaveBeenCalledWith( @@ -1074,21 +1156,25 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, isKeyringUnlocked: false, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1106,14 +1192,12 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: false, }); await advanceTime({ clock, duration: 1 }); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: true, }); await advanceTime({ clock, duration: 1 }); @@ -1132,23 +1216,28 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const firstSelectedAddress = - '0x0000000000000000000000000000000000000001'; - const secondSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const firstSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + const secondSelectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000002', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress: firstSelectedAddress, + selectedAccountId: firstSelectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, + triggerSelectedAccountChange, callActionSpy, }) => { + mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1166,9 +1255,10 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress: secondSelectedAddress, useTokenDetection: true, }); + mockGetAccount(secondSelectedAccount); + triggerSelectedAccountChange(secondSelectedAccount); await advanceTime({ clock, duration: 1 }); expect(callActionSpy).not.toHaveBeenCalledWith( @@ -1182,20 +1272,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1213,14 +1307,12 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: false, }); await advanceTime({ clock, duration: 1 }); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), - selectedAddress, useTokenDetection: true, }); await advanceTime({ clock, duration: 1 }); @@ -1249,20 +1341,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -1294,7 +1390,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: '0x89', - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -1305,20 +1401,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -1356,20 +1456,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1403,21 +1507,25 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, isKeyringUnlocked: false, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1453,20 +1561,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerNetworkDidChange, }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1512,20 +1624,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerTokenListStateChange, }) => { + mockGetAccount(selectedAccount); const tokenList = { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1557,7 +1673,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -1568,20 +1684,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerTokenListStateChange, }) => { + mockGetAccount(selectedAccount); const tokenListState = { ...getDefaultTokenListState(), tokenList: {}, @@ -1603,21 +1723,25 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, isKeyringUnlocked: false, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerTokenListStateChange, }) => { + mockGetAccount(selectedAccount); const tokenListState = { ...getDefaultTokenListState(), tokenList: { @@ -1651,20 +1775,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ + mockGetAccount, mockTokenListGetState, callActionSpy, triggerTokenListStateChange, }) => { + mockGetAccount(selectedAccount); const tokenListState = { ...getDefaultTokenListState(), tokenList: { @@ -1707,13 +1835,15 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ controller, mockTokenListGetState }) => { @@ -1773,13 +1903,15 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ @@ -1787,7 +1919,9 @@ describe('TokenDetectionController', () => { mockNetworkState, triggerPreferencesStateChange, callActionSpy, + mockGetAccount, }) => { + mockGetAccount(selectedAccount); mockNetworkState({ ...defaultNetworkState, selectedNetworkClientId: NetworkType.goerli, @@ -1798,7 +1932,7 @@ describe('TokenDetectionController', () => { }); await controller.detectTokens({ networkClientId: NetworkType.goerli, - selectedAddress, + selectedAddress: selectedAccount.address, }); expect(callActionSpy).not.toHaveBeenCalledWith( 'TokensController:addDetectedTokens', @@ -1817,27 +1951,31 @@ describe('TokenDetectionController', () => { {}, ), ); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, async ({ controller, + mockGetAccount, triggerPreferencesStateChange, callActionSpy, }) => { + mockGetAccount(selectedAccount); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), useTokenDetection: false, }); await controller.detectTokens({ networkClientId: NetworkType.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }); expect(callActionSpy).toHaveBeenLastCalledWith( 'TokensController:addDetectedTokens', @@ -1850,7 +1988,7 @@ describe('TokenDetectionController', () => { }; }), { - selectedAddress, + selectedAddress: selectedAccount.address, chainId: ChainId.mainnet, }, ); @@ -1862,16 +2000,24 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller, mockTokenListGetState, callActionSpy }) => { + async ({ + controller, + mockGetAccount, + mockTokenListGetState, + callActionSpy, + }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -1894,7 +2040,7 @@ describe('TokenDetectionController', () => { await controller.detectTokens({ networkClientId: NetworkType.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }); expect(callActionSpy).toHaveBeenCalledWith( @@ -1902,7 +2048,7 @@ describe('TokenDetectionController', () => { [sampleTokenA], { chainId: ChainId.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }, ); }, @@ -1913,7 +2059,9 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const selectedAddress = '0x0000000000000000000000000000000000000001'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); const mockTrackMetaMetricsEvent = jest.fn(); await withController( @@ -1922,10 +2070,11 @@ describe('TokenDetectionController', () => { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, trackMetaMetricsEvent: mockTrackMetaMetricsEvent, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller, mockTokenListGetState }) => { + async ({ controller, mockGetAccount, mockTokenListGetState }) => { + mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -1948,7 +2097,7 @@ describe('TokenDetectionController', () => { await controller.detectTokens({ networkClientId: NetworkType.mainnet, - selectedAddress, + selectedAddress: selectedAccount.address, }); expect(mockTrackMetaMetricsEvent).toHaveBeenCalledWith({ @@ -1980,6 +2129,7 @@ function getTokensPath(chainId: Hex) { type WithControllerCallback = ({ controller, + mockGetAccount, mockGetSelectedAccount, mockKeyringGetState, mockTokensGetState, @@ -1997,6 +2147,7 @@ type WithControllerCallback = ({ triggerNetworkDidChange, }: { controller: TokenDetectionController; + mockGetAccount: (internalAccount: InternalAccount) => void; mockGetSelectedAccount: (address: string) => void; mockKeyringGetState: (state: KeyringControllerState) => void; mockTokensGetState: (state: TokensControllerState) => void; @@ -2047,6 +2198,12 @@ async function withController( const controllerMessenger = messenger ?? new ControllerMessenger(); + const mockGetAccount = jest.fn(); + controllerMessenger.registerActionHandler( + 'AccountsController:getAccount', + mockGetAccount, + ); + const mockGetSelectedAccount = jest.fn(); controllerMessenger.registerActionHandler( 'AccountsController:getSelectedAccount', @@ -2130,6 +2287,9 @@ async function withController( try { return await fn({ controller, + mockGetAccount: (internalAccount: InternalAccount) => { + mockGetAccount.mockReturnValue(internalAccount); + }, mockGetSelectedAccount: (address: string) => { mockGetSelectedAccount.mockReturnValue({ address } as InternalAccount); }, @@ -2185,7 +2345,7 @@ async function withController( }, triggerSelectedAccountChange: (account: InternalAccount) => { controllerMessenger.publish( - 'AccountsController:selectedAccountChange', + 'AccountsController:selectedEvmAccountChange', account, ); }, diff --git a/packages/assets-controllers/src/TokenDetectionController.ts b/packages/assets-controllers/src/TokenDetectionController.ts index bbebdca4807..51ed3be0bdc 100644 --- a/packages/assets-controllers/src/TokenDetectionController.ts +++ b/packages/assets-controllers/src/TokenDetectionController.ts @@ -1,6 +1,7 @@ import type { AccountsControllerGetSelectedAccountAction, - AccountsControllerSelectedAccountChangeEvent, + AccountsControllerGetAccountAction, + AccountsControllerSelectedEvmAccountChangeEvent, } from '@metamask/accounts-controller'; import type { RestrictedControllerMessenger, @@ -105,6 +106,7 @@ export type TokenDetectionControllerActions = export type AllowedActions = | AccountsControllerGetSelectedAccountAction + | AccountsControllerGetAccountAction | NetworkControllerGetNetworkClientByIdAction | NetworkControllerGetNetworkConfigurationByNetworkClientId | NetworkControllerGetStateAction @@ -121,7 +123,7 @@ export type TokenDetectionControllerEvents = TokenDetectionControllerStateChangeEvent; export type AllowedEvents = - | AccountsControllerSelectedAccountChangeEvent + | AccountsControllerSelectedEvmAccountChangeEvent | NetworkControllerNetworkDidChangeEvent | TokenListStateChange | KeyringControllerLockEvent @@ -153,7 +155,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< > { #intervalId?: ReturnType; - #selectedAddress: string; + #selectedAccountId: string; #networkClientId: NetworkClientId; @@ -186,19 +188,19 @@ export class TokenDetectionController extends StaticIntervalPollingController< * @param options.messenger - The controller messaging system. * @param options.disabled - If set to true, all network requests are blocked. * @param options.interval - Polling interval used to fetch new token rates - * @param options.selectedAddress - Vault selected address + * @param options.selectedAccountId - Vault selected address * @param options.getBalancesInSingleCall - Gets the balances of a list of tokens for the given address. * @param options.trackMetaMetricsEvent - Sets options for MetaMetrics event tracking. */ constructor({ - selectedAddress, + selectedAccountId, interval = DEFAULT_INTERVAL, disabled = true, getBalancesInSingleCall, trackMetaMetricsEvent, messenger, }: { - selectedAddress?: string; + selectedAccountId?: string; interval?: number; disabled?: boolean; getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall']; @@ -223,10 +225,14 @@ export class TokenDetectionController extends StaticIntervalPollingController< this.#disabled = disabled; this.setIntervalLength(interval); - this.#selectedAddress = - selectedAddress ?? - this.messagingSystem.call('AccountsController:getSelectedAccount') - .address; + if (selectedAccountId) { + this.#selectedAccountId = selectedAccountId; + } else { + const selectedInternalAccount = this.messagingSystem.call( + 'AccountsController:getSelectedAccount', + ); + this.#selectedAccountId = selectedInternalAccount.id; + } const { chainId, networkClientId } = this.#getCorrectChainIdAndNetworkClientId(); @@ -277,32 +283,32 @@ export class TokenDetectionController extends StaticIntervalPollingController< this.messagingSystem.subscribe( 'PreferencesController:stateChange', - async ({ selectedAddress: newSelectedAddress, useTokenDetection }) => { - const isSelectedAddressChanged = - this.#selectedAddress !== newSelectedAddress; + async ({ useTokenDetection }) => { + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getSelectedAccount', + ); const isDetectionChangedFromPreferences = this.#isDetectionEnabledFromPreferences !== useTokenDetection; - this.#selectedAddress = newSelectedAddress; this.#isDetectionEnabledFromPreferences = useTokenDetection; - if (isSelectedAddressChanged || isDetectionChangedFromPreferences) { + if (isDetectionChangedFromPreferences) { await this.#restartTokenDetection({ - selectedAddress: this.#selectedAddress, + selectedAccountId: selectedAccount.id, }); } }, ); this.messagingSystem.subscribe( - 'AccountsController:selectedAccountChange', - async ({ address: newSelectedAddress }) => { - const isSelectedAddressChanged = - this.#selectedAddress !== newSelectedAddress; - if (isSelectedAddressChanged) { - this.#selectedAddress = newSelectedAddress; + 'AccountsController:selectedEvmAccountChange', + async (internalAccount) => { + const didSelectedAccountIdChanged = + this.#selectedAccountId !== internalAccount.id; + if (didSelectedAccountIdChanged) { + this.#selectedAccountId = internalAccount.id; await this.#restartTokenDetection({ - selectedAddress: this.#selectedAddress, + selectedAccountId: this.#selectedAccountId, }); } }, @@ -436,16 +442,23 @@ export class TokenDetectionController extends StaticIntervalPollingController< * in case of address change or user session initialization. * * @param options - Options for restart token detection. - * @param options.selectedAddress - the selectedAddress against which to detect for token balances + * @param options.selectedAccountId - the id of the InternalAccount against which to detect for token balances * @param options.networkClientId - The ID of the network client to use. */ async #restartTokenDetection({ - selectedAddress, + selectedAccountId, networkClientId, }: { - selectedAddress?: string; + selectedAccountId?: string; networkClientId?: NetworkClientId; } = {}): Promise { + const internalAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + selectedAccountId ?? this.#selectedAccountId, + ); + + const selectedAddress = internalAccount?.address || ''; + await this.detectTokens({ networkClientId, selectedAddress, @@ -472,8 +485,13 @@ export class TokenDetectionController extends StaticIntervalPollingController< return; } + const selectedInternalAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + const addressAgainstWhichToDetect = - selectedAddress ?? this.#selectedAddress; + selectedAddress ?? selectedInternalAccount?.address ?? ''; const { chainId, networkClientId: selectedNetworkClientId } = this.#getCorrectChainIdAndNetworkClientId(networkClientId); const chainIdAgainstWhichToDetect = chainId; diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index 78fcb2a051a..5cf53a0baec 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -5,13 +5,13 @@ import { toChecksumHexAddress, toHex, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientConfiguration, NetworkClientId, NetworkState, } from '@metamask/network-controller'; import { defaultState as defaultNetworkState } from '@metamask/network-controller'; -import type { PreferencesState } from '@metamask/preferences-controller'; import type { Hex } from '@metamask/utils'; import { add0x } from '@metamask/utils'; import assert from 'assert'; @@ -19,6 +19,7 @@ import nock from 'nock'; import { useFakeTimers } from 'sinon'; import { advanceTime, flushPromises } from '../../../tests/helpers'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import { buildCustomNetworkClientConfiguration, buildMockGetNetworkClientById, @@ -37,7 +38,9 @@ import type { } from './TokenRatesController'; import type { TokensControllerState } from './TokensController'; -const defaultSelectedAddress = '0x0000000000000000000000000000000000000001'; +const defaultMockInternalAccount = createMockInternalAccount({ + address: '0xA', +}); const mockTokenAddress = '0x0000000000000000000000000000000000000010'; describe('TokenRatesController', () => { @@ -59,10 +62,11 @@ describe('TokenRatesController', () => { it('should set default state', () => { const controller = new TokenRatesController({ getNetworkClientById: jest.fn(), + getInternalAccount: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -75,10 +79,11 @@ describe('TokenRatesController', () => { it('should initialize with the default config', () => { const controller = new TokenRatesController({ getNetworkClientById: jest.fn(), + getInternalAccount: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -91,7 +96,7 @@ describe('TokenRatesController', () => { disabled: false, nativeCurrency: NetworksTicker.mainnet, chainId: '0x1', - selectedAddress: defaultSelectedAddress, + selectedAccountId: defaultMockInternalAccount.id, }); }); @@ -100,10 +105,11 @@ describe('TokenRatesController', () => { new TokenRatesController({ interval: 100, getNetworkClientById: jest.fn(), + getInternalAccount: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -129,18 +135,22 @@ describe('TokenRatesController', () => { describe('when legacy polling is active', () => { it('should update exchange rates when any of the addresses in the "all tokens" collection change', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = defaultMockInternalAccount; const tokenAddresses = ['0xE1', '0xE2']; + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[0], decimals: 0, @@ -163,7 +173,7 @@ describe('TokenRatesController', () => { controllerEvents.tokensStateChange({ allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[1], decimals: 0, @@ -184,19 +194,23 @@ describe('TokenRatesController', () => { it('should update exchange rates when any of the addresses in the "all detected tokens" collection change', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); const tokenAddresses = ['0xE1', '0xE2']; + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[0], decimals: 0, @@ -219,7 +233,7 @@ describe('TokenRatesController', () => { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[1], decimals: 0, @@ -239,11 +253,14 @@ describe('TokenRatesController', () => { it('should not update exchange rates if both the "all tokens" or "all detected tokens" are exactly the same', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); const tokensState = { allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -259,7 +276,8 @@ describe('TokenRatesController', () => { { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: tokensState, }, @@ -280,10 +298,10 @@ describe('TokenRatesController', () => { it('should not update exchange rates if all of the tokens in "all tokens" just move to "all detected tokens"', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); const tokens = { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -297,11 +315,14 @@ describe('TokenRatesController', () => { { options: { chainId, - selectedAddress, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), }, config: { allTokens: tokens, allDetectedTokens: {}, + selectedAccountId: selectedAccount.id, }, }, async ({ controller, controllerEvents }) => { @@ -324,17 +345,21 @@ describe('TokenRatesController', () => { it('should not update exchange rates if a new token is added to "all detected tokens" but is already present in "all tokens"', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -357,7 +382,7 @@ describe('TokenRatesController', () => { controllerEvents.tokensStateChange({ allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -369,7 +394,7 @@ describe('TokenRatesController', () => { }, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -389,18 +414,22 @@ describe('TokenRatesController', () => { it('should not update exchange rates if a new token is added to "all tokens" but is already present in "all detected tokens"', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -422,7 +451,7 @@ describe('TokenRatesController', () => { controllerEvents.tokensStateChange({ allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -434,7 +463,7 @@ describe('TokenRatesController', () => { }, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -454,18 +483,22 @@ describe('TokenRatesController', () => { it('should not update exchange rates if none of the addresses in "all tokens" or "all detected tokens" change, even if other parts of the token change', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 3, @@ -488,7 +521,7 @@ describe('TokenRatesController', () => { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: mockTokenAddress, decimals: 7, @@ -508,18 +541,24 @@ describe('TokenRatesController', () => { it('should not update exchange rates if none of the addresses in "all tokens" or "all detected tokens" change, when normalized to checksum addresses', async () => { const chainId = '0xC'; - const selectedAddress = '0xAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA'; + const selectedAccount = createMockInternalAccount({ + address: '0xAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', + }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: '0x0EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE2', decimals: 3, @@ -542,7 +581,7 @@ describe('TokenRatesController', () => { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: '0x0eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee2', decimals: 7, @@ -562,18 +601,22 @@ describe('TokenRatesController', () => { it('should not update exchange rates if any of the addresses in "all tokens" or "all detected tokens" merely change order', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: '0xE1', decimals: 0, @@ -602,7 +645,7 @@ describe('TokenRatesController', () => { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: '0xE2', decimals: 0, @@ -630,18 +673,22 @@ describe('TokenRatesController', () => { describe('when legacy polling is inactive', () => { it('should not update exchange rates when any of the addresses in the "all tokens" collection change', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); const tokenAddresses = ['0xE1', '0xE2']; await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[0], decimals: 0, @@ -663,7 +710,7 @@ describe('TokenRatesController', () => { controllerEvents.tokensStateChange({ allTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[1], decimals: 0, @@ -683,19 +730,23 @@ describe('TokenRatesController', () => { it('should not update exchange rates when any of the addresses in the "all detected tokens" collection change', async () => { const chainId = '0xC'; - const selectedAddress = '0xA'; + const selectedAccount = createMockInternalAccount({ address: '0xA' }); + const mockGetInternalAccount = jest + .fn() + .mockReturnValue(selectedAccount); const tokenAddresses = ['0xE1', '0xE2']; await withController( { options: { chainId, - selectedAddress, + selectedAccountId: selectedAccount.id, + getInternalAccount: mockGetInternalAccount, }, config: { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[0], decimals: 0, @@ -717,7 +768,7 @@ describe('TokenRatesController', () => { allTokens: {}, allDetectedTokens: { [chainId]: { - [selectedAddress]: [ + [selectedAccount.address]: [ { address: tokenAddresses[1], decimals: 0, @@ -763,11 +814,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -801,11 +853,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -839,11 +892,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -875,11 +929,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -911,11 +966,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -951,11 +1007,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -988,11 +1045,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -1025,11 +1083,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -1060,11 +1119,12 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange, tokenPricesService: buildMockTokenPricesService(), @@ -1082,37 +1142,41 @@ describe('TokenRatesController', () => { }); }); - describe('PreferencesController::stateChange', () => { + describe('onSelectedAccountChange', () => { let clock: sinon.SinonFakeTimers; - beforeEach(() => { clock = useFakeTimers({ now: Date.now() }); }); - afterEach(() => { clock.restore(); }); describe('when polling is active', () => { it('should update exchange rates when selected address changes', async () => { - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let preferencesStateChangeListener: (state: any) => Promise; - const onPreferencesStateChange = jest + const alternateSelectedAddress = + '0x0000000000000000000000000000000000000002'; + const alternativeAccount = createMockInternalAccount({ + address: alternateSelectedAddress, + }); + + let selectedAccountChangeListener: ( + interalAccount: InternalAccount, + ) => Promise; + const onSelectedAccountChange = jest .fn() .mockImplementation((listener) => { - preferencesStateChangeListener = listener; + selectedAccountChangeListener = listener; }); - const alternateSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const controller = new TokenRatesController( { interval: 100, getNetworkClientById: jest.fn(), + getInternalAccount: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange, + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange, onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -1134,30 +1198,31 @@ describe('TokenRatesController', () => { .mockResolvedValue(); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await preferencesStateChangeListener!({ - selectedAddress: alternateSelectedAddress, - }); + await selectedAccountChangeListener!(alternativeAccount); - expect(updateExchangeRatesSpy).toHaveBeenCalledTimes(1); + expect(updateExchangeRatesSpy).toHaveBeenCalled(); }); it('should not update exchange rates when preferences state changes without selected address changing', async () => { // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let preferencesStateChangeListener: (state: any) => Promise; - const onPreferencesStateChange = jest + + let selectedAccountChangeListener: ( + interalAccount: InternalAccount, + ) => Promise; + const onSelectedAccountChange = jest .fn() .mockImplementation((listener) => { - preferencesStateChangeListener = listener; + selectedAccountChangeListener = listener; }); const controller = new TokenRatesController( { interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange, + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange, onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -1165,7 +1230,7 @@ describe('TokenRatesController', () => { { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: '0x02', decimals: 0, symbol: '', aggregators: [] }, { address: '0x03', decimals: 0, symbol: '', aggregators: [] }, ], @@ -1179,10 +1244,7 @@ describe('TokenRatesController', () => { .mockResolvedValue(); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await preferencesStateChangeListener!({ - selectedAddress: defaultSelectedAddress, - exampleConfig: 'exampleValue', - }); + await selectedAccountChangeListener!(defaultMockInternalAccount); expect(updateExchangeRatesSpy).not.toHaveBeenCalled(); }); @@ -1190,24 +1252,29 @@ describe('TokenRatesController', () => { describe('when polling is inactive', () => { it('should not update exchange rates when selected address changes', async () => { - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let preferencesStateChangeListener: (state: any) => Promise; - const onPreferencesStateChange = jest + const alternateSelectedAddress = + '0x0000000000000000000000000000000000000002'; + const alternateAccount = createMockInternalAccount({ + address: alternateSelectedAddress, + }); + let selectedAccountChangeListener: ( + interalAccount: InternalAccount, + ) => Promise; + const onSelectedAccountChange = jest .fn() .mockImplementation((listener) => { - preferencesStateChangeListener = listener; + selectedAccountChangeListener = listener; }); - const alternateSelectedAddress = - '0x0000000000000000000000000000000000000002'; + const controller = new TokenRatesController( { interval: 100, + getInternalAccount: jest.fn(), getNetworkClientById: jest.fn(), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange, + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange, onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService: buildMockTokenPricesService(), @@ -1215,7 +1282,7 @@ describe('TokenRatesController', () => { { allTokens: { '0x1': { - [alternateSelectedAddress]: [ + [alternateAccount.address]: [ { address: '0x02', decimals: 0, symbol: '', aggregators: [] }, { address: '0x03', decimals: 0, symbol: '', aggregators: [] }, ], @@ -1228,9 +1295,7 @@ describe('TokenRatesController', () => { .mockResolvedValue(); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await preferencesStateChangeListener!({ - selectedAddress: alternateSelectedAddress, - }); + await selectedAccountChangeListener!(alternateAccount); expect(updateExchangeRatesSpy).not.toHaveBeenCalled(); }); @@ -1257,10 +1322,13 @@ describe('TokenRatesController', () => { { interval, getNetworkClientById: jest.fn(), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService, @@ -1268,7 +1336,7 @@ describe('TokenRatesController', () => { { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -1301,10 +1369,13 @@ describe('TokenRatesController', () => { { interval, getNetworkClientById: jest.fn(), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), chainId: '0x1', ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), tokenPricesService, @@ -1312,7 +1383,7 @@ describe('TokenRatesController', () => { { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -1356,8 +1427,8 @@ describe('TokenRatesController', () => { interval, chainId: '0x2', ticker: 'ticker', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1366,12 +1437,15 @@ describe('TokenRatesController', () => { ticker: NetworksTicker.mainnet, }, }), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -1408,8 +1482,8 @@ describe('TokenRatesController', () => { { chainId: '0x2', ticker: 'ticker', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1418,12 +1492,15 @@ describe('TokenRatesController', () => { ticker: NetworksTicker.mainnet, }, }), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: '0x02', decimals: 0, @@ -1513,8 +1590,8 @@ describe('TokenRatesController', () => { { chainId: '0x2', ticker: 'ticker', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1523,12 +1600,15 @@ describe('TokenRatesController', () => { ticker: 'LOL', }, }), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: '0x02', decimals: 0, @@ -1617,8 +1697,8 @@ describe('TokenRatesController', () => { { chainId: '0x2', ticker: 'ETH', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1627,12 +1707,15 @@ describe('TokenRatesController', () => { ticker: 'LOL', }, }), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: '0x02', decimals: 0, @@ -1674,8 +1757,8 @@ describe('TokenRatesController', () => { interval, chainId: '0x2', ticker: 'ticker', - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, + onSelectedAccountChange: jest.fn(), onTokensStateChange: jest.fn(), onNetworkStateChange: jest.fn(), getNetworkClientById: jest.fn().mockReturnValue({ @@ -1684,12 +1767,15 @@ describe('TokenRatesController', () => { ticker: NetworksTicker.mainnet, }, }), + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), tokenPricesService, }, { allTokens: { '0x1': { - [defaultSelectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: mockTokenAddress, decimals: 0, @@ -1728,14 +1814,24 @@ describe('TokenRatesController', () => { ])('%s', (method) => { it('does not update state when disabled', async () => { await withController( - { config: { disabled: true } }, + { + options: { + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + disabled: true, + selectedAccountId: defaultMockInternalAccount.id, + }, + }, async ({ controller, controllerEvents }) => { const tokenAddress = '0x0000000000000000000000000000000000000001'; await callUpdateExchangeRatesMethod({ allTokens: { [ChainId.mainnet]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddress, decimals: 18, @@ -1759,51 +1855,65 @@ describe('TokenRatesController', () => { }); it('does not update state if there are no tokens for the given chain and address', async () => { - await withController(async ({ controller, controllerEvents }) => { - const tokenAddress = '0x0000000000000000000000000000000000000001'; - const differentAccount = '0x1000000000000000000000000000000000000000'; + await withController( + { + options: { + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, + }, + }, + async ({ controller, controllerEvents }) => { + const tokenAddress = '0x0000000000000000000000000000000000000001'; + const differentAccount = '0x1000000000000000000000000000000000000000'; - await callUpdateExchangeRatesMethod({ - allTokens: { - // These tokens are for the right chain but wrong account - [ChainId.mainnet]: { - [differentAccount]: [ - { - address: tokenAddress, - decimals: 18, - symbol: 'TST', - aggregators: [], - }, - ], - }, - // These tokens are for the right account but wrong chain - [toHex(2)]: { - [controller.config.selectedAddress]: [ - { - address: tokenAddress, - decimals: 18, - symbol: 'TST', - aggregators: [], - }, - ], + await callUpdateExchangeRatesMethod({ + allTokens: { + // These tokens are for the right chain but wrong account + [ChainId.mainnet]: { + [differentAccount]: [ + { + address: tokenAddress, + decimals: 18, + symbol: 'TST', + aggregators: [], + }, + ], + }, + // These tokens are for the right account but wrong chain + [toHex(2)]: { + [defaultMockInternalAccount.address]: [ + { + address: tokenAddress, + decimals: 18, + symbol: 'TST', + aggregators: [], + }, + ], + }, }, - }, - chainId: ChainId.mainnet, - controller, - controllerEvents, - method, - nativeCurrency: 'ETH', - selectedNetworkClientId: InfuraNetworkType.mainnet, - }); + chainId: toHex(1), + controller, + controllerEvents, + method, + nativeCurrency: 'ETH', + selectedNetworkClientId: InfuraNetworkType.mainnet, + }); - expect(controller.state).toStrictEqual({ - marketData: { - '0x1': { - '0x0000000000000000000000000000000000000000': { currency: 'ETH' }, + expect(controller.state).toStrictEqual({ + marketData: { + '0x1': { + '0x0000000000000000000000000000000000000000': { + currency: 'ETH', + }, + }, }, - }, - }); - }); + }); + }, + ); }); it('does not update state if the price update fails', async () => { @@ -1813,7 +1923,17 @@ describe('TokenRatesController', () => { .mockRejectedValue(new Error('Failed to fetch')), }); await withController( - { options: { tokenPricesService } }, + { + options: { + tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, + }, + }, async ({ controller, controllerEvents }) => { const tokenAddress = '0x0000000000000000000000000000000000000001'; @@ -1822,7 +1942,7 @@ describe('TokenRatesController', () => { await callUpdateExchangeRatesMethod({ allTokens: { [ChainId.mainnet]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddress, decimals: 18, @@ -1866,13 +1986,19 @@ describe('TokenRatesController', () => { options: { ticker, tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, }, }, async ({ controller, controllerEvents }) => { await callUpdateExchangeRatesMethod({ allTokens: { [chainId]: { - [controller.config.selectedAddress]: tokens, + [defaultMockInternalAccount.address]: tokens, }, }, chainId, @@ -1922,12 +2048,22 @@ describe('TokenRatesController', () => { }), }); await withController( - { options: { tokenPricesService } }, + { + options: { + tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, + }, + }, async ({ controller, controllerEvents }) => { await callUpdateExchangeRatesMethod({ allTokens: { [ChainId.mainnet]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddresses[0], decimals: 18, @@ -1994,12 +2130,20 @@ describe('TokenRatesController', () => { }), }); await withController( - { options: { tokenPricesService } }, + { + options: { + tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { selectedAccountId: defaultMockInternalAccount.id }, + }, async ({ controller, controllerEvents }) => { await callUpdateExchangeRatesMethod({ allTokens: { [toHex(2)]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddresses[0], decimals: 18, @@ -2088,6 +2232,12 @@ describe('TokenRatesController', () => { { options: { tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, }, mockNetworkClientConfigurationsByNetworkClientId: { [selectedNetworkClientId]: selectedNetworkClientConfiguration, @@ -2097,7 +2247,7 @@ describe('TokenRatesController', () => { await callUpdateExchangeRatesMethod({ allTokens: { [selectedNetworkClientConfiguration.chainId]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddresses[0], decimals: 18, @@ -2178,6 +2328,12 @@ describe('TokenRatesController', () => { { options: { tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, }, mockNetworkClientConfigurationsByNetworkClientId: { [selectedNetworkClientId]: selectedNetworkClientConfiguration, @@ -2187,7 +2343,7 @@ describe('TokenRatesController', () => { await callUpdateExchangeRatesMethod({ allTokens: { [selectedNetworkClientConfiguration.chainId]: { - [controller.config.selectedAddress]: tokens, + [defaultMockInternalAccount.address]: tokens, }, }, chainId: selectedNetworkClientConfiguration.chainId, @@ -2251,6 +2407,12 @@ describe('TokenRatesController', () => { { options: { tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, }, mockNetworkClientConfigurationsByNetworkClientId: { [selectedNetworkClientId]: selectedNetworkClientConfiguration, @@ -2260,7 +2422,7 @@ describe('TokenRatesController', () => { await callUpdateExchangeRatesMethod({ allTokens: { [selectedNetworkClientConfiguration.chainId]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddresses[0], decimals: 18, @@ -2319,13 +2481,23 @@ describe('TokenRatesController', () => { fetchTokenPrices: fetchTokenPricesMock, }); await withController( - { options: { tokenPricesService } }, + { + options: { + tokenPricesService, + getInternalAccount: jest + .fn() + .mockReturnValue(defaultMockInternalAccount), + }, + config: { + selectedAccountId: defaultMockInternalAccount.id, + }, + }, async ({ controller, controllerEvents }) => { const updateExchangeRates = async () => await callUpdateExchangeRatesMethod({ allTokens: { [toHex(1)]: { - [controller.config.selectedAddress]: [ + [defaultMockInternalAccount.address]: [ { address: tokenAddresses[0], decimals: 18, @@ -2382,7 +2554,7 @@ describe('TokenRatesController', () => { */ type ControllerEvents = { networkStateChange: (state: NetworkState) => void; - preferencesStateChange: (state: PreferencesState) => void; + seletedAccountChange: (internalAccount: InternalAccount) => void; tokensStateChange: (state: TokensControllerState) => void; }; @@ -2454,13 +2626,14 @@ async function withController( onNetworkStateChange: (listener) => { controllerEvents.networkStateChange = listener; }, - onPreferencesStateChange: (listener) => { - controllerEvents.preferencesStateChange = listener; + onSelectedAccountChange: (listener) => { + controllerEvents.seletedAccountChange = listener; }, onTokensStateChange: (listener) => { controllerEvents.tokensStateChange = listener; }, - selectedAddress: defaultSelectedAddress, + getInternalAccount: jest.fn(), + selectedAccountId: defaultMockInternalAccount.id, ticker: NetworksTicker.mainnet, tokenPricesService: buildMockTokenPricesService(), ...options, diff --git a/packages/assets-controllers/src/TokenRatesController.ts b/packages/assets-controllers/src/TokenRatesController.ts index 0f0fa9cd327..734999d2b6f 100644 --- a/packages/assets-controllers/src/TokenRatesController.ts +++ b/packages/assets-controllers/src/TokenRatesController.ts @@ -5,13 +5,13 @@ import { FALL_BACK_VS_CURRENCY, toHex, } from '@metamask/controller-utils'; +import { type InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientId, NetworkController, NetworkState, } from '@metamask/network-controller'; import { StaticIntervalPollingControllerV1 } from '@metamask/polling-controller'; -import type { PreferencesState } from '@metamask/preferences-controller'; import { createDeferredPromise, type Hex } from '@metamask/utils'; import { isEqual } from 'lodash'; @@ -59,7 +59,7 @@ export interface TokenRatesConfig extends BaseConfig { interval: number; nativeCurrency: string; chainId: Hex; - selectedAddress: string; + selectedAccountId: string; allTokens: { [chainId: Hex]: { [key: string]: Token[] } }; allDetectedTokens: { [chainId: Hex]: { [key: string]: Token[] } }; threshold: number; @@ -175,6 +175,8 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< private readonly getNetworkClientById: NetworkController['getNetworkClientById']; + private readonly getInternalAccount: (accountId: string) => InternalAccount; + /** * Creates a TokenRatesController instance. * @@ -184,8 +186,9 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< * @param options.getNetworkClientById - Gets the network client with the given id from the NetworkController. * @param options.chainId - The chain ID of the current network. * @param options.ticker - The ticker for the current network. - * @param options.selectedAddress - The current selected address. - * @param options.onPreferencesStateChange - Allows subscribing to preference controller state changes. + * @param options.getInternalAccount - A callback to get an InternalAccount by id. + * @param options.selectedAccountId - The current selected address. + * @param options.onSelectedAccountChange - Allows subscribing to changes of selected account. * @param options.onTokensStateChange - Allows subscribing to token controller state changes. * @param options.onNetworkStateChange - Allows subscribing to network state changes. * @param options.tokenPricesService - An object in charge of retrieving token prices. @@ -199,8 +202,9 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< getNetworkClientById, chainId: initialChainId, ticker: initialTicker, - selectedAddress: initialSelectedAddress, - onPreferencesStateChange, + selectedAccountId, + getInternalAccount, + onSelectedAccountChange, onTokensStateChange, onNetworkStateChange, tokenPricesService, @@ -210,9 +214,10 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< getNetworkClientById: NetworkController['getNetworkClientById']; chainId: Hex; ticker: string; - selectedAddress: string; - onPreferencesStateChange: ( - listener: (preferencesState: PreferencesState) => void, + selectedAccountId: string; + getInternalAccount: (accountId: string) => InternalAccount; + onSelectedAccountChange: ( + listener: (internalAccount: InternalAccount) => void, ) => void; onTokensStateChange: ( listener: (tokensState: TokensControllerState) => void, @@ -232,7 +237,7 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< disabled: false, nativeCurrency: initialTicker, chainId: initialChainId, - selectedAddress: initialSelectedAddress, + selectedAccountId, allTokens: {}, // TODO: initialize these correctly, maybe as part of BaseControllerV2 migration allDetectedTokens: {}, }; @@ -243,15 +248,16 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< this.initialize(); this.setIntervalLength(interval); this.getNetworkClientById = getNetworkClientById; + this.getInternalAccount = getInternalAccount; this.#tokenPricesService = tokenPricesService; if (config?.disabled) { this.configure({ disabled: true }, false, false); } - onPreferencesStateChange(async ({ selectedAddress }) => { - if (this.config.selectedAddress !== selectedAddress) { - this.configure({ selectedAddress }); + onSelectedAccountChange(async (internalAccount) => { + if (this.config.selectedAccountId !== internalAccount.id) { + this.configure({ selectedAccountId: internalAccount.id }); if (this.#pollState === PollState.Active) { await this.updateExchangeRates(); } @@ -298,10 +304,11 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< * @returns The list of tokens addresses for the current chain */ #getTokenAddresses(chainId: Hex): Hex[] { - const { allTokens, allDetectedTokens } = this.config; - const tokens = allTokens[chainId]?.[this.config.selectedAddress] || []; + const { allTokens, allDetectedTokens, selectedAccountId } = this.config; + const internalAccount = this.getInternalAccount(selectedAccountId); + const tokens = allTokens[chainId]?.[internalAccount.address] || []; const detectedTokens = - allDetectedTokens[chainId]?.[this.config.selectedAddress] || []; + allDetectedTokens[chainId]?.[internalAccount.address] || []; return [ ...new Set( @@ -356,6 +363,7 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< */ async updateExchangeRates() { const { chainId, nativeCurrency } = this.config; + await this.updateExchangeRatesByChainId({ chainId, nativeCurrency, diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index 970310d99b2..442ff883741 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -1,4 +1,5 @@ import { Contract } from '@ethersproject/contracts'; +import type { AccountsController } from '@metamask/accounts-controller'; import type { ApprovalStateChange } from '@metamask/approval-controller'; import { ApprovalController, @@ -13,18 +14,18 @@ import { convertHexToDecimal, InfuraNetworkType, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import type { NetworkClientConfiguration, NetworkClientId, } from '@metamask/network-controller'; import { defaultState as defaultNetworkState } from '@metamask/network-controller'; -import type { PreferencesState } from '@metamask/preferences-controller'; -import { getDefaultPreferencesState } from '@metamask/preferences-controller'; import nock from 'nock'; import * as sinon from 'sinon'; import { v1 as uuidV1 } from 'uuid'; import { FakeProvider } from '../../../tests/fake-provider'; +import { createMockInternalAccount } from '../../accounts-controller/src/tests/mocks'; import type { ExtractAvailableAction, ExtractAvailableEvent, @@ -58,6 +59,10 @@ const uuidV1Mock = jest.mocked(uuidV1); const ERC20StandardMock = jest.mocked(ERC20Standard); const ERC1155StandardMock = jest.mocked(ERC1155Standard); +const defaultMockInternalAccount = createMockInternalAccount({ + address: '0x1', +}); + describe('TokensController', () => { beforeEach(() => { uuidV1Mock.mockReturnValue('9b1deb4d-3b7d-4bad-9bdd-2b0d7b3dcb6d'); @@ -266,32 +271,34 @@ describe('TokensController', () => { it('should add token by selected address', async () => { await withController( - async ({ controller, triggerPreferencesStateChange }) => { + async ({ + controller, + triggerSelectedAccountChange, + getAccountHandler, + }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); const firstAddress = '0x123'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + }); const secondAddress = '0x321'; - - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: firstAddress, + const secondAccount = createMockInternalAccount({ + address: secondAddress, }); + + getAccountHandler.mockReturnValue(firstAccount); + triggerSelectedAccountChange(firstAccount); await controller.addToken({ address: '0x01', symbol: 'bar', decimals: 2, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: secondAddress, - }); + triggerSelectedAccountChange(secondAccount); expect(controller.state.tokens).toHaveLength(0); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: firstAddress, - }); + triggerSelectedAccountChange(firstAccount); expect(controller.state.tokens[0]).toStrictEqual({ address: '0x01', decimals: 2, @@ -408,25 +415,32 @@ describe('TokensController', () => { it('should remove token by selected address', async () => { await withController( - async ({ controller, triggerPreferencesStateChange }) => { + async ({ + controller, + triggerSelectedAccountChange, + getAccountHandler, + }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); const firstAddress = '0x123'; + const firstAccount = createMockInternalAccount({ + address: firstAddress, + }); const secondAddress = '0x321'; - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: firstAddress, + const secondAccount = createMockInternalAccount({ + address: secondAddress, }); + + getAccountHandler.mockReturnValue(firstAccount); + triggerSelectedAccountChange(firstAccount); await controller.addToken({ address: '0x02', symbol: 'baz', decimals: 2, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: secondAddress, - }); + getAccountHandler.mockReturnValue(secondAccount); + triggerSelectedAccountChange(secondAccount); await controller.addToken({ address: '0x01', symbol: 'bar', @@ -436,10 +450,7 @@ describe('TokensController', () => { controller.ignoreTokens(['0x01']); expect(controller.state.tokens).toHaveLength(0); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: firstAddress, - }); + triggerSelectedAccountChange(firstAccount); expect(controller.state.tokens[0]).toStrictEqual({ address: '0x02', decimals: 2, @@ -522,14 +533,16 @@ describe('TokensController', () => { await withController( async ({ controller, - triggerPreferencesStateChange, + triggerSelectedAccountChange, changeNetwork, + getAccountHandler, }) => { const selectedAddress = '0x0001'; - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress, + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, }); + getAccountHandler.mockReturnValue(selectedAccount); + triggerSelectedAccountChange(selectedAccount); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); await controller.addToken({ address: '0x01', @@ -569,14 +582,16 @@ describe('TokensController', () => { await withController( async ({ controller, - triggerPreferencesStateChange, + triggerSelectedAccountChange, changeNetwork, + getAccountHandler, }) => { const selectedAddress = '0x0001'; - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress, + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, }); + getAccountHandler.mockReturnValue(selectedAccount); + triggerSelectedAccountChange(selectedAccount); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); await controller.addToken({ address: '0x01', @@ -606,15 +621,20 @@ describe('TokensController', () => { await withController( async ({ controller, - triggerPreferencesStateChange, + triggerSelectedAccountChange, changeNetwork, + getAccountHandler, }) => { const selectedAddress1 = '0x0001'; + const selectedAccount1 = createMockInternalAccount({ + address: selectedAddress1, + }); const selectedAddress2 = '0x0002'; - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: selectedAddress1, + const selectedAccount2 = createMockInternalAccount({ + address: selectedAddress2, }); + getAccountHandler.mockReturnValue(selectedAccount1); + triggerSelectedAccountChange(selectedAccount1); changeNetwork({ selectedNetworkClientId: InfuraNetworkType.sepolia }); await controller.addToken({ address: '0x01', @@ -638,10 +658,8 @@ describe('TokensController', () => { controller.ignoreTokens(['0x02']); expect(controller.state.ignoredTokens).toStrictEqual(['0x02']); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: selectedAddress2, - }); + getAccountHandler.mockReturnValue(selectedAccount2); + triggerSelectedAccountChange(selectedAccount2); expect(controller.state.ignoredTokens).toHaveLength(0); await controller.addToken({ @@ -780,7 +798,8 @@ describe('TokensController', () => { describe('addToken method', () => { it('should add isERC721 = true when token is an NFT and is in our contract-metadata repo', async () => { - await withController(async ({ controller }) => { + await withController(async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(defaultMockInternalAccount); const contractAddresses = Object.keys(contractMaps); const erc721ContractAddresses = contractAddresses.filter( (contractAddress) => contractMaps[contractAddress].erc721 === true, @@ -802,7 +821,8 @@ describe('TokensController', () => { }); it('should add isERC721 = true when the token is an NFT but not in our contract-metadata repo', async () => { - await withController(async ({ controller }) => { + await withController(async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(defaultMockInternalAccount); ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: true }), ); @@ -830,7 +850,8 @@ describe('TokensController', () => { }); it('should add isERC721 = false to token object already in state when token is not an NFT and in our contract-metadata repo', async () => { - await withController(async ({ controller }) => { + await withController(async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(defaultMockInternalAccount); const contractAddresses = Object.keys(contractMaps); const erc20ContractAddresses = contractAddresses.filter( (contractAddress) => contractMaps[contractAddress].erc20 === true, @@ -852,7 +873,8 @@ describe('TokensController', () => { }); it('should add isERC721 = false when the token is not an NFT and not in our contract-metadata repo', async () => { - await withController(async ({ controller }) => { + await withController(async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(defaultMockInternalAccount); ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); @@ -880,21 +902,26 @@ describe('TokensController', () => { }); it('should throw error if switching networks while adding token', async () => { - await withController(async ({ controller, changeNetwork }) => { - const dummyTokenAddress = - '0x514910771AF9Ca656af840dff83E8264EcF986CA'; + await withController( + async ({ controller, changeNetwork, getAccountHandler }) => { + getAccountHandler.mockReturnValue(defaultMockInternalAccount); + const dummyTokenAddress = + '0x514910771AF9Ca656af840dff83E8264EcF986CA'; - const addTokenPromise = controller.addToken({ - address: dummyTokenAddress, - symbol: 'LINK', - decimals: 18, - }); - changeNetwork({ selectedNetworkClientId: InfuraNetworkType.goerli }); + const addTokenPromise = controller.addToken({ + address: dummyTokenAddress, + symbol: 'LINK', + decimals: 18, + }); + changeNetwork({ + selectedNetworkClientId: InfuraNetworkType.goerli, + }); - await expect(addTokenPromise).rejects.toThrow( - 'TokensController Error: Switched networks while adding token', - ); - }); + await expect(addTokenPromise).rejects.toThrow( + 'TokensController Error: Switched networks while adding token', + ); + }, + ); }); }); @@ -971,7 +998,8 @@ describe('TokensController', () => { async ({ controller, changeNetwork, - triggerPreferencesStateChange, + triggerSelectedAccountChange, + getAccountHandler, }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), @@ -981,13 +1009,13 @@ describe('TokensController', () => { const CONFIGURED_CHAIN = ChainId.sepolia; const CONFIGURED_NETWORK_CLIENT_ID = InfuraNetworkType.sepolia; const CONFIGURED_ADDRESS = '0xConfiguredAddress'; + const configuredAccount = createMockInternalAccount({ + address: CONFIGURED_ADDRESS, + }); changeNetwork({ selectedNetworkClientId: CONFIGURED_NETWORK_CLIENT_ID, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: CONFIGURED_ADDRESS, - }); + triggerSelectedAccountChange(configuredAccount); // A different chain + address const OTHER_CHAIN = '0xOtherChainId'; @@ -1011,6 +1039,8 @@ describe('TokensController', () => { detectedTokenOtherAccount, ] = generateTokens(3); + getAccountHandler.mockReturnValue(configuredAccount); + // Run twice to ensure idempotency for (let i = 0; i < 2; i++) { // Add and detect some tokens on the configured chain + account @@ -1570,7 +1600,6 @@ describe('TokensController', () => { buildMockEthersERC721Contract({ supportsInterface: false }), ); uuidV1Mock.mockReturnValue(requestId); - await controller.watchAsset({ asset, type: 'ERC20' }); expect(controller.state.tokens).toHaveLength(1); @@ -1721,7 +1750,6 @@ describe('TokensController', () => { buildMockEthersERC721Contract({ supportsInterface: false }), ); uuidV1Mock.mockReturnValue(requestId); - await expect( controller.watchAsset({ asset, type: 'ERC20' }), ).rejects.toThrow(errorMessage); @@ -1844,14 +1872,20 @@ describe('TokensController', () => { describe('when PreferencesController:stateChange is published', () => { it('should update tokens list when set address changes', async () => { await withController( - async ({ controller, triggerPreferencesStateChange }) => { + async ({ + controller, + triggerSelectedAccountChange, + getAccountHandler, + }) => { ContractMock.mockReturnValue( buildMockEthersERC721Contract({ supportsInterface: false }), ); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: '0x1', + const selectedAccount = createMockInternalAccount({ address: '0x1' }); + const selectedAccount2 = createMockInternalAccount({ + address: '0x2', }); + getAccountHandler.mockReturnValue(selectedAccount); + triggerSelectedAccountChange(selectedAccount); await controller.addToken({ address: '0x01', symbol: 'A', @@ -1862,10 +1896,8 @@ describe('TokensController', () => { symbol: 'B', decimals: 5, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: '0x2', - }); + getAccountHandler.mockReturnValue(selectedAccount2); + triggerSelectedAccountChange(selectedAccount2); expect(controller.state.tokens).toStrictEqual([]); await controller.addToken({ @@ -1873,10 +1905,7 @@ describe('TokensController', () => { symbol: 'C', decimals: 6, }); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: '0x1', - }); + triggerSelectedAccountChange(selectedAccount); expect(controller.state.tokens).toStrictEqual([ { address: '0x01', @@ -1900,10 +1929,7 @@ describe('TokensController', () => { }, ]); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - selectedAddress: '0x2', - }); + triggerSelectedAccountChange(selectedAccount2); expect(controller.state.tokens).toStrictEqual([ { address: '0x03', @@ -2010,6 +2036,9 @@ describe('TokensController', () => { describe('Clearing nested lists', () => { it('should clear nest allTokens under chain ID and selected address when an added token is ignored', async () => { const selectedAddress = '0x1'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); const tokenAddress = '0x01'; const dummyTokens = [ { @@ -2025,10 +2054,11 @@ describe('TokensController', () => { { options: { chainId: ChainId.mainnet, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller }) => { + async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(selectedAccount); await controller.addTokens(dummyTokens); controller.ignoreTokens([tokenAddress]); @@ -2041,6 +2071,9 @@ describe('TokensController', () => { it('should clear nest allIgnoredTokens under chain ID and selected address when an ignored token is re-added', async () => { const selectedAddress = '0x1'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); const tokenAddress = '0x01'; const dummyTokens = [ { @@ -2056,10 +2089,11 @@ describe('TokensController', () => { { options: { chainId: ChainId.mainnet, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller }) => { + async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(selectedAccount); await controller.addTokens(dummyTokens); controller.ignoreTokens([tokenAddress]); await controller.addTokens(dummyTokens); @@ -2073,6 +2107,9 @@ describe('TokensController', () => { it('should clear nest allDetectedTokens under chain ID and selected address when an detected token is added to tokens list', async () => { const selectedAddress = '0x1'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); const tokenAddress = '0x01'; const dummyTokens = [ { @@ -2088,10 +2125,11 @@ describe('TokensController', () => { { options: { chainId: ChainId.mainnet, - selectedAddress, + selectedAccountId: selectedAccount.id, }, }, - async ({ controller }) => { + async ({ controller, getAccountHandler }) => { + getAccountHandler.mockReturnValue(selectedAccount); await controller.addDetectedTokens(dummyTokens); await controller.addTokens(dummyTokens); @@ -2165,7 +2203,7 @@ type WithControllerCallback = ({ changeNetwork, messenger, approvalController, - triggerPreferencesStateChange, + triggerSelectedAccountChange, }: { controller: TokensController; changeNetwork: (networkControllerState: { @@ -2173,7 +2211,11 @@ type WithControllerCallback = ({ }) => void; messenger: UnrestrictedMessenger; approvalController: ApprovalController; - triggerPreferencesStateChange: (state: PreferencesState) => void; + triggerSelectedAccountChange: (internalAccount: InternalAccount) => void; + getAccountHandler: jest.Mock< + ReturnType, + Parameters + >; }) => Promise | ReturnValue; type WithControllerArgs = @@ -2227,16 +2269,17 @@ async function withController( allowedActions: [ 'ApprovalController:addRequest', 'NetworkController:getNetworkClientById', + 'AccountsController:getAccount', ], allowedEvents: [ 'NetworkController:networkDidChange', - 'PreferencesController:stateChange', + 'AccountsController:selectedEvmAccountChange', 'TokenListController:stateChange', ], }); const controller = new TokensController({ chainId: ChainId.mainnet, - selectedAddress: '0x1', + selectedAccountId: defaultMockInternalAccount.id, // The tests assume that this is set, but they shouldn't make that // assumption. But we have to do this due to a bug in TokensController // where the provider can possibly be `undefined` if `networkClientId` is @@ -2246,10 +2289,20 @@ async function withController( ...options, }); - const triggerPreferencesStateChange = (state: PreferencesState) => { - messenger.publish('PreferencesController:stateChange', state, []); + const triggerSelectedAccountChange = (internalAccount: InternalAccount) => { + messenger.publish( + 'AccountsController:selectedEvmAccountChange', + internalAccount, + ); }; + const getAccountHandler = jest.fn(); + + messenger.registerActionHandler( + `AccountsController:getAccount`, + getAccountHandler.mockReturnValue(defaultMockInternalAccount), + ); + const changeNetwork = ({ selectedNetworkClientId, }: { @@ -2274,7 +2327,8 @@ async function withController( changeNetwork, messenger, approvalController, - triggerPreferencesStateChange, + triggerSelectedAccountChange, + getAccountHandler, }); } diff --git a/packages/assets-controllers/src/TokensController.ts b/packages/assets-controllers/src/TokensController.ts index ce7cb493deb..4b843d29062 100644 --- a/packages/assets-controllers/src/TokensController.ts +++ b/packages/assets-controllers/src/TokensController.ts @@ -1,5 +1,9 @@ import { Contract } from '@ethersproject/contracts'; import { Web3Provider } from '@ethersproject/providers'; +import type { + AccountsControllerGetAccountAction, + AccountsControllerSelectedEvmAccountChangeEvent, +} from '@metamask/accounts-controller'; import type { AddApprovalRequest } from '@metamask/approval-controller'; import type { RestrictedControllerMessenger, @@ -19,6 +23,7 @@ import { isValidHexAddress, safelyExecute, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; import { abiERC721 } from '@metamask/metamask-eth-abis'; import type { NetworkClientId, @@ -27,10 +32,6 @@ import type { NetworkState, Provider, } from '@metamask/network-controller'; -import type { - PreferencesControllerStateChangeEvent, - PreferencesState, -} from '@metamask/preferences-controller'; import { rpcErrors } from '@metamask/rpc-errors'; import type { Hex } from '@metamask/utils'; import { Mutex } from 'async-mutex'; @@ -136,7 +137,8 @@ export type TokensControllerAddDetectedTokensAction = { */ export type AllowedActions = | AddApprovalRequest - | NetworkControllerGetNetworkClientByIdAction; + | NetworkControllerGetNetworkClientByIdAction + | AccountsControllerGetAccountAction; export type TokensControllerStateChangeEvent = ControllerStateChangeEvent< typeof controllerName, @@ -147,8 +149,8 @@ export type TokensControllerEvents = TokensControllerStateChangeEvent; export type AllowedEvents = | NetworkControllerNetworkDidChangeEvent - | PreferencesControllerStateChangeEvent - | TokenListStateChange; + | TokenListStateChange + | AccountsControllerSelectedEvmAccountChangeEvent; /** * The messenger of the {@link TokensController}. @@ -184,7 +186,7 @@ export class TokensController extends BaseController< #chainId: Hex; - #selectedAddress: string; + #selectedAccountId: string; #provider: Provider | undefined; @@ -194,20 +196,20 @@ export class TokensController extends BaseController< * Tokens controller options * @param options - Constructor options. * @param options.chainId - The chain ID of the current network. - * @param options.selectedAddress - Vault selected address + * @param options.selectedAccountId - Vault selected account id * @param options.provider - Network provider. * @param options.state - Initial state to set on this controller. * @param options.messenger - The controller messenger. */ constructor({ chainId: initialChainId, - selectedAddress, + selectedAccountId, provider, state, messenger, }: { chainId: Hex; - selectedAddress: string; + selectedAccountId: string; provider: Provider | undefined; state?: Partial; messenger: TokensControllerMessenger; @@ -226,7 +228,7 @@ export class TokensController extends BaseController< this.#provider = provider; - this.#selectedAddress = selectedAddress; + this.#selectedAccountId = selectedAccountId; this.#abortController = new AbortController(); @@ -236,8 +238,8 @@ export class TokensController extends BaseController< ); this.messagingSystem.subscribe( - 'PreferencesController:stateChange', - this.#onPreferenceControllerStateChange.bind(this), + 'AccountsController:selectedEvmAccountChange', + this.#onSelectedAccountChange.bind(this), ); this.messagingSystem.subscribe( @@ -273,29 +275,32 @@ export class TokensController extends BaseController< this.#abortController.abort(); this.#abortController = new AbortController(); this.#chainId = chainId; + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); this.update((state) => { - state.tokens = allTokens[chainId]?.[this.#selectedAddress] || []; + state.tokens = allTokens[chainId]?.[selectedAccount?.address || ''] || []; state.ignoredTokens = - allIgnoredTokens[chainId]?.[this.#selectedAddress] || []; + allIgnoredTokens[chainId]?.[selectedAccount?.address || ''] || []; state.detectedTokens = - allDetectedTokens[chainId]?.[this.#selectedAddress] || []; + allDetectedTokens[chainId]?.[selectedAccount?.address || ''] || []; }); } /** - * Handles the state change of the preference controller. - * @param preferencesState - The new state of the preference controller. - * @param preferencesState.selectedAddress - The current selected address of the preference controller. + * Handles the selected account change in the accounts controller. + * @param selectedAccount - The new selected account */ - #onPreferenceControllerStateChange({ selectedAddress }: PreferencesState) { + #onSelectedAccountChange(selectedAccount: InternalAccount) { const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; - this.#selectedAddress = selectedAddress; + this.#selectedAccountId = selectedAccount.id; this.update((state) => { - state.tokens = allTokens[this.#chainId]?.[selectedAddress] ?? []; + state.tokens = allTokens[this.#chainId]?.[selectedAccount.address] ?? []; state.ignoredTokens = - allIgnoredTokens[this.#chainId]?.[selectedAddress] ?? []; + allIgnoredTokens[this.#chainId]?.[selectedAccount.address] ?? []; state.detectedTokens = - allDetectedTokens[this.#chainId]?.[selectedAddress] ?? []; + allDetectedTokens[this.#chainId]?.[selectedAccount.address] ?? []; }); } @@ -357,7 +362,6 @@ export class TokensController extends BaseController< networkClientId?: NetworkClientId; }): Promise { const chainId = this.#chainId; - const selectedAddress = this.#selectedAddress; const releaseLock = await this.#mutex.acquire(); const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; let currentChainId = chainId; @@ -368,8 +372,15 @@ export class TokensController extends BaseController< ).configuration.chainId; } - const accountAddress = interactingAddress || selectedAddress; - const isInteractingWithWalletAccount = accountAddress === selectedAddress; + const internalAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + const accountAddress = interactingAddress || internalAccount?.address || ''; + const isInteractingWithWalletAccount = + accountAddress === internalAccount?.address; try { address = toChecksumHexAddress(address); @@ -578,10 +589,15 @@ export class TokensController extends BaseController< ) { const releaseLock = await this.#mutex.acquire(); - // Get existing tokens for the chain + account + const internalAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + const chainId = detectionDetails?.chainId ?? this.#chainId; + // Previously selectedAddress could be an empty string. This is to preserve the behaviour const accountAddress = - detectionDetails?.selectedAddress ?? this.#selectedAddress; + detectionDetails?.selectedAddress ?? internalAccount?.address ?? ''; const { allTokens, allDetectedTokens, allIgnoredTokens } = this.state; let newTokens = [...(allTokens?.[chainId]?.[accountAddress] ?? [])]; @@ -648,9 +664,17 @@ export class TokensController extends BaseController< // We may be detecting tokens on a different chain/account pair than are currently configured. // Re-point `tokens` and `detectedTokens` to keep them referencing the current chain/account. - newTokens = newAllTokens?.[this.#chainId]?.[this.#selectedAddress] || []; + const currentInternalAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + const currentAddress = currentInternalAccount?.address || ''; + + newTokens = newAllTokens?.[this.#chainId]?.[currentAddress] || []; newDetectedTokens = - newAllDetectedTokens?.[this.#chainId]?.[this.#selectedAddress] || []; + newAllDetectedTokens?.[this.#chainId]?.[currentAddress] || []; this.update((state) => { state.tokens = newTokens; @@ -806,6 +830,12 @@ export class TokensController extends BaseController< throw rpcErrors.invalidParams(`Invalid address "${asset.address}"`); } + // Validate if account is an evm account + const selectedAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + // Validate contract if (await this.#detectIsERC721(asset.address, networkClientId)) { @@ -896,7 +926,8 @@ export class TokensController extends BaseController< id: this.#generateRandomId(), time: Date.now(), type, - interactingAddress: interactingAddress || this.#selectedAddress, + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + interactingAddress: interactingAddress || selectedAccount?.address || '', }; await this.#requestApproval(suggestedAssetMeta); @@ -940,8 +971,14 @@ export class TokensController extends BaseController< interactingChainId, } = params; const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; + const selectedInternalAccount = this.messagingSystem.call( + 'AccountsController:getAccount', + this.#selectedAccountId, + ); + // Previously selectedAddress could be an empty string. This is to preserve the behaviour + const userAddressToAddTokens = + interactingAddress ?? selectedInternalAccount?.address ?? ''; - const userAddressToAddTokens = interactingAddress ?? this.#selectedAddress; const chainIdToAddTokens = interactingChainId ?? this.#chainId; let newAllTokens = allTokens; diff --git a/packages/transaction-controller/package.json b/packages/transaction-controller/package.json index 3bc8e971216..a83eb445be2 100644 --- a/packages/transaction-controller/package.json +++ b/packages/transaction-controller/package.json @@ -52,6 +52,7 @@ "@metamask/controller-utils": "^11.0.0", "@metamask/eth-query": "^4.0.0", "@metamask/gas-fee-controller": "^17.0.0", + "@metamask/keyring-api": "6.4.0", "@metamask/metamask-eth-abis": "^3.1.1", "@metamask/network-controller": "^19.0.0", "@metamask/nonce-tracker": "^5.0.0", diff --git a/packages/transaction-controller/src/TransactionController.test.ts b/packages/transaction-controller/src/TransactionController.test.ts index 5cfbdb33fb2..5cc0ede74ae 100644 --- a/packages/transaction-controller/src/TransactionController.test.ts +++ b/packages/transaction-controller/src/TransactionController.test.ts @@ -16,6 +16,7 @@ import { } from '@metamask/controller-utils'; import EthQuery from '@metamask/eth-query'; import HttpProvider from '@metamask/ethjs-provider-http'; +import { EthAccountType } from '@metamask/keyring-api'; import type { BlockTracker, NetworkController, @@ -434,6 +435,20 @@ const MOCK_CUSTOM_NETWORK: MockNetwork = { }; const ACCOUNT_MOCK = '0x6bf137f335ea1b8f193b8f6ea92561a60d23a207'; +const INTERNAL_ACCOUNT_MOCK = { + id: '58def058-d35f-49a1-a7ab-e2580565f6f5', + address: ACCOUNT_MOCK, + type: EthAccountType.Eoa, + options: {}, + methods: [], + metadata: { + name: 'Account 1', + keyring: { type: 'HD Key Tree' }, + importTime: 1631619180000, + lastSelected: 1631619180000, + }, +}; + const ACCOUNT_2_MOCK = '0x08f137f335ea1b8f193b8f6ea92561a60d23a211'; const NONCE_MOCK = 12; const ACTION_ID_MOCK = '123456'; @@ -582,7 +597,7 @@ describe('TransactionController', () => { // eslint-disable-next-line @typescript-eslint/no-explicit-any getNetworkClientRegistry: () => ({} as any), getPermittedAccounts: async () => [ACCOUNT_MOCK], - getSelectedAddress: () => ACCOUNT_MOCK, + getSelectedAccount: () => INTERNAL_ACCOUNT_MOCK, isMultichainEnabled: false, hooks: {}, onNetworkStateChange: network.subscribe, diff --git a/packages/transaction-controller/src/TransactionController.ts b/packages/transaction-controller/src/TransactionController.ts index 06010fc3148..6abdb08df1f 100644 --- a/packages/transaction-controller/src/TransactionController.ts +++ b/packages/transaction-controller/src/TransactionController.ts @@ -25,6 +25,7 @@ import type { FetchGasFeeEstimateOptions, GasFeeState, } from '@metamask/gas-fee-controller'; +import type { InternalAccount } from '@metamask/keyring-api'; import type { BlockTracker, NetworkClientId, @@ -42,7 +43,7 @@ import type { Transaction as NonceTrackerTransaction, } from '@metamask/nonce-tracker'; import { errorCodes, rpcErrors, providerErrors } from '@metamask/rpc-errors'; -import type { Hex } from '@metamask/utils'; +import type { CaipChainId, Hex } from '@metamask/utils'; import { add0x } from '@metamask/utils'; import { Mutex } from 'async-mutex'; import { MethodRegistry } from 'eth-method-registry'; @@ -297,7 +298,7 @@ export type TransactionControllerOptions = { getNetworkState: () => NetworkState; getPermittedAccounts: (origin?: string) => Promise; getSavedGasFees?: (chainId: Hex) => SavedGasFees | undefined; - getSelectedAddress: () => string; + getSelectedAccount: () => InternalAccount; incomingTransactions?: IncomingTransactionOptions; isMultichainEnabled: boolean; isSimulationEnabled?: () => boolean; @@ -614,7 +615,9 @@ export class TransactionController extends BaseController< private readonly getPermittedAccounts: (origin?: string) => Promise; - private readonly getSelectedAddress: () => string; + private readonly getSelectedAccount: ( + chainId: CaipChainId, + ) => InternalAccount; private readonly getExternalPendingTransactions: ( address: string, @@ -733,7 +736,7 @@ export class TransactionController extends BaseController< * @param options.getNetworkState - Gets the state of the network controller. * @param options.getPermittedAccounts - Get accounts that a given origin has permissions for. * @param options.getSavedGasFees - Gets the saved gas fee config. - * @param options.getSelectedAddress - Gets the address of the currently selected account. + * @param options.getSelectedAccount - Gets the address of the currently selected account. * @param options.incomingTransactions - Configuration options for incoming transaction support. * @param options.isMultichainEnabled - Enable multichain support. * @param options.isSimulationEnabled - Whether new transactions will be automatically simulated. @@ -761,7 +764,7 @@ export class TransactionController extends BaseController< getNetworkState, getPermittedAccounts, getSavedGasFees, - getSelectedAddress, + getSelectedAccount, incomingTransactions = {}, isMultichainEnabled = false, isSimulationEnabled, @@ -802,7 +805,7 @@ export class TransactionController extends BaseController< this.getGasFeeEstimates = getGasFeeEstimates || (() => Promise.resolve({} as GasFeeState)); this.getPermittedAccounts = getPermittedAccounts; - this.getSelectedAddress = getSelectedAddress; + this.getSelectedAccount = getSelectedAccount; this.getExternalPendingTransactions = getExternalPendingTransactions ?? (() => []); this.securityProviderRequest = securityProviderRequest; @@ -1035,7 +1038,7 @@ export class TransactionController extends BaseController< if (origin) { await validateTransactionOrigin( await this.getPermittedAccounts(origin), - this.getSelectedAddress(), + this.getSelectedAccount('eip:155:*').address, txParams.from, origin, ); @@ -3430,7 +3433,7 @@ export class TransactionController extends BaseController< }): IncomingTransactionHelper { const incomingTransactionHelper = new IncomingTransactionHelper({ blockTracker, - getCurrentAccount: this.getSelectedAddress, + getCurrentAccount: this.getSelectedAccount, getLastFetchedBlockNumbers: () => this.state.lastFetchedBlockNumbers, getChainId: chainId ? () => chainId : this.getChainId.bind(this), isEnabled: this.#incomingTransactionOptions.isEnabled, diff --git a/packages/transaction-controller/src/TransactionControllerIntegration.test.ts b/packages/transaction-controller/src/TransactionControllerIntegration.test.ts index 979f88c4525..91349913c88 100644 --- a/packages/transaction-controller/src/TransactionControllerIntegration.test.ts +++ b/packages/transaction-controller/src/TransactionControllerIntegration.test.ts @@ -11,6 +11,8 @@ import { InfuraNetworkType, NetworkType, } from '@metamask/controller-utils'; +import type { InternalAccount } from '@metamask/keyring-api'; +import { EthAccountType, EthMethod } from '@metamask/keyring-api'; import { NetworkController, NetworkClientType, @@ -25,6 +27,7 @@ import assert from 'assert'; import nock from 'nock'; import type { SinonFakeTimers } from 'sinon'; import { useFakeTimers } from 'sinon'; +import { v4 } from 'uuid'; import { advanceTime } from '../../../tests/helpers'; import { mockNetwork } from '../../../tests/mock-network'; @@ -64,7 +67,46 @@ type UnrestrictedControllerMessenger = ControllerMessenger< | TransactionControllerEvents >; +const createMockInternalAccount = ({ + id = v4(), + address = '0x2990079bcdee240329a520d2444386fc119da21a', + name = 'Account 1', + importTime = Date.now(), + lastSelected = Date.now(), +}: { + id?: string; + address?: string; + name?: string; + importTime?: number; + lastSelected?: number; +} = {}): InternalAccount => { + return { + id, + address, + options: {}, + methods: [ + EthMethod.PersonalSign, + EthMethod.Sign, + EthMethod.SignTransaction, + EthMethod.SignTypedDataV1, + EthMethod.SignTypedDataV3, + EthMethod.SignTypedDataV4, + ], + type: EthAccountType.Eoa, + metadata: { + name, + keyring: { type: 'HD Key Tree' }, + importTime, + lastSelected, + }, + } as InternalAccount; +}; + const ACCOUNT_MOCK = '0x6bf137f335ea1b8f193b8f6ea92561a60d23a207'; +const INTERNAL_ACCOUNT_MOCK = createMockInternalAccount({ + address: ACCOUNT_MOCK, +}); + const ACCOUNT_2_MOCK = '0x08f137f335ea1b8f193b8f6ea92561a60d23a211'; const ACCOUNT_3_MOCK = '0xe688b84b23f322a994a53dbf8e15fa82cdb71127'; const infuraProjectId = 'fake-infura-project-id'; @@ -167,7 +209,8 @@ const setupController = async ( getNetworkClientRegistry: networkController.getNetworkClientRegistry.bind(networkController), getPermittedAccounts: async () => [ACCOUNT_MOCK], - getSelectedAddress: () => '0xdeadbeef', + getSelectedAccount: () => + createMockInternalAccount({ address: '0xdeadbeef' }), hooks: {}, isMultichainEnabled: false, messenger, @@ -802,7 +845,7 @@ describe('TransactionController Integration', () => { await setupController({ isMultichainEnabled: true, getPermittedAccounts: async () => [ACCOUNT_MOCK], - getSelectedAddress: () => ACCOUNT_MOCK, + getSelectedAccount: () => INTERNAL_ACCOUNT_MOCK, }); const otherNetworkClientIdOnGoerli = await networkController.upsertNetworkConfiguration( @@ -883,7 +926,7 @@ describe('TransactionController Integration', () => { await setupController({ isMultichainEnabled: true, getPermittedAccounts: async () => [ACCOUNT_MOCK], - getSelectedAddress: () => ACCOUNT_MOCK, + getSelectedAccount: () => INTERNAL_ACCOUNT_MOCK, }); const addTx1 = await transactionController.addTransaction( @@ -1140,10 +1183,13 @@ describe('TransactionController Integration', () => { }); const selectedAddress = ETHERSCAN_TRANSACTION_BASE_MOCK.to; + const selectedAccountMock = createMockInternalAccount({ + address: selectedAddress, + }); const { networkController, transactionController } = await setupController({ - getSelectedAddress: () => selectedAddress, + getSelectedAccount: () => selectedAccountMock, isMultichainEnabled: true, }); @@ -1209,6 +1255,9 @@ describe('TransactionController Integration', () => { it('should start the global incoming transaction helper when no networkClientIds provided', async () => { const selectedAddress = ETHERSCAN_TRANSACTION_BASE_MOCK.to; + const selectedAccountMock = createMockInternalAccount({ + address: selectedAddress, + }); mockNetwork({ networkClientConfiguration: buildInfuraNetworkClientConfiguration( @@ -1226,7 +1275,7 @@ describe('TransactionController Integration', () => { .reply(200, ETHERSCAN_TRANSACTION_RESPONSE_MOCK); const { transactionController } = await setupController({ - getSelectedAddress: () => selectedAddress, + getSelectedAccount: () => selectedAccountMock, }); transactionController.startIncomingTransactionPolling(); @@ -1314,10 +1363,13 @@ describe('TransactionController Integration', () => { }); const selectedAddress = ETHERSCAN_TRANSACTION_BASE_MOCK.to; + const selectedAccountMock = createMockInternalAccount({ + address: selectedAddress, + }); const { networkController, transactionController } = await setupController({ - getSelectedAddress: () => selectedAddress, + getSelectedAccount: () => selectedAccountMock, isMultichainEnabled: true, }); @@ -1410,10 +1462,13 @@ describe('TransactionController Integration', () => { describe('stopIncomingTransactionPolling', () => { it('should not poll for new incoming transactions for the given networkClientId', async () => { const selectedAddress = ETHERSCAN_TRANSACTION_BASE_MOCK.to; + const selectedAccountMock = createMockInternalAccount({ + address: selectedAddress, + }); const { networkController, transactionController } = await setupController({ - getSelectedAddress: () => selectedAddress, + getSelectedAccount: () => selectedAccountMock, }); const networkClients = networkController.getNetworkClientRegistry(); @@ -1454,9 +1509,12 @@ describe('TransactionController Integration', () => { it('should stop the global incoming transaction helper when no networkClientIds provided', async () => { const selectedAddress = ETHERSCAN_TRANSACTION_BASE_MOCK.to; + const selectedAccountMock = createMockInternalAccount({ + address: selectedAddress, + }); const { transactionController } = await setupController({ - getSelectedAddress: () => selectedAddress, + getSelectedAccount: () => selectedAccountMock, }); mockNetwork({ @@ -1490,10 +1548,13 @@ describe('TransactionController Integration', () => { describe('stopAllIncomingTransactionPolling', () => { it('should not poll for incoming transactions on any network client', async () => { const selectedAddress = ETHERSCAN_TRANSACTION_BASE_MOCK.to; + const selectedAccountMock = createMockInternalAccount({ + address: selectedAddress, + }); const { networkController, transactionController } = await setupController({ - getSelectedAddress: () => selectedAddress, + getSelectedAccount: () => selectedAccountMock, }); const networkClients = networkController.getNetworkClientRegistry(); @@ -1534,10 +1595,13 @@ describe('TransactionController Integration', () => { describe('updateIncomingTransactions', () => { it('should add incoming transactions to state with the correct chainId for the given networkClientId without waiting for the next block', async () => { const selectedAddress = ETHERSCAN_TRANSACTION_BASE_MOCK.to; + const selectedAccountMock = createMockInternalAccount({ + address: selectedAddress, + }); const { networkController, transactionController } = await setupController({ - getSelectedAddress: () => selectedAddress, + getSelectedAccount: () => selectedAccountMock, isMultichainEnabled: true, }); @@ -1600,9 +1664,12 @@ describe('TransactionController Integration', () => { it('should update the incoming transactions for the gloablly selected network when no networkClientIds provided', async () => { const selectedAddress = ETHERSCAN_TRANSACTION_BASE_MOCK.to; + const selectedAccountMock = createMockInternalAccount({ + address: selectedAddress, + }); const { transactionController } = await setupController({ - getSelectedAddress: () => selectedAddress, + getSelectedAccount: () => selectedAccountMock, }); mockNetwork({ diff --git a/packages/transaction-controller/src/helpers/IncomingTransactionHelper.test.ts b/packages/transaction-controller/src/helpers/IncomingTransactionHelper.test.ts index 49b39c4effc..6e65f7de1c3 100644 --- a/packages/transaction-controller/src/helpers/IncomingTransactionHelper.test.ts +++ b/packages/transaction-controller/src/helpers/IncomingTransactionHelper.test.ts @@ -32,7 +32,21 @@ const BLOCK_TRACKER_MOCK = { const CONTROLLER_ARGS_MOCK = { blockTracker: BLOCK_TRACKER_MOCK, - getCurrentAccount: () => ADDRESS_MOCK, + getCurrentAccount: () => { + return { + id: '58def058-d35f-49a1-a7ab-e2580565f6f5', + address: ADDRESS_MOCK, + type: 'eip155:eoa' as const, + options: {}, + methods: [], + metadata: { + name: 'Account 1', + keyring: { type: 'HD Key Tree' }, + importTime: 1631619180000, + lastSelected: 1631619180000, + }, + }; + }, getLastFetchedBlockNumbers: () => ({}), getChainId: () => CHAIN_ID_MOCK, remoteTransactionSource: {} as RemoteTransactionSource, @@ -546,7 +560,8 @@ describe('IncomingTransactionHelper', () => { remoteTransactionSource: createRemoteTransactionSourceMock([ TRANSACTION_MOCK_2, ]), - getCurrentAccount: () => undefined as unknown as string, + // @ts-expect-error testing undefined + getCurrentAccount: () => undefined, }); const { blockNumberListener } = await emitBlockTrackerLatestEvent( diff --git a/packages/transaction-controller/src/helpers/IncomingTransactionHelper.ts b/packages/transaction-controller/src/helpers/IncomingTransactionHelper.ts index c6600b48931..b96a777f12d 100644 --- a/packages/transaction-controller/src/helpers/IncomingTransactionHelper.ts +++ b/packages/transaction-controller/src/helpers/IncomingTransactionHelper.ts @@ -1,5 +1,6 @@ +import type { InternalAccount } from '@metamask/keyring-api'; import type { BlockTracker } from '@metamask/network-controller'; -import type { Hex } from '@metamask/utils'; +import type { CaipChainId, Hex } from '@metamask/utils'; import { Mutex } from 'async-mutex'; import EventEmitter from 'events'; @@ -7,6 +8,7 @@ import { incomingTransactionsLogger as log } from '../logger'; import type { RemoteTransactionSource, TransactionMeta } from '../types'; const RECENT_HISTORY_BLOCK_RANGE = 10; +const EVM_WILDCARD_CHAIN_ID = 'eip155:*'; // TODO: Replace `any` with type // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -35,7 +37,7 @@ export class IncomingTransactionHelper { #blockTracker: BlockTracker; - #getCurrentAccount: () => string; + #getCurrentAccount: (chainId: CaipChainId) => InternalAccount; #getLastFetchedBlockNumbers: () => Record; @@ -72,7 +74,7 @@ export class IncomingTransactionHelper { updateTransactions, }: { blockTracker: BlockTracker; - getCurrentAccount: () => string; + getCurrentAccount: (chainId: CaipChainId) => InternalAccount; getLastFetchedBlockNumbers: () => Record; getLocalTransactions?: () => TransactionMeta[]; getChainId: () => Hex; @@ -144,7 +146,7 @@ export class IncomingTransactionHelper { this.#remoteTransactionSource.getLastBlockVariations?.() ?? []; const fromBlock = this.#getFromBlock(latestBlockNumber); - const address = this.#getCurrentAccount(); + const account = this.#getCurrentAccount(EVM_WILDCARD_CHAIN_ID); const currentChainId = this.#getChainId(); let remoteTransactions = []; @@ -152,7 +154,7 @@ export class IncomingTransactionHelper { try { remoteTransactions = await this.#remoteTransactionSource.fetchTransactions({ - address, + address: account.address, currentChainId, fromBlock, limit: this.#transactionLimit, @@ -165,7 +167,8 @@ export class IncomingTransactionHelper { } if (!this.#updateTransactions) { remoteTransactions = remoteTransactions.filter( - (tx) => tx.txParams.to?.toLowerCase() === address.toLowerCase(), + (tx) => + tx.txParams.to?.toLowerCase() === account.address.toLowerCase(), ); } @@ -301,7 +304,9 @@ export class IncomingTransactionHelper { #getBlockNumberKey(additionalKeys: string[]): string { const currentChainId = this.#getChainId(); - const currentAccount = this.#getCurrentAccount()?.toLowerCase(); + const currentAccount = this.#getCurrentAccount( + EVM_WILDCARD_CHAIN_ID, + )?.address.toLowerCase(); return [currentChainId, currentAccount, ...additionalKeys].join('#'); } diff --git a/yarn.lock b/yarn.lock index 52d067820f5..ffc7ddedb27 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1742,6 +1742,8 @@ __metadata: "@metamask/polling-controller": ^8.0.0 "@metamask/preferences-controller": ^13.0.0 "@metamask/rpc-errors": ^6.2.1 + "@metamask/snaps-sdk": ^4.2.0 + "@metamask/snaps-utils": ^7.4.0 "@metamask/utils": ^8.3.0 "@types/bn.js": ^5.1.5 "@types/jest": ^27.4.1 @@ -1752,6 +1754,7 @@ __metadata: bn.js: ^5.2.1 cockatiel: ^3.1.2 deepmerge: ^4.2.2 + immer: ^9.0.6 jest: ^27.5.1 jest-environment-jsdom: ^27.5.1 lodash: ^4.17.21 @@ -1770,6 +1773,7 @@ __metadata: "@metamask/keyring-controller": ^17.0.0 "@metamask/network-controller": ^19.0.0 "@metamask/preferences-controller": ^13.0.0 + "@metamask/snaps-controllers": ^8.1.1 languageName: unknown linkType: soft @@ -2463,7 +2467,7 @@ __metadata: languageName: node linkType: hard -"@metamask/keyring-api@npm:^6.3.1, @metamask/keyring-api@npm:^6.4.0": +"@metamask/keyring-api@npm:6.4.0, @metamask/keyring-api@npm:^6.3.1, @metamask/keyring-api@npm:^6.4.0": version: 6.4.0 resolution: "@metamask/keyring-api@npm:6.4.0" dependencies: @@ -3149,6 +3153,7 @@ __metadata: "@metamask/eth-query": ^4.0.0 "@metamask/ethjs-provider-http": ^0.3.0 "@metamask/gas-fee-controller": ^17.0.0 + "@metamask/keyring-api": 6.4.0 "@metamask/metamask-eth-abis": ^3.1.1 "@metamask/network-controller": ^19.0.0 "@metamask/nonce-tracker": ^5.0.0