diff --git a/src/connect.js b/src/connect.js index 76ecaad..20598fe 100644 --- a/src/connect.js +++ b/src/connect.js @@ -1,5 +1,5 @@ /* eslint-disable import/no-unresolved, import/extensions */ -import React from 'react'; +import React, { forwardRef } from 'react'; import { connect } from 'react-redux'; import hoistStatics from 'hoist-non-react-statics'; /* eslint-enable import/no-unresolved, import/extensions */ @@ -39,7 +39,11 @@ const easyConnect = injectProps => ( } : mapDispatchToProps; - return connect(mapStateToProps, modifiedMapDispatchToProps, ...otherArgs); + return connect( + mapStateToProps, + modifiedMapDispatchToProps, + ...otherArgs, + ); }; export default (...args) => (WrappedComponent) => { @@ -73,30 +77,34 @@ export default (...args) => (WrappedComponent) => { ); }; - getWrappedInstance = () => - this.innerRef && this.innerRef.getWrappedInstance - ? this.innerRef.getWrappedInstance() - : null; - render() { const { ConnectedComponent } = this; + // eslint-disable-next-line react/prop-types + const { forwardedRef, ...otherProps } = this.props; const passedProps = { - ...this.props, + ...otherProps, ...this.state, - ref: (ref) => { - this.innerRef = ref; - }, + ref: forwardedRef, }; - return ( - - ); + return ; } } EasyConnect.displayName = getDisplayName(wrappedComponentName); + EasyConnect.WrappedComponent = WrappedComponent; + + if (args[3] && args[3].forwardRef) { + // eslint-disable-next-line react/no-multi-comp + const forwarded = forwardRef((props, ref) => ( + + )); + + forwarded.displayName = wrappedComponentName; + forwarded.WrappedComponent = WrappedComponent; + + return hoistStatics(forwarded, WrappedComponent); + } return hoistStatics(EasyConnect, WrappedComponent); };