Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/query-core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export {
noop,
partialMatchKey,
replaceEqualDeep,
resolveEnabled,
shouldThrowError,
skipToken,
} from './utils'
Expand Down
29 changes: 26 additions & 3 deletions packages/query-core/src/queryObserver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ export class QueryObserver<
isRefetchError: isError && hasData,
isStale: isStale(query, options),
refetch: this.refetch,
promise: this.#currentThenable,
promise: tagThenable(this.#currentThenable, query.queryHash),
isEnabled: resolveEnabled(options.enabled, query) !== false,
}

Expand All @@ -608,7 +608,7 @@ export class QueryObserver<
const pending =
(this.#currentThenable =
nextResult.promise =
pendingThenable())
tagThenable(pendingThenable<TData>(), query.queryHash))

finalizeThenableIfPossible(pending)
}
Expand All @@ -628,7 +628,11 @@ export class QueryObserver<
}
break
case 'rejected':
if (!isErrorWithoutData || nextResult.error !== prevThenable.reason) {
if (
!isErrorWithoutData ||
nextResult.error !== prevThenable.reason ||
nextResult.fetchStatus === 'fetching'
) {
recreateThenable()
}
break
Expand Down Expand Up @@ -826,3 +830,22 @@ function shouldAssignObserverCurrentProperties<
// basically, just keep previous properties if nothing changed
return false
}

function tagThenable<TThenable extends Thenable<any>>(
thenable: TThenable,
queryHash: string,
): TThenable {
if (!Object.prototype.hasOwnProperty.call(thenable, 'queryHash')) {
Object.defineProperty(thenable, 'queryHash', {
value: queryHash,
enumerable: false,
configurable: true,
})
}
return thenable
}

/**
* @internal
*/
export type PromiseWithHash<T> = Promise<T> & { queryHash?: string }
55 changes: 54 additions & 1 deletion packages/react-query/src/QueryErrorResetBoundary.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
'use client'
import * as React from 'react'

import { useQueryClient } from './QueryClientProvider'

// CONTEXT
export type QueryErrorResetFunction = () => void
export type QueryErrorIsResetFunction = () => boolean
Expand All @@ -10,6 +12,7 @@ export interface QueryErrorResetBoundaryValue {
clearReset: QueryErrorClearResetFunction
isReset: QueryErrorIsResetFunction
reset: QueryErrorResetFunction
register: (queryHash: string) => void
}

function createValue(): QueryErrorResetBoundaryValue {
Expand All @@ -24,6 +27,7 @@ function createValue(): QueryErrorResetBoundaryValue {
isReset: () => {
return isReset
},
register: () => {},
}
}

Expand All @@ -47,10 +51,59 @@ export interface QueryErrorResetBoundaryProps {
export const QueryErrorResetBoundary = ({
children,
}: QueryErrorResetBoundaryProps) => {
const [value] = React.useState(() => createValue())
const client = useQueryClient()
const registeredQueries = React.useRef(new Set<string>())
const [value] = React.useState(() => {
const boundary = createValue()
return {
...boundary,
reset: () => {
boundary.reset()
const queryHashes = new Set(registeredQueries.current)
registeredQueries.current.clear()

void client.refetchQueries({
predicate: (query) =>
queryHashes.has(query.queryHash) && query.state.status === 'error',
type: 'active',
})
},
register: (queryHash: string) => {
registeredQueries.current.add(queryHash)
},
}
})
return (
<QueryErrorResetBoundaryContext.Provider value={value}>
{typeof children === 'function' ? children(value) : children}
</QueryErrorResetBoundaryContext.Provider>
)
}

/**
* @internal
*/
export function getQueryHash(query: any): string | undefined {
if (typeof query === 'object' && query !== null) {
if ('queryHash' in query) {
return query.queryHash
}
if (
'promise' in query &&
query.promise &&
typeof query.promise === 'object' &&
'queryHash' in query.promise
) {
return query.promise.queryHash
}
}
return undefined
}

export function useTrackQueryHash(query: any) {
const { register } = useQueryErrorResetBoundary()
const hash = getQueryHash(query)
if (hash) {
register(hash)
}
}
237 changes: 237 additions & 0 deletions packages/react-query/src/__tests__/QueryResetErrorBoundary.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
useQuery,
useSuspenseQueries,
useSuspenseQuery,
useTrackQueryHash,
} from '..'
import { renderWithClient } from './utils'

Expand Down Expand Up @@ -863,4 +864,240 @@ describe('QueryErrorResetBoundary', () => {
consoleMock.mockRestore()
})
})

describe('Scoped Registry', () => {
it('should isolate resets between different boundaries', async () => {
const consoleMock = vi
.spyOn(console, 'error')
.mockImplementation(() => undefined)
const key1 = queryKey()
const key2 = queryKey()
let count1 = 0
let count2 = 0

function Comp1() {
useQuery({
queryKey: key1,
queryFn: async () => {
await sleep(10)
count1++
throw new Error('fail1')
},
retry: false,
throwOnError: true,
})
return null
}

function Comp2() {
useQuery({
queryKey: key2,
queryFn: async () => {
await sleep(10)
count2++
throw new Error('fail2')
},
retry: false,
throwOnError: true,
})
return null
}

const rendered = renderWithClient(
queryClient,
<>
<QueryErrorResetBoundary>
{({ reset }) => (
<ErrorBoundary
onReset={reset}
fallbackRender={({ resetErrorBoundary }) => (
<div>
<button onClick={resetErrorBoundary}>reset1</button>
</div>
)}
>
<React.Suspense fallback="loading1">
<Comp1 />
</React.Suspense>
</ErrorBoundary>
)}
</QueryErrorResetBoundary>
<QueryErrorResetBoundary>
{({ reset }) => (
<ErrorBoundary
onReset={reset}
fallbackRender={({ resetErrorBoundary }) => (
<div>
<button onClick={resetErrorBoundary}>reset2</button>
</div>
)}
>
<React.Suspense fallback="loading2">
<Comp2 />
</React.Suspense>
</ErrorBoundary>
)}
</QueryErrorResetBoundary>
</>,
)

await vi.advanceTimersByTimeAsync(11)
expect(rendered.getByText('reset1')).toBeInTheDocument()
expect(rendered.getByText('reset2')).toBeInTheDocument()
expect(count1).toBe(1)
expect(count2).toBe(1)

fireEvent.click(rendered.getByText('reset1'))

await vi.advanceTimersByTimeAsync(11)
expect(count1).toBe(2)
expect(count2).toBe(1)

consoleMock.mockRestore()
})

it('should clear registry after reset', async () => {
const consoleMock = vi
.spyOn(console, 'error')
.mockImplementation(() => undefined)
const key = queryKey()
let count = 0

function Comp() {
useQuery({
queryKey: key,
queryFn: async () => {
await sleep(10)
count++
throw new Error('fail')
},
retry: false,
throwOnError: true,
})
return null
}

const rendered = renderWithClient(
queryClient,
<QueryErrorResetBoundary>
{({ reset }) => (
<ErrorBoundary
onReset={reset}
fallbackRender={({ resetErrorBoundary }) => (
<div>
<button onClick={resetErrorBoundary}>reset</button>
</div>
)}
>
<React.Suspense fallback="loading">
<Comp />
</React.Suspense>
</ErrorBoundary>
)}
</QueryErrorResetBoundary>,
)

await vi.advanceTimersByTimeAsync(11)
expect(rendered.getByText('reset')).toBeInTheDocument()
expect(count).toBe(1)

fireEvent.click(rendered.getByText('reset'))
await vi.advanceTimersByTimeAsync(11)
expect(count).toBe(2)

consoleMock.mockRestore()
})

it('should handle StrictMode double registration gracefully', async () => {
const key = queryKey()
let count = 0

function Comp() {
useQuery({
queryKey: key,
queryFn: async () => {
await sleep(10)
count++
return 'ok'
},
})
return null
}

renderWithClient(
queryClient,
<React.StrictMode>
<QueryErrorResetBoundary>
<Comp />
</QueryErrorResetBoundary>
</React.StrictMode>,
)

await vi.advanceTimersByTimeAsync(11)
expect(count).toBeGreaterThanOrEqual(1)
})

it('should support tracking queries outside the boundary via useTrackQueryHash', async () => {
const consoleMock = vi
.spyOn(console, 'error')
.mockImplementation(() => undefined)
const key = queryKey()
let count = 0

function Child() {
const { data } = useSuspenseQuery({
queryKey: key,
queryFn: async () => {
await sleep(10)
count++
if (count === 1) {
throw new Error('fail')
}
return 'ok'
},
retry: false,
})
return <div>{data}</div>
}

function TrackedChild() {
const hash = queryClient
.getQueryCache()
.build(queryClient, { queryKey: key }).queryHash
useTrackQueryHash({ queryHash: hash })
return null
}
}

const rendered = renderWithClient(
queryClient,
<QueryErrorResetBoundary>
{({ reset }) => (
<ErrorBoundary
onReset={reset}
fallbackRender={({ resetErrorBoundary }) => (
<button onClick={resetErrorBoundary}>retry</button>
)}
>
<React.Suspense fallback="loading">
<TrackedChild />
<Child />
</React.Suspense>
</ErrorBoundary>
)}
</QueryErrorResetBoundary>,
)

await act(() => vi.advanceTimersByTimeAsync(11))
expect(rendered.getByText('retry')).toBeInTheDocument()
expect(count).toBe(1)

fireEvent.click(rendered.getByText('retry'))
await act(() => vi.advanceTimersByTimeAsync(11))
expect(count).toBe(2)
expect(rendered.getByText('ok')).toBeInTheDocument()

consoleMock.mockRestore()
})
})
})
Loading
Loading