diff --git a/examples/react-quickstart/src/components/App.tsx b/examples/react-quickstart/src/components/App.tsx index ca727202..0e9499c0 100644 --- a/examples/react-quickstart/src/components/App.tsx +++ b/examples/react-quickstart/src/components/App.tsx @@ -10,7 +10,7 @@ const App = () => { // disconnect XMTP client when the wallet changes useEffect(() => { - disconnect(); + void disconnect(); // eslint-disable-next-line react-hooks/exhaustive-deps }, [signer]); diff --git a/packages/react-sdk/src/contexts/XMTPContext.tsx b/packages/react-sdk/src/contexts/XMTPContext.tsx index eababc49..27f6f3e5 100644 --- a/packages/react-sdk/src/contexts/XMTPContext.tsx +++ b/packages/react-sdk/src/contexts/XMTPContext.tsx @@ -1,5 +1,5 @@ -import { createContext, useMemo } from "react"; -import type { ContentCodec } from "@xmtp/xmtp-js"; +import { createContext, useMemo, useState } from "react"; +import type { Client, ContentCodec, Signer } from "@xmtp/xmtp-js"; import Dexie from "dexie"; import type { CacheConfiguration, @@ -11,6 +11,10 @@ import { combineMessageProcessors } from "@/helpers/combineMessageProcessors"; import { combineCodecs } from "@/helpers/combineCodecs"; export type XMTPContextValue = { + /** + * The XMTP client instance + */ + client?: Client; /** * Content codecs used by the XMTP client */ @@ -27,6 +31,12 @@ export type XMTPContextValue = { * Message processors for caching */ processors: CachedMessageProcessors; + setClient: React.Dispatch>; + setClientSigner: React.Dispatch>; + /** + * The signer (wallet) to associate with the XMTP client + */ + signer?: Signer | null; }; const initialDb = new Dexie("__XMTP__"); @@ -36,9 +46,15 @@ export const XMTPContext = createContext({ db: initialDb, namespaces: {}, processors: {}, + setClient: () => {}, + setClientSigner: () => {}, }); export type XMTPProviderProps = React.PropsWithChildren & { + /** + * Initial XMTP client instance + */ + client?: Client; /** * An array of cache configurations to support the caching of messages */ @@ -54,9 +70,15 @@ export type XMTPProviderProps = React.PropsWithChildren & { export const XMTPProvider: React.FC = ({ children, + client: initialClient, cacheConfig, dbVersion, }) => { + const [client, setClient] = useState(initialClient); + const [clientSigner, setClientSigner] = useState( + undefined, + ); + // combine all processors into a single object const processors = useMemo( () => combineMessageProcessors(cacheConfig ?? []), @@ -86,12 +108,16 @@ export const XMTPProvider: React.FC = ({ // memo-ize the context value to prevent unnecessary re-renders const value = useMemo( () => ({ + client, codecs, db, namespaces, processors, + setClient, + setClientSigner, + signer: clientSigner, }), - [codecs, db, namespaces, processors], + [client, clientSigner, codecs, db, namespaces, processors], ); return {children}; diff --git a/packages/react-sdk/src/hooks/useClient.test.tsx b/packages/react-sdk/src/hooks/useClient.test.tsx new file mode 100644 index 00000000..1c513291 --- /dev/null +++ b/packages/react-sdk/src/hooks/useClient.test.tsx @@ -0,0 +1,156 @@ +import { it, expect, describe, vi, beforeEach } from "vitest"; +import { act, renderHook, waitFor } from "@testing-library/react"; +import { Client } from "@xmtp/xmtp-js"; +import { Wallet } from "ethers"; +import type { PropsWithChildren } from "react"; +import { useClient } from "@/hooks/useClient"; +import { XMTPProvider } from "@/contexts/XMTPContext"; + +const processUnprocessedMessagesMock = vi.hoisted(() => vi.fn()); + +const TestWrapper: React.FC = ({ + children, + client, +}) => {children}; + +vi.mock("@/helpers/caching/messages", async () => { + const actual = await import("@/helpers/caching/messages"); + return { + ...actual, + processUnprocessedMessages: processUnprocessedMessagesMock, + }; +}); + +describe("useClient", () => { + beforeEach(() => { + processUnprocessedMessagesMock.mockReset(); + }); + + it("should disconnect an active client", async () => { + const disconnectClientMock = vi.fn(); + const mockClient = { + close: disconnectClientMock, + }; + const { result } = renderHook(() => useClient(), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + expect(result.current.client).toBeDefined(); + + await act(async () => { + await result.current.disconnect(); + }); + + expect(disconnectClientMock).toHaveBeenCalledTimes(1); + expect(result.current.client).toBeUndefined(); + }); + + it("should not initialize a client if one is already active", async () => { + const mockClient = { + address: "testWalletAddress", + }; + const clientCreateSpy = vi.spyOn(Client, "create"); + const testWallet = Wallet.createRandom(); + + const { result } = renderHook(() => useClient(), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + await act(async () => { + await result.current.initialize({ signer: testWallet }); + }); + + expect(clientCreateSpy).not.toHaveBeenCalled(); + + await waitFor(() => { + expect(processUnprocessedMessagesMock).toBeCalledTimes(1); + }); + }); + + it("should initialize a client if one is not active", async () => { + const testWallet = Wallet.createRandom(); + const mockClient = { + address: "testWalletAddress", + } as unknown as Client; + const clientCreateSpy = vi + .spyOn(Client, "create") + .mockResolvedValue(mockClient); + + const { result } = renderHook(() => useClient(), { + wrapper: ({ children }) => {children}, + }); + + await act(async () => { + await result.current.initialize({ signer: testWallet }); + }); + + expect(clientCreateSpy).toHaveBeenCalledWith(testWallet, { + codecs: [], + privateKeyOverride: undefined, + }); + expect(result.current.client).toBe(mockClient); + expect(result.current.signer).toBe(testWallet); + + await waitFor(() => { + expect(processUnprocessedMessagesMock).toHaveBeenCalledTimes(1); + }); + }); + + it("should throw an error if client initialization fails", async () => { + const testWallet = Wallet.createRandom(); + const testError = new Error("testError"); + vi.spyOn(Client, "create").mockRejectedValue(testError); + const onErrorMock = vi.fn(); + + const { result } = renderHook(() => useClient(onErrorMock)); + + await act(async () => { + await expect( + result.current.initialize({ signer: testWallet }), + ).rejects.toThrow(testError); + }); + + expect(onErrorMock).toBeCalledTimes(1); + expect(onErrorMock).toHaveBeenCalledWith(testError); + expect(result.current.client).toBeUndefined(); + expect(result.current.signer).toBeUndefined(); + expect(result.current.error).toEqual(testError); + }); + + it("should should call the onError callback if processing unprocessed messages fails", async () => { + const testWallet = Wallet.createRandom(); + const testError = new Error("testError"); + const mockClient = { + address: "testWalletAddress", + } as unknown as Client; + const onErrorMock = vi.fn(); + vi.spyOn(Client, "create").mockResolvedValue(mockClient); + processUnprocessedMessagesMock.mockRejectedValue(testError); + + const { result } = renderHook(() => useClient(onErrorMock), { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + await act(async () => { + await result.current.initialize({ signer: testWallet }); + }); + + await waitFor(() => { + expect(onErrorMock).toHaveBeenCalledTimes(1); + expect(onErrorMock).toHaveBeenCalledWith(testError); + expect(result.current.error).toBe(null); + }); + }); +}); diff --git a/packages/react-sdk/src/hooks/useClient.ts b/packages/react-sdk/src/hooks/useClient.ts index 3e1372b0..279eef0e 100644 --- a/packages/react-sdk/src/hooks/useClient.ts +++ b/packages/react-sdk/src/hooks/useClient.ts @@ -1,21 +1,147 @@ -/* c8 ignore start */ -import { useContext } from "react"; +import { useCallback, useContext, useEffect, useRef, useState } from "react"; +import type { ClientOptions, Signer } from "@xmtp/xmtp-js"; +import { Client } from "@xmtp/xmtp-js"; import { XMTPContext } from "../contexts/XMTPContext"; +import type { OnError } from "@/sharedTypes"; +import { processUnprocessedMessages } from "@/helpers/caching/messages"; + +export type InitClientArgs = { + /** + * Provide a XMTP PrivateKeyBundle encoded as a Uint8Array for signing + * + * This is required if `signer` is not specified + */ + keys?: Uint8Array; + /** + * XMTP client options + */ + options?: Partial; + /** + * The signer (wallet) to associate with the XMTP client + */ + signer?: Signer | null; +}; /** * This hook allows you to initialize, disconnect, and access the XMTP client * instance. It also exposes the error and loading states of the client. */ -export const useClient = () => { - const xmtpContext = useContext(XMTPContext); +export const useClient = (onError?: OnError["onError"]) => { + const [isLoading, setIsLoading] = useState(false); + const [error, setError] = useState(null); + // client is initializing + const initializingRef = useRef(false); + // unprocessed messages are being processed + const processingRef = useRef(false); + // unprocessed messages have been processed + const processedRef = useRef(false); + + const { + client, + setClient, + setClientSigner, + signer: clientSigner, + codecs, + db, + processors, + namespaces, + } = useContext(XMTPContext); + + /** + * Initialize an XMTP client + */ + const initialize = useCallback( + async ({ keys, options, signer }: InitClientArgs) => { + // only initialize a client if one doesn't already exist + if (!client && signer) { + // if the client is already initializing, don't do anything + if (initializingRef.current) { + return undefined; + } + + // flag the client as initializing + initializingRef.current = true; + + // reset error state + setError(null); + // reset loading state + setIsLoading(true); + + let xmtpClient: Client; + + try { + // create a new XMTP client with the provided keys, or a wallet + xmtpClient = await Client.create(keys ? null : signer, { + ...options, + codecs: [...(options?.codecs ?? []), ...codecs], + privateKeyOverride: keys, + }); + setClient(xmtpClient); + setClientSigner(signer); + } catch (e) { + setClient(undefined); + setClientSigner(undefined); + setError(e); + onError?.(e); + // re-throw error for upstream consumption + throw e; + } + + setIsLoading(false); + initializingRef.current = false; + + return xmtpClient; + } + return client; + }, + [client, codecs, onError, setClient, setClientSigner], + ); + + /** + * Disconnect the XMTP client + */ + const disconnect = useCallback(async () => { + if (client) { + await client.close(); + setClient(undefined); + setClientSigner(undefined); + } + }, [client, setClient, setClientSigner]); + + /** + * Process unprocessed messages when there's an available client, but only + * do it once + */ + useEffect(() => { + if (client && !processingRef.current && !processedRef.current) { + processingRef.current = true; + setIsLoading(true); + const reprocess = async () => { + try { + await processUnprocessedMessages({ + client, + db, + processors, + namespaces, + }); + processedRef.current = true; + } catch (e) { + onError?.(e); + } finally { + processingRef.current = false; + setIsLoading(false); + } + }; + void reprocess(); + } + }, [client, db, namespaces, onError, processors]); return { - client: xmtpContext.client, - disconnect: xmtpContext.closeClient, - error: xmtpContext.error, - initialize: xmtpContext.initClient, - isLoading: xmtpContext.isLoading, - signer: xmtpContext.signer, + client, + disconnect, + error, + initialize, + isLoading, + signer: clientSigner, }; }; -/* c8 ignore stop */