Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to coerce string input for subsample_n #34

Merged
merged 2 commits into from
Aug 20, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 22 additions & 25 deletions pycytominer/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ def __init__(
0 < subsample_frac and 1 >= subsample_frac
), "subsample_frac must be between 0 and 1"

# Check that the user didn't specify both subset frac and
assert (
subsample_frac == 1 or subsample_n == "all"
), "Do not set both subsample_frac and subsample_n"

self.sql_file = sql_file
self.strata = strata
self.features = features
Expand All @@ -72,10 +67,18 @@ def __init__(
self.subset_data = "none"
self.subsampling_random_state = subsampling_random_state
self.is_aggregated = False
self.is_subset_computed = False

if self.subsample_n != "all":
try:
self.subsample_n = int(self.subsample_n)
except ValueError:
print("subsample n must be an integer or coercable")

# Connect to sqlite engine
self.engine = create_engine(self.sql_file)
self.conn = self.engine.connect()
self._check_subsampling()

if load_image_data:
self.load_image()
Expand All @@ -90,16 +93,22 @@ def _check_compartments(self, compartments):
elif isinstance(compartments, str):
assert compartments in valid_compartments, error_str

def _check_subsampling(self):
# Check that the user didn't specify both subset frac and
assert (
self.subsample_frac == 1 or self.subsample_n == "all"
), "Do not set both subsample_frac and subsample_n"

def set_output_file(self, output_file):
self.output_file = output_file

def set_subsample_frac(self, subsample_frac):
self.subsample_n = "all"
self.subsample_frac = subsample_frac
self._check_subsampling()

def set_subsample_n(self, subsample_n):
self.subsample_frac = 1
self.subsample_n = subsample_n
self._check_subsampling()

def set_subsample_random_state(self, random_state):
self.subsampling_random_state = random_state
Expand All @@ -109,9 +118,7 @@ def load_image(self):
Load image table from sqlite file
"""
# Extract image metadata
image_cols = (
"TableNumber, ImageNumber, {}".format(", ".join(self.strata))
)
image_cols = "TableNumber, ImageNumber, {}".format(", ".join(self.strata))
image_query = "select {} from image".format(image_cols)
self.image_df = pd.read_sql(sql=image_query, con=self.conn)

Expand All @@ -127,6 +134,7 @@ def count_cells(self, compartment="cells", count_subset=False):

if count_subset:
assert self.is_aggregated, "Make sure to aggregate_profiles() first!"
assert self.is_subset_computed, "Make sure to get_subsample() first!"
count_df = pd.crosstab(
self.subset_data.loc[:, self.strata[1]],
self.subset_data.loc[:, self.strata[0]],
Expand Down Expand Up @@ -166,7 +174,7 @@ def subsample_profiles(self, x, random_state="none"):
x, frac=self.subsample_frac, random_state=self.subsampling_random_state
)

def get_subsample(self, compartment="cells", subsample_frac=1, subsample_n="all"):
def get_subsample(self, compartment="cells"):
"""
Extract subsample from sqlite file

Expand All @@ -175,12 +183,6 @@ def get_subsample(self, compartment="cells", subsample_frac=1, subsample_n="all"
"""
self._check_compartments(compartment)

if subsample_frac < 1:
self.set_subsample_frac(subsample_frac)

if isinstance(subsample_n, int):
self.set_subsample_n(subsample_n)

query_cols = "TableNumber, ImageNumber, ObjectNumber"
query = "select {} from {}".format(query_cols, compartment)

Expand All @@ -194,6 +196,7 @@ def get_subsample(self, compartment="cells", subsample_frac=1, subsample_n="all"
.apply(lambda x: self.subsample_profiles(x))
.reset_index(drop=True)
)
self.is_subset_computed = True

def aggregate_compartment(self, compartment, compute_subsample=False):
"""
Expand Down Expand Up @@ -301,15 +304,9 @@ def aggregate(
population_df = population_df.groupby(strata)

if operation == "median":
population_df = (
population_df.median()
.reset_index()
)
population_df = population_df.median().reset_index()
else:
population_df = (
population_df.mean()
.reset_index()
)
population_df = population_df.mean().reset_index()

# Aggregated image number and object number do not make sense
for col in ["ImageNumber", "ObjectNumber"]:
Expand Down