diff --git a/copulas/multivariate/gaussian.py b/copulas/multivariate/gaussian.py index a929d7ba..1dd7d657 100644 --- a/copulas/multivariate/gaussian.py +++ b/copulas/multivariate/gaussian.py @@ -70,26 +70,6 @@ def _transform_to_normal(self, X): return stats.norm.ppf(np.column_stack(U)) - def _get_correlation(self, X): - """Compute correlation matrix with transformed data. - - Args: - X (numpy.ndarray): - Data for which the correlation needs to be computed. - - Returns: - numpy.ndarray: - computed correlation matrix. - """ - result = self._transform_to_normal(X) - correlation = pd.DataFrame(data=result).corr().to_numpy() - correlation = np.nan_to_num(correlation, nan=0.0) - # If singular, add some noise to the diagonal - if np.linalg.cond(correlation) > 1.0 / sys.float_info.epsilon: - correlation = correlation + np.identity(correlation.shape[0]) * EPSILON - - return pd.DataFrame(correlation, index=self.columns, columns=self.columns) - @check_valid_values def fit(self, X): """Compute the distribution for each variable and then its correlation matrix. @@ -100,42 +80,88 @@ def fit(self, X): """ LOGGER.info('Fitting %s', self) + # Validate the input data + X = self._validate_input(X) + columns, univariates = self._fit_columns(X) + + self.columns = columns + self.univariates = univariates + + LOGGER.debug('Computing correlation.') + self.correlation = self._get_correlation(X) + self.fitted = True + LOGGER.debug('GaussianMultivariate fitted successfully') + + def _validate_input(self, X): + """Validate the input data.""" if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X) + return X + + def _fit_columns(self, X): + """Fit each column to its distribution.""" columns = [] univariates = [] for column_name, column in X.items(): - if isinstance(self.distribution, dict): - distribution = self.distribution.get(column_name, DEFAULT_DISTRIBUTION) - else: - distribution = self.distribution - + distribution = self._get_distribution_for_column(column_name) LOGGER.debug('Fitting column %s to %s', column_name, distribution) - univariate = get_instance(distribution) - try: - univariate.fit(column) - except BaseException: - log_message = ( - f'Unable to fit to a {distribution} distribution for column {column_name}. ' - 'Using a Gaussian distribution instead.' - ) - LOGGER.info(log_message) - univariate = GaussianUnivariate() - univariate.fit(column) - + univariate = self._fit_column(column, distribution, column_name) columns.append(column_name) univariates.append(univariate) - self.columns = columns - self.univariates = univariates + return columns, univariates + + def _get_distribution_for_column(self, column_name): + """Retrieve the distribution for a given column name.""" + if isinstance(self.distribution, dict): + return self.distribution.get(column_name, DEFAULT_DISTRIBUTION) + + return self.distribution + + def _fit_column(self, column, distribution, column_name): + """Fit a single column to its distribution with exception handling.""" + univariate = get_instance(distribution) + try: + univariate.fit(column) + except Exception as error: + univariate = self._fit_with_fallback_distribution( + column, distribution, column_name, error + ) + + return univariate + + def _fit_with_fallback_distribution(self, column, distribution, column_name, error): + """Fall back to fitting a Gaussian distribution and log the error.""" + log_message = ( + f'Unable to fit to a {distribution} distribution for column {column_name}. ' + 'Using a Gaussian distribution instead.' + ) + LOGGER.info(log_message) + univariate = GaussianUnivariate() + univariate.fit(column) + return univariate - LOGGER.debug('Computing correlation') - self.correlation = self._get_correlation(X) - self.fitted = True + def _get_correlation(self, X): + """Compute correlation matrix with transformed data. - LOGGER.debug('GaussianMultivariate fitted successfully') + Args: + X (numpy.ndarray): + Data for which the correlation needs to be computed. + + Returns: + numpy.ndarray: + computed correlation matrix. + """ + result = self._transform_to_normal(X) + correlation = pd.DataFrame(data=result).corr().to_numpy() + correlation = np.nan_to_num(correlation, nan=0.0) + # If singular, add some noise to the diagonal + if np.linalg.cond(correlation) > 1.0 / sys.float_info.epsilon: + correlation = correlation + np.identity(correlation.shape[0]) * EPSILON + + return pd.DataFrame(correlation, index=self.columns, columns=self.columns) def probability_density(self, X): """Compute the probability density for each point in X. diff --git a/tests/unit/multivariate/test_gaussian.py b/tests/unit/multivariate/test_gaussian.py index 545edcc4..c0c2c616 100644 --- a/tests/unit/multivariate/test_gaussian.py +++ b/tests/unit/multivariate/test_gaussian.py @@ -350,6 +350,111 @@ def test_fit_broken_distribution(self, logger_mock, truncated_mock): assert isinstance(copula.univariates[0], GaussianUnivariate) assert copula.univariates[0]._params == {'loc': np.mean(data), 'scale': np.std(data)} + def test__validate_input_with_dataframe(self): + """Test that `_validate_input` returns the same DataFrame.""" + # Setup + instance = GaussianMultivariate() + input_df = pd.DataFrame({'A': [1, 2, 3]}) + + # Run + result = instance._validate_input(input_df) + + # Assert + pd.testing.assert_frame_equal(result, input_df) + + def test__validate_input_with_non_dataframe(self): + """Test that `_validate_input` converts non-DataFrame input into a DataFrame.""" + # Setup + instance = GaussianMultivariate() + input_data = [[1, 2, 3], [4, 5, 6]] + + # Run + result = instance._validate_input(input_data) + + # Assert + expected_df = pd.DataFrame(input_data) + pd.testing.assert_frame_equal(result, expected_df) + + @patch('copulas.multivariate.gaussian.LOGGER') + def test__fit_columns(self, mock_logger): + """Test that `_fit_columns` fits each column to its distribution.""" + # Setup + instance = GaussianMultivariate() + instance._get_distribution_for_column = Mock(return_value='normal') + instance._fit_column = Mock(return_value='fitted_univariate') + + X = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}) + + # Run + columns, univariates = instance._fit_columns(X) + + # Assert + assert columns == ['A', 'B'] + assert univariates == ['fitted_univariate', 'fitted_univariate'] + instance._get_distribution_for_column.assert_any_call('A') + instance._get_distribution_for_column.assert_any_call('B') + mock_logger.debug.assert_any_call('Fitting column %s to %s', 'A', 'normal') + mock_logger.debug.assert_any_call('Fitting column %s to %s', 'B', 'normal') + + @patch('copulas.multivariate.gaussian.DEFAULT_DISTRIBUTION', new='default_distribution') + def test__get_distribution_for_column_with_dict(self): + """Test that `_get_distribution_for_column` retrieves correct distribution from dict.""" + # Setup + instance = GaussianMultivariate() + instance.distribution = {'A': 'normal', 'B': 'uniform'} + + # Run + result_A = instance._get_distribution_for_column('A') + result_B = instance._get_distribution_for_column('B') + result_C = instance._get_distribution_for_column('C') + + # Assert + assert result_A == 'normal' + assert result_B == 'uniform' + assert result_C == 'default_distribution' + + @patch('copulas.multivariate.gaussian.get_instance') + @patch('copulas.multivariate.gaussian.GaussianUnivariate') + def test__fit_column_with_exception(self, mock_gaussian_univariate, mock_get_instance): + """Test that `_fit_column` falls back to a Gaussian distribution on exception.""" + # Setup + instance = GaussianMultivariate() + column = pd.Series([1, 2, 3]) + distribution = 'normal' + column_name = 'A' + instance._fit_with_fallback_distribution = Mock(return_value='fallback_univariate') + + mock_univariate = Mock() + mock_univariate.fit.side_effect = Exception('Fit error') + mock_get_instance.return_value = mock_univariate + + # Run + result = instance._fit_column(column, distribution, column_name) + + # Assert + instance._fit_with_fallback_distribution.assert_called_once_with( + column, distribution, column_name, mock_univariate.fit.side_effect + ) + assert result == 'fallback_univariate' + + @patch('copulas.multivariate.gaussian.GaussianUnivariate') + def test__fit_with_fallback_distribution(self, mock_gaussian_univariate): + """Test that `_fit_with_fallback_distribution` fits a Gaussian distribution.""" + # Setup + instance = GaussianMultivariate() + column = pd.Series([1, 2, 3]) + distribution = 'normal' + column_name = 'A' + error = Exception('Test error') + mock_gaussian_univariate.return_value = Mock(fit=Mock()) + + # Run + result = instance._fit_with_fallback_distribution(column, distribution, column_name, error) + + # Assert + mock_gaussian_univariate.return_value.fit.assert_called_once_with(column) + assert result == mock_gaussian_univariate.return_value + def test_probability_density(self): """Probability_density computes probability for the given values.""" # Setup