diff --git a/cpp/include/cuml/tsa/arima_common.h b/cpp/include/cuml/tsa/arima_common.h index 1f0358f7a6..2ed9da31e2 100644 --- a/cpp/include/cuml/tsa/arima_common.h +++ b/cpp/include/cuml/tsa/arima_common.h @@ -201,8 +201,8 @@ struct ARIMAMemory { *Tparams_ar, *Tparams_ma, *Tparams_sar, *Tparams_sma, *Tparams_sigma2, *d_params, *d_Tparams, *Z_dense, *R_dense, *T_dense, *RQR_dense, *RQ_dense, *P_dense, *alpha_dense, *ImT_dense, *ImT_inv_dense, *v_tmp_dense, *m_tmp_dense, *K_dense, *TP_dense, *pred, *y_diff, *loglike, - *loglike_base, *loglike_pert, *x_pert, *sigma2_buffer, *I_m_AxA_dense, *I_m_AxA_inv_dense, - *Ts_dense, *RQRs_dense, *Ps_dense; + *loglike_base, *loglike_pert, *x_pert, *I_m_AxA_dense, *I_m_AxA_inv_dense, *Ts_dense, + *RQRs_dense, *Ps_dense; T **Z_batches, **R_batches, **T_batches, **RQR_batches, **RQ_batches, **P_batches, **alpha_batches, **ImT_batches, **ImT_inv_batches, **v_tmp_batches, **m_tmp_batches, **K_batches, **TP_batches, **I_m_AxA_batches, **I_m_AxA_inv_batches, **Ts_batches, @@ -279,7 +279,6 @@ struct ARIMAMemory { append_buffer(K_batches, batch_size); append_buffer(TP_dense, rd * rd * batch_size); append_buffer(TP_batches, batch_size); - append_buffer(sigma2_buffer, batch_size); append_buffer(pred, n_obs * batch_size); append_buffer(y_diff, n_obs * batch_size); diff --git a/cpp/src/arima/batched_kalman.cu b/cpp/src/arima/batched_kalman.cu index 2b043dc3a4..7eb0cd2efe 100644 --- a/cpp/src/arima/batched_kalman.cu +++ b/cpp/src/arima/batched_kalman.cu @@ -96,7 +96,6 @@ DI void MM_l(const double* A, const double* B, double* out) * @param[in] batch_size Batch size * @param[out] d_pred Predictions (nobs) * @param[out] d_loglike Log-likelihood (1) - * @param[out] d_ll_sigma2 Sigma^2 term in the log-likelihood (1) * @param[in] n_diff d + s*D * @param[in] fc_steps Number of steps to forecast * @param[out] d_fc Array to store the forecast @@ -116,7 +115,6 @@ __global__ void batched_kalman_loop_kernel(const double* ys, int batch_size, double* d_pred, double* d_loglike, - double* d_ll_sigma2, int n_diff, int fc_steps = 0, double* d_fc = nullptr, @@ -257,7 +255,6 @@ __global__ void batched_kalman_loop_kernel(const double* ys, { double n_obs_ll_f = static_cast(n_obs_ll); b_ll_s2 /= n_obs_ll_f; - if (conf_int) d_ll_sigma2[bid] = b_ll_s2; d_loglike[bid] = -.5 * (b_sum_logFs + n_obs_ll_f * (b_ll_s2 + log(2 * M_PI))); } @@ -342,7 +339,6 @@ union KalmanLoopSharedMemory { * @param[in] rd State vector dimension * @param[out] d_pred Predictions (nobs) * @param[out] d_loglike Log-likelihood (1) - * @param[out] d_ll_sigma2 Sigma^2 term in the log-likelihood (1) * @param[in] n_diff d + s*D * @param[in] fc_steps Number of steps to forecast * @param[out] d_fc Array to store the forecast @@ -365,7 +361,6 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys, int rd, double* d_pred, double* d_loglike, - double* d_ll_sigma2, int n_diff, int fc_steps, double* d_fc, @@ -603,7 +598,6 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys, if (threadIdx.x == 0) { double n_obs_ll_f = static_cast(n_obs_ll); ll_s2 /= n_obs_ll_f; - if (conf_int) d_ll_sigma2[bid] = ll_s2; d_loglike[bid] = -.5 * (sum_logFs + n_obs_ll_f * (ll_s2 + log(2 * M_PI))); } } @@ -625,7 +619,6 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys, * @param[in] rd Dimension of the state vector * @param[out] d_pred Predictions (nobs) * @param[out] d_loglike Log-likelihood (1) - * @param[out] d_ll_sigma2 Sigma^2 term in the log-likelihood (1) * @param[in] n_diff d + s*D * @param[in] fc_steps Number of steps to forecast * @param[out] d_fc Array to store the forecast @@ -646,7 +639,6 @@ void _batched_kalman_device_loop_large(const ARIMAMemory& arima_mem, int rd, double* d_pred, double* d_loglike, - double* d_ll_sigma2, int n_diff, int fc_steps = 0, double* d_fc = nullptr, @@ -690,7 +682,6 @@ void _batched_kalman_device_loop_large(const ARIMAMemory& arima_mem, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -713,7 +704,6 @@ void batched_kalman_loop(raft::handle_t& handle, const ARIMAOrder& order, double* d_pred, double* d_loglike, - double* d_ll_sigma2, int fc_steps = 0, double* d_fc = nullptr, bool conf_int = false, @@ -741,7 +731,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -762,7 +751,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -783,7 +771,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -804,7 +791,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -825,7 +811,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -846,7 +831,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -867,7 +851,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -888,7 +871,6 @@ void batched_kalman_loop(raft::handle_t& handle, batch_size, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -917,7 +899,6 @@ void batched_kalman_loop(raft::handle_t& handle, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -939,7 +920,6 @@ void batched_kalman_loop(raft::handle_t& handle, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -963,7 +943,6 @@ void batched_kalman_loop(raft::handle_t& handle, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -985,7 +964,6 @@ void batched_kalman_loop(raft::handle_t& handle, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -1008,7 +986,6 @@ void batched_kalman_loop(raft::handle_t& handle, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -1030,7 +1007,6 @@ void batched_kalman_loop(raft::handle_t& handle, rd, d_pred, d_loglike, - d_ll_sigma2, n_diff, fc_steps, d_fc, @@ -1046,25 +1022,21 @@ void batched_kalman_loop(raft::handle_t& handle, * @note: One block per batch member, one thread per forecast time step * * @param[in] d_fc Mean forecasts - * @param[in] d_sigma2 sum(v_t * v_t / F_t) / n_obs_diff * @param[inout] d_lower Input: F_{n+t} * Output: lower bound of the confidence intervals * @param[out] d_upper Upper bound of the confidence intervals - * @param[in] fc_steps Number of forecast steps + * @param[in] n_elem Total number of elements (fc_steps * batch_size) * @param[in] multiplier Coefficient associated with the confidence level */ -__global__ void confidence_intervals(const double* d_fc, - const double* d_sigma2, - double* d_lower, - double* d_upper, - int fc_steps, - double multiplier) +__global__ void confidence_intervals( + const double* d_fc, double* d_lower, double* d_upper, int n_elem, double multiplier) { - int idx = blockIdx.x * fc_steps + threadIdx.x; - double fc = d_fc[idx]; - double margin = multiplier * sqrt(d_lower[idx] * d_sigma2[blockIdx.x]); - d_lower[idx] = fc - margin; - d_upper[idx] = fc + margin; + for (int idx = threadIdx.x; idx < n_elem; idx += blockDim.x * gridDim.x) { + double fc = d_fc[idx]; + double margin = multiplier * sqrt(d_lower[idx]); + d_lower[idx] = fc - margin; + d_upper[idx] = fc + margin; + } } void _lyapunov_wrapper(raft::handle_t& handle, @@ -1287,15 +1259,16 @@ void _batched_kalman_filter(raft::handle_t& handle, order, d_pred, d_loglike, - arima_mem.sigma2_buffer, fc_steps, d_fc, level > 0, d_lower); if (level > 0) { - confidence_intervals<<>>( - d_fc, arima_mem.sigma2_buffer, d_lower, d_upper, fc_steps, sqrt(2.0) * erfinv(level)); + constexpr int TPB_conf = 256; + int n_blocks = raft::ceildiv(fc_steps * batch_size, TPB_conf); + confidence_intervals<<>>( + d_fc, d_lower, d_upper, fc_steps * batch_size, sqrt(2.0) * erfinv(level)); CUDA_CHECK(cudaPeekAtLastError()); } } diff --git a/python/cuml/test/test_arima.py b/python/cuml/test/test_arima.py index d4823624a1..1a704abbc9 100644 --- a/python/cuml/test/test_arima.py +++ b/python/cuml/test/test_arima.py @@ -420,9 +420,9 @@ def _predict_common(key, data, dtype, start, end, num_steps=None, level=None, np.testing.assert_allclose(cuml_pred, ref_preds, rtol=0.001, atol=0.01) if level is not None: np.testing.assert_allclose( - cuml_lower, ref_lower, rtol=0.03, atol=0.01) + cuml_lower, ref_lower, rtol=0.005, atol=0.01) np.testing.assert_allclose( - cuml_upper, ref_upper, rtol=0.03, atol=0.01) + cuml_upper, ref_upper, rtol=0.005, atol=0.01) @pytest.mark.parametrize('key, data', test_data)