diff --git a/.changeset/funny-turkeys-sleep.md b/.changeset/funny-turkeys-sleep.md
new file mode 100644
index 00000000..5cb4cb23
--- /dev/null
+++ b/.changeset/funny-turkeys-sleep.md
@@ -0,0 +1,5 @@
+---
+'focus-trap-react': patch
+---
+
+Fix setReturnFocus option as function not being passed node focused prior to activation.
diff --git a/demo/js/demo-setReturnFocus.js b/demo/js/demo-setReturnFocus.js
index 25f8f40e..1f56c09d 100644
--- a/demo/js/demo-setReturnFocus.js
+++ b/demo/js/demo-setReturnFocus.js
@@ -10,7 +10,14 @@ const DemoSetReturnFocusDialog = () => {
const focusTrapOptions = useMemo(
() => ({
- setReturnFocus: '#AlternateReturnFocusElement',
+ setReturnFocus: (prevSelNode) => {
+ // contrived code to prove during tests that the setReturnFocus() option
+ // can be a function that is given a reference to the node that was focused
+ // prior to trap activation
+ return prevSelNode.parentNode.querySelector(
+ '#AlternateReturnFocusElement'
+ );
+ },
onDeactivate: () => setIsTrapActive(false),
}),
[]
diff --git a/src/focus-trap-react.js b/src/focus-trap-react.js
index 722ea1eb..20eb34ce 100644
--- a/src/focus-trap-react.js
+++ b/src/focus-trap-react.js
@@ -42,11 +42,11 @@ class FocusTrap extends React.Component {
// original options provided by the consumer
this.originalOptions = {
- // because of the above `tailoredFocusTrapOptions`, we maintain our own flag for
+ // because of the above `internalOptions`, we maintain our own flag for
// this option, and default it to `true` because that's focus-trap's default
returnFocusOnDeactivate: true,
- // because of the above `tailoredFocusTrapOptions`, we keep these separate since
+ // because of the above `internalOptions`, we keep these separate since
// they're part of the deactivation process which we configure (internally) to
// be shared between focus-trap and focus-trap-react
onDeactivate: null,
@@ -71,7 +71,7 @@ class FocusTrap extends React.Component {
optionName === 'clickOutsideDeactivates'
) {
this.originalOptions[optionName] = focusTrapOptions[optionName];
- continue; // exclude from tailoredFocusTrapOptions
+ continue; // exclude from internalOptions
}
this.internalOptions[optionName] = focusTrapOptions[optionName];
@@ -106,36 +106,63 @@ class FocusTrap extends React.Component {
);
}
- // TODO: Need more test coverage for this function
- getNodeForOption(optionName) {
- const optionValue = this.internalOptions[optionName];
- if (!optionValue) {
- return null;
+ /**
+ * Gets the node for the given option, which is expected to be an option that
+ * can be either a DOM node, a string that is a selector to get a node, `false`
+ * (if a node is explicitly NOT given), or a function that returns any of these
+ * values.
+ * @param {string} optionName
+ * @returns {undefined | false | HTMLElement | SVGElement} Returns
+ * `undefined` if the option is not specified; `false` if the option
+ * resolved to `false` (node explicitly not given); otherwise, the resolved
+ * DOM node.
+ * @throws {Error} If the option is set, not `false`, and is not, or does not
+ * resolve to a node.
+ */
+ getNodeForOption = function (optionName, ...params) {
+ // use internal options first, falling back to original options
+ let optionValue =
+ this.internalOptions[optionName] ?? this.originalOptions[optionName];
+
+ if (typeof optionValue === 'function') {
+ optionValue = optionValue(...params);
}
- let node = optionValue;
+ if (optionValue === true) {
+ optionValue = undefined; // use default value
+ }
- if (typeof optionValue === 'string') {
- node = this.getDocument()?.querySelector(optionValue);
- if (!node) {
- throw new Error(`\`${optionName}\` refers to no known node`);
+ if (!optionValue) {
+ if (optionValue === undefined || optionValue === false) {
+ return optionValue;
}
+ // else, empty string (invalid), null (invalid), 0 (invalid)
+
+ throw new Error(
+ `\`${optionName}\` was specified but was not a node, or did not return a node`
+ );
}
- if (typeof optionValue === 'function') {
- node = optionValue();
+ let node = optionValue; // could be HTMLElement, SVGElement, or non-empty string at this point
+
+ if (typeof optionValue === 'string') {
+ node = this.getDocument()?.querySelector(optionValue); // resolve to node, or null if fails
if (!node) {
- throw new Error(`\`${optionName}\` did not return a node`);
+ throw new Error(
+ `\`${optionName}\` as selector refers to no known node`
+ );
}
}
return node;
- }
+ };
getReturnFocusNode() {
- const node = this.getNodeForOption('setReturnFocus');
-
- return node ? node : this.previouslyFocusedElement;
+ const node = this.getNodeForOption(
+ 'setReturnFocus',
+ this.previouslyFocusedElement
+ );
+ return node ? node : node === false ? false : this.previouslyFocusedElement;
}
/** Update the previously focused element with the currently focused element. */
@@ -395,12 +422,13 @@ FocusTrap.propTypes = {
initialFocus: PropTypes.oneOfType([
PropTypes.instanceOf(ElementType),
PropTypes.string,
- PropTypes.func,
PropTypes.bool,
+ PropTypes.func,
]),
fallbackFocus: PropTypes.oneOfType([
PropTypes.instanceOf(ElementType),
PropTypes.string,
+ // NOTE: does not support `false` as value (or return value from function)
PropTypes.func,
]),
escapeDeactivates: PropTypes.oneOfType([PropTypes.bool, PropTypes.func]),
@@ -412,6 +440,7 @@ FocusTrap.propTypes = {
setReturnFocus: PropTypes.oneOfType([
PropTypes.instanceOf(ElementType),
PropTypes.string,
+ PropTypes.bool,
PropTypes.func,
]),
allowOutsideClick: PropTypes.oneOfType([PropTypes.bool, PropTypes.func]),
diff --git a/test/focus-trap-react.test.js b/test/focus-trap-react.test.js
index d5883474..408a371b 100644
--- a/test/focus-trap-react.test.js
+++ b/test/focus-trap-react.test.js
@@ -76,6 +76,7 @@ FocusTrapExample.propTypes = {
describe('FocusTrap', () => {
let TestFocusTrap;
+ let user;
beforeEach(() => {
// This surpresses React error boundary logs for testing intentionally
@@ -84,6 +85,8 @@ describe('FocusTrap', () => {
jest.spyOn(console, 'error');
global.console.error.mockImplementation(() => {});
TestFocusTrap = mkTestFocusTrap();
+
+ user = userEvent.setup();
});
afterEach(() => {
@@ -221,29 +224,29 @@ describe('FocusTrap', () => {
const deactivateTrapButton = screen.getByText('deactivate trap');
// Tabbing forward through the focus trap and wrapping back to the beginning
- await userEvent.tab();
+ await user.tab();
expect(link2).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(link3).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(deactivateTrapButton).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(link1).toHaveFocus();
// Tabbing backward through the focus trap and wrapping back to the beginning
- await userEvent.tab({ shift: true });
+ await user.tab({ shift: true });
expect(deactivateTrapButton).toHaveFocus();
- await userEvent.tab({ shift: true });
+ await user.tab({ shift: true });
expect(link3).toHaveFocus();
- await userEvent.tab({ shift: true });
+ await user.tab({ shift: true });
expect(link2).toHaveFocus();
- await userEvent.tab({ shift: true });
+ await user.tab({ shift: true });
expect(link1).toHaveFocus();
});
@@ -567,6 +570,145 @@ describe('FocusTrap', () => {
});
});
});
+
+ describe('#setReturnFocus', () => {
+ const altTargetTestId = 'alt-target';
+ let preExistingTargetEl;
+
+ beforeEach(() => {
+ preExistingTargetEl = document.createElement('button');
+ preExistingTargetEl.id = altTargetTestId;
+ preExistingTargetEl.setAttribute('data-testid', altTargetTestId);
+ document.body.append(preExistingTargetEl);
+ });
+
+ afterEach(() => {
+ document.body.removeChild(preExistingTargetEl);
+ });
+
+ it('Can be a selector string', async () => {
+ render(
+
+ );
+
+ const activateEl = screen.getByText('activate trap');
+ await user.click(activateEl);
+
+ const deactivateEl = screen.getByText('deactivate trap');
+ await user.click(deactivateEl);
+
+ expect(screen.getByTestId(altTargetTestId)).toHaveFocus();
+ });
+
+ it('Can be an element', async () => {
+ render(
+
+ );
+
+ const activateEl = screen.getByText('activate trap');
+ await user.click(activateEl);
+
+ const deactivateEl = screen.getByText('deactivate trap');
+ await user.click(deactivateEl);
+
+ expect(preExistingTargetEl).toHaveFocus();
+ });
+
+ it('Can be false for no focus return', async () => {
+ render(
+
+ );
+
+ const activateEl = screen.getByText('activate trap');
+ await user.click(activateEl);
+
+ const deactivateEl = screen.getByText('deactivate trap');
+ await user.click(deactivateEl);
+
+ expect(activateEl).not.toHaveFocus();
+ expect(screen.getByTestId(altTargetTestId)).not.toHaveFocus();
+ expect(document.activeElement === document.body).toBeTruthy();
+ });
+
+ it('Can be function that returns selector string', async () => {
+ const handler = jest.fn(() => `#${altTargetTestId}`);
+ render(
+
+ );
+
+ const activateEl = screen.getByText('activate trap');
+ await user.click(activateEl);
+
+ const deactivateEl = screen.getByText('deactivate trap');
+ await user.click(deactivateEl);
+
+ expect(handler).toHaveBeenCalledTimes(1);
+ expect(handler.mock.calls[0][0] === activateEl).toBeTruthy();
+ expect(screen.getByTestId(altTargetTestId)).toHaveFocus();
+ });
+
+ it('Can be function that returns an element', async () => {
+ const handler = jest.fn(() =>
+ document.querySelector(`#${altTargetTestId}`)
+ );
+ render(
+
+ );
+
+ const activateEl = screen.getByText('activate trap');
+ await user.click(activateEl);
+
+ const deactivateEl = screen.getByText('deactivate trap');
+ await user.click(deactivateEl);
+
+ expect(handler).toHaveBeenCalledTimes(1);
+ expect(handler.mock.calls[0][0] === activateEl).toBeTruthy();
+ expect(screen.getByTestId(altTargetTestId)).toHaveFocus();
+ });
+
+ it('Can be function that returns false for no focus return', async () => {
+ const handler = jest.fn(() => false);
+ render(
+
+ );
+
+ const activateEl = screen.getByText('activate trap');
+ await user.click(activateEl);
+
+ const deactivateEl = screen.getByText('deactivate trap');
+ await user.click(deactivateEl);
+
+ expect(handler).toHaveBeenCalledTimes(1);
+ expect(handler.mock.calls[0][0] === activateEl).toBeTruthy();
+ expect(activateEl).not.toHaveFocus();
+ expect(screen.getByTestId(altTargetTestId)).not.toHaveFocus();
+ expect(document.activeElement === document.body).toBeTruthy();
+ });
+ });
});
describe('containerElements prop', () => {
@@ -600,17 +742,17 @@ describe('FocusTrap', () => {
});
// Tabbing forward through the focus trap and wrapping back to the beginning
- await userEvent.tab();
+ await user.tab();
expect(anchor2).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(anchor1).toHaveFocus();
// Tabbing backward through the focus trap and wrapping back to the beginning
- await userEvent.tab({ shift: true });
+ await user.tab({ shift: true });
expect(anchor2).toHaveFocus();
- await userEvent.tab({ shift: true });
+ await user.tab({ shift: true });
expect(anchor1).toHaveFocus();
// DOM cleanup
@@ -634,20 +776,20 @@ describe('FocusTrap', () => {
// Does not activate the focus trap or change which element is currently focused
expect(activateTrapButton).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(screen.getByText('Link 1')).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(screen.getByText('Link 2')).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(screen.getByText('Link 3')).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(screen.getByText('deactivate trap')).toHaveFocus();
// Because the focus trap is not activated, the tab order continues past the trap content
- await userEvent.tab();
+ await user.tab();
expect(screen.getByText('after trap content')).toHaveFocus();
});
@@ -711,19 +853,19 @@ describe('FocusTrap', () => {
// Does not activate the focus trap or change which element is currently focused
expect(activateTrapButton).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(screen.getByText('Link 1')).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(screen.getByText('Link 2')).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(screen.getByText('Link 3')).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(screen.getByText('deactivate trap')).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
const useTwoContainerElementsButton = screen.getByText(
'use two container elements'
);
@@ -743,17 +885,17 @@ describe('FocusTrap', () => {
});
// Tabbing forward through the focus trap and wrapping back to the beginning
- await userEvent.tab();
+ await user.tab();
expect(anchor2).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(anchor1).toHaveFocus();
// Tabbing backward through the focus trap and wrapping back to the beginning
- await userEvent.tab({ shift: true });
+ await user.tab({ shift: true });
expect(anchor2).toHaveFocus();
- await userEvent.tab({ shift: true });
+ await user.tab({ shift: true });
expect(anchor1).toHaveFocus();
// Updates the containerElements prop value to contain zero elements,
@@ -829,28 +971,28 @@ describe('FocusTrap', () => {
const afterTrapContentButton = screen.getByText('after trap content');
// Tab through the page while the trap is deactivated
- await userEvent.tab();
+ await user.tab();
expect(activateTrapButton).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(beforeTrapContentButton).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(link1).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(link2).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(link3).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(deactivateTrapButton).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(afterTrapContentButton).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(document.body).toHaveFocus();
// Activate the focus trap
@@ -863,16 +1005,16 @@ describe('FocusTrap', () => {
});
// Tab through the page while the trap is activated
- await userEvent.tab();
+ await user.tab();
expect(link2).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(link3).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(deactivateTrapButton).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(link1).toHaveFocus();
// Deactivate the focus trap
@@ -955,25 +1097,25 @@ describe('FocusTrap', () => {
});
// Tab through the page while the trap is activated
- await userEvent.tab();
+ await user.tab();
expect(link2).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(link3).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(deactivateTrapButton).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(pauseTrapButton).toHaveFocus();
// Pause the focus trap
fireEvent.click(pauseTrapButton);
- await userEvent.tab();
+ await user.tab();
expect(afterTrapContentButton).toHaveFocus();
- await userEvent.tab();
+ await user.tab();
expect(unpauseTrapButton).toHaveFocus();
// Unpause the focus trap