Skip to content

Commit

Permalink
Updated basic_usage and black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
rmj3197 committed May 6, 2024
1 parent e7a2678 commit 666a302
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 75 deletions.
49 changes: 17 additions & 32 deletions QuadratiK/kernel_test/_h_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ def _objective_one_sample(
random_state=random_state,
)

statistic = stat_normality_test(
xnew, h, np.array([mean_dat]), np.diag(s_dat))
statistic = stat_normality_test(xnew, h, np.array([mean_dat]), np.diag(s_dat))
cv = cv_normality(
n,
h,
Expand Down Expand Up @@ -233,7 +232,7 @@ def _objective_two_sample(
skew_tilde = skew_data + dk
s_tilde = s_dat

if isinstance(random_state, (int,np.int_)):
if isinstance(random_state, (int, np.int_)):
random_state = random_state + int(rep_values)

xnew = skewnorm.rvs(
Expand All @@ -251,8 +250,7 @@ def _objective_two_sample(
random_state=np.random.default_rng(random_state),
)

statistic = stat_two_sample(
xnew, ynew, h, np.array([[0]]), np.array([[1]]))
statistic = stat_two_sample(xnew, ynew, h, np.array([[0]]), np.array([[1]]))
cv = cv_twosample(
num_iter,
quantile,
Expand Down Expand Up @@ -370,7 +368,7 @@ def _objective_k_sample(
skew_tilde = skew_data + dk
s_tilde = s_dat

if isinstance(random_state, (int,np.int_)):
if isinstance(random_state, (int, np.int_)):
random_state = random_state + int(rep_values)

nk = round(n / k)
Expand Down Expand Up @@ -587,8 +585,7 @@ def select_h(
)

mean_dat = np.mean(x, axis=0)
s_dat = np.diag(np.cov(x, rowvar=False).reshape(
x.shape[1], x.shape[1]))
s_dat = np.diag(np.cov(x, rowvar=False).reshape(x.shape[1], x.shape[1]))
skew_data = skew(x)
all_parameters = np.array(np.meshgrid(h_values, delta, rep_values)).T.reshape(
-1, 3
Expand All @@ -614,26 +611,22 @@ def select_h(
)
for param in parameters
)
results = pd.DataFrame(
results, columns=["rep", "delta", "h", "score"])
results = pd.DataFrame(results, columns=["rep", "delta", "h", "score"])
results["score"] = 1 - results["score"]
results_mean = (
results.groupby(["h", "delta"]).agg(
{"score": "mean"}).reset_index()
results.groupby(["h", "delta"]).agg({"score": "mean"}).reset_index()
)
results_mean.columns = ["h", "delta", "power"]
results_mean = results_mean.sort_values(by=["delta", "h"])
all_results[delta_val] = results_mean
min_h_power_gt_05 = results_mean[results_mean["power"] >= 0.5]
if not min_h_power_gt_05.empty:
min_h = results_mean[results_mean["power"]
>= 0.50].iloc[0]["h"]
min_h = results_mean[results_mean["power"] >= 0.50].iloc[0]["h"]
break

elif k > k_threshold:
if x.shape[1] != y.shape[1]:
raise ValueError(
"'x' and 'y' must have the same number of columns")
raise ValueError("'x' and 'y' must have the same number of columns")

n = x.shape[0]
m = y.shape[0]
Expand All @@ -651,8 +644,7 @@ def select_h(

mean_dat = np.mean(pooled, axis=0)
s_dat = np.diag(
np.cov(pooled, rowvar=False).reshape(
pooled.shape[1], pooled.shape[1])
np.cov(pooled, rowvar=False).reshape(pooled.shape[1], pooled.shape[1])
)
skew_data = skew(pooled)
all_parameters = np.array(np.meshgrid(h_values, delta, rep_values)).T.reshape(
Expand Down Expand Up @@ -683,20 +675,17 @@ def select_h(
)
for param in parameters
)
results = pd.DataFrame(
results, columns=["rep", "delta", "h", "score"])
results = pd.DataFrame(results, columns=["rep", "delta", "h", "score"])
results["score"] = 1 - results["score"]
results_mean = (
results.groupby(["h", "delta"]).agg(
{"score": "mean"}).reset_index()
results.groupby(["h", "delta"]).agg({"score": "mean"}).reset_index()
)
results_mean.columns = ["h", "delta", "power"]
results_mean = results_mean.sort_values(by=["delta", "h"])
all_results[delta_val] = results_mean
min_h_power_gt_05 = results_mean[results_mean["power"] >= 0.5]
if not min_h_power_gt_05.empty:
min_h = results_mean[results_mean["power"]
>= 0.50].iloc[0]["h"]
min_h = results_mean[results_mean["power"] >= 0.50].iloc[0]["h"]
break
else:
n, d = x.shape
Expand All @@ -710,8 +699,7 @@ def select_h(
)

mean_dat = np.mean(x, axis=0)
s_dat = np.diag(np.cov(x, rowvar=False).reshape(
x.shape[1], x.shape[1]))
s_dat = np.diag(np.cov(x, rowvar=False).reshape(x.shape[1], x.shape[1]))
skew_data = skew(x)

all_parameters = np.array(np.meshgrid(h_values, delta, rep_values)).T.reshape(
Expand Down Expand Up @@ -742,20 +730,17 @@ def select_h(
for param in parameters
)

results_df = pd.DataFrame(
results, columns=["rep", "delta", "h", "score"])
results_df = pd.DataFrame(results, columns=["rep", "delta", "h", "score"])
results_df["score"] = 1 - results_df["score"]
results_mean = (
results_df.groupby(["h", "delta"]).agg(
{"score": "mean"}).reset_index()
results_df.groupby(["h", "delta"]).agg({"score": "mean"}).reset_index()
)
results_mean.columns = ["h", "delta", "power"]
results_mean = results_mean.sort_values(by=["delta", "h"])
all_results[delta_val] = results_mean
min_h_power_gt_05 = results_mean[results_mean["power"] >= 0.5]
if not min_h_power_gt_05.empty:
min_h = results_mean[results_mean["power"]
>= 0.50].iloc[0]["h"]
min_h = results_mean[results_mean["power"] >= 0.50].iloc[0]["h"]
break

all_results = pd.concat(all_results.values())
Expand Down
14 changes: 10 additions & 4 deletions QuadratiK/kernel_test/_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,16 @@ def summary(self, print_fmt="simple_grid"):
self.un_test_statistic_,
self.un_cv_,
self.un_h0_rejected_,
self.var_un_
self.var_un_,
]
res = res.set_axis(
["Test Type", "Un Test Statistic", "Un Critical Value", "Reject H0", "Var Un"]
[
"Test Type",
"Un Test Statistic",
"Un Critical Value",
"Reject H0",
"Var Un",
]
)
else:
res[""] = [
Expand All @@ -472,7 +478,7 @@ def summary(self, print_fmt="simple_grid"):
self.vn_test_statistic_,
self.vn_cv_,
self.vn_h0_rejected_,
self.var_un_
self.var_un_,
]
res = res.set_axis(
[
Expand All @@ -483,7 +489,7 @@ def summary(self, print_fmt="simple_grid"):
"Vn Test Statistic",
"Vn Critical Value",
"Vn Reject H0",
"Var Un"
"Var Un",
]
)

Expand Down
3 changes: 1 addition & 2 deletions QuadratiK/kernel_test/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def variance_k_sample_test(k_cen, sizes, cum_size):
cum_size[i] : cum_size[i] + sizes[i],
cum_size[j] : cum_size[j] + sizes[j],
]

if j > i:
C2 += 8 * n_ij_factor * n_ij_factor * (k_ij**2).sum()
C3 -= 8 * n_ij_factor * ni_factor * (k_ii @ k_ij.T).sum()
Expand Down Expand Up @@ -602,7 +601,7 @@ def subsampling_helper_twosample(size_x, size_y, b, h, data_pool, n_rep, random_
"""
if random_state is None:
generator = check_random_state(random_state)
elif isinstance(random_state, (int,np.int_)):
elif isinstance(random_state, (int, np.int_)):
generator = check_random_state(random_state + n_rep)

ind_x = generator.choice(np.arange(size_x), size=round(size_x * b), replace=False)
Expand Down
4 changes: 2 additions & 2 deletions QuadratiK/ui/pages/1_Normality_Test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def run_normality_test(h_val, num_iter, b, x):
norm_test.vn_test_statistic_,
norm_test.vn_cv_,
norm_test.vn_h0_rejected_,
norm_test.var_un_
norm_test.var_un_,
]
res = res.set_axis(
[
Expand All @@ -93,7 +93,7 @@ def run_normality_test(h_val, num_iter, b, x):
"Vn Test Statistic",
"Vn Critical Value",
"Vn Reject H0",
"Var Un"
"Var Un",
]
)
st.table(res)
Expand Down
10 changes: 8 additions & 2 deletions QuadratiK/ui/pages/2_Two_Sample_Test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,16 @@ def run_twosample_test(h_val, num_iter, b, X, Y):
two_sample_test.un_test_statistic_,
two_sample_test.un_cv_,
two_sample_test.un_h0_rejected_,
two_sample_test.var_un_
two_sample_test.var_un_,
]
res = res.set_axis(
["Test Type", "Un Test Statistic", "Un Critical Value", "Un Reject H0","Var Un"]
[
"Test Type",
"Un Test Statistic",
"Un Critical Value",
"Un Reject H0",
"Var Un",
]
)

st.table(res)
Expand Down
4 changes: 2 additions & 2 deletions QuadratiK/ui/pages/3_K_Sample_Test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ def run_ksample_test(h_val, num_iter, b, X, y):
k_samp_test.un_test_statistic_,
k_samp_test.un_cv_,
k_samp_test.un_h0_rejected_,
k_samp_test.var_un_
k_samp_test.var_un_,
]
res = res.set_axis(
[
"Test Type",
"Un Test Statistic",
"Un Critical Value",
"Un Reject H0",
"Var Un"
"Var Un",
]
)
st.table(res)
Expand Down
4 changes: 2 additions & 2 deletions QuadratiK/ui/pages/7_Clustering_on_Sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,15 @@
k_samp_test.un_test_statistic_,
k_samp_test.un_cv_,
k_samp_test.un_h0_rejected_,
k_samp_test.var_un_
k_samp_test.var_un_,
]
res = res.set_axis(
[
"Test Type",
"Un Test Statistic",
"Un Critical Value",
"Un Reject H0",
"Var Un"
"Var Un",
]
)

Expand Down
Loading

0 comments on commit 666a302

Please sign in to comment.