Skip to content

Commit

Permalink
Merge pull request #141 from alexheat/dev
Browse files Browse the repository at this point in the history
Add tqdm progress bar and resolve #139, resolve #140
  • Loading branch information
alexheat authored Nov 26, 2023
2 parents 2e26fdb + 163cceb commit 4691418
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions pylabel/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit as sklearnGroupShuffleSplit
from pylabel.shared import schema
from tqdm import tqdm


class Split:
Expand Down Expand Up @@ -136,6 +137,9 @@ def calc_mse_loss(df):
b = 0 # counter for the batches
batch_df = df_main[0:0]

# Use tqdm in the for loop to show progress bar and iterate through the groups
pbar = tqdm(total=subject_grouped_df_main.ngroups, desc="Splitting dataset")

for _, group in subject_grouped_df_main:
if i < 3:
if i == 0:
Expand All @@ -155,7 +159,7 @@ def calc_mse_loss(df):
i += 1
continue

# Add groups to the
# Add groups to the batch
batch_df = pd.concat([batch_df, group])
b += 1
if b < batch_size and i < subject_grouped_df_main.ngroups - 3:
Expand Down Expand Up @@ -199,6 +203,10 @@ def calc_mse_loss(df):

# print ("Group " + str(i) + ". loss_train: " + str(loss_train) + " | " + "loss_val: " + str(loss_val) + " | " + "loss_test: " + str(loss_test) + " | ")
i += 1

# update the progress bar
pbar.update(b)

# Reset the batch
b = 0
batch_df = df_main[0:0]
Expand All @@ -209,14 +217,16 @@ def calc_mse_loss(df):
# Sometimes the algo will put some rows in the val set even if the split percent was set to zero
# In those cases move the rows from val to test
if round(val_pct, 1) == round(0, 1):
pd.concat([df_test, df_val])
df_val = df_val[0:0] # remove the values from
# Move the rows from val to test and remove the rows from val
df_test = pd.concat([df_test, df_val])
df_val = df_val[0:0] # remove the values from val

# Apply train, split, val labels to the split collumn
df_train["split"] = "train"
df_test["split"] = "test"
df_val["split"] = "val"

# Combine the train, test, and val dataframes
df = pd.concat([df_train, pd.concat([df_test, df_val])])

assert (
Expand All @@ -226,3 +236,7 @@ def calc_mse_loss(df):
self.dataset.df = df
self.dataset.df = self.dataset.df.reset_index(drop=True)
self.dataset.df = self.dataset.df[schema]

# set progtess bar to 100%
pbar.update(subject_grouped_df_main.ngroups)
pbar.close()

0 comments on commit 4691418

Please sign in to comment.