Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
Caitlyn Chen committed Dec 19, 2020
1 parent 7aa9732 commit 941432a
Show file tree
Hide file tree
Showing 19 changed files with 36 additions and 123 deletions.
7 changes: 1 addition & 6 deletions lux/action/generalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,7 @@ def generalize(ldf):
for clause in filters:
# new_spec = ldf._intent.copy()
# new_spec.remove_column_from_spec(new_spec.attribute)
temp_vis = Vis(
ldf.current_vis[0]._inferred_intent.copy(),
source=ldf,
title="Overall",
score=0,
)
temp_vis = Vis(ldf.current_vis[0]._inferred_intent.copy(), source=ldf, title="Overall", score=0,)
temp_vis.remove_filter_from_spec(clause.value)
output.append(temp_vis)

Expand Down
1 change: 0 additions & 1 deletion lux/action/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,4 +216,3 @@ def preprocess(vis):
# aggregate(vis)
# interpolate(vis, 100)
normalize(vis)

12 changes: 4 additions & 8 deletions lux/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,7 @@ def get_SQL_cardinality(self):
cardinality = {}
for attr in list(self.columns):
card_query = pd.read_sql(
f"SELECT Count(Distinct({attr})) FROM {self.table_name}",
lux.config.SQLconnection,
f"SELECT Count(Distinct({attr})) FROM {self.table_name}", lux.config.SQLconnection,
)
cardinality[attr] = list(card_query["count"])[0]
self.cardinality = cardinality
Expand All @@ -342,8 +341,7 @@ def get_SQL_unique_values(self):
unique_vals = {}
for attr in list(self.columns):
unique_query = pd.read_sql(
f"SELECT Distinct({attr}) FROM {self.table_name}",
lux.config.SQLconnection,
f"SELECT Distinct({attr}) FROM {self.table_name}", lux.config.SQLconnection,
)
unique_vals[attr] = list(unique_query[attr])
self.unique_values = unique_vals
Expand Down Expand Up @@ -564,8 +562,7 @@ def exported(self) -> Union[Dict[str, VisList], VisList]:
exported_vis = VisList(
list(
map(
self._recommendation[export_action].__getitem__,
exported_vis_lst[export_action],
self._recommendation[export_action].__getitem__, exported_vis_lst[export_action],
)
)
)
Expand Down Expand Up @@ -655,8 +652,7 @@ def _repr_html_(self):
if len(self._recommendation) > 0:
# box = widgets.Box(layout=widgets.Layout(display='inline'))
button = widgets.Button(
description="Toggle Pandas/Lux",
layout=widgets.Layout(width="140px", top="5px"),
description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"),
)
self.output = widgets.Output()
# box.children = [button,output]
Expand Down
3 changes: 1 addition & 2 deletions lux/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def __repr__(self):
if len(ldf.recommendation) > 0:
# box = widgets.Box(layout=widgets.Layout(display='inline'))
button = widgets.Button(
description="Toggle Pandas/Lux",
layout=widgets.Layout(width="140px", top="5px"),
description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"),
)
ldf.output = widgets.Output()
# box.children = [button,output]
Expand Down
10 changes: 2 additions & 8 deletions lux/executor/PandasExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,7 @@ def execute_aggregate(vis: Vis, isFiltered=True):
}
)
vis._vis_data = vis.data.merge(
df,
on=[columns[0], columns[1]],
how="right",
suffixes=["", "_right"],
df, on=[columns[0], columns[1]], how="right", suffixes=["", "_right"],
)
for col in columns[2:]:
vis.data[col] = vis.data[col].fillna(0) # Triggers __setitem__
Expand Down Expand Up @@ -329,10 +326,7 @@ def execute_2D_binning(vis: Vis):
if color_attr.data_type == "nominal":
# Compute mode and count. Mode aggregates each cell by taking the majority vote for the category variable. In cases where there is ties across categories, pick the first item (.iat[0])
result = groups.agg(
[
("count", "count"),
(color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]),
]
[("count", "count"), (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]),]
).reset_index()
elif color_attr.data_type == "quantitative":
# Compute the average of all values in the bin
Expand Down
6 changes: 2 additions & 4 deletions lux/executor/SQLExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,11 @@ def execute_binning(vis: Vis, ldf: LuxDataFrame):
else:
bin_centers = np.array([(attr_min + attr_min + bin_width) / 2])
bin_centers = np.append(
bin_centers,
np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0),
bin_centers, np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0),
)
if attr_type == int:
bin_centers = np.append(
bin_centers,
math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2),
bin_centers, math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2),
)
else:
bin_centers = np.append(bin_centers, (upper_edges[len(upper_edges) - 1] + attr_max) / 2)
Expand Down
5 changes: 1 addition & 4 deletions lux/interestingness/interestingness.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ def interestingness(vis: Vis, ldf: LuxDataFrame) -> int:
vis_attrs_specs = utils.get_attrs_specs(vis._inferred_intent)

record_attrs = list(
filter(
lambda x: x.attribute == "Record" and x.data_model == "measure",
vis_attrs_specs,
)
filter(lambda x: x.attribute == "Record" and x.data_model == "measure", vis_attrs_specs,)
)
n_record = len(record_attrs)
for clause in vis_attrs_specs:
Expand Down
9 changes: 2 additions & 7 deletions lux/processor/Compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,7 @@ def line_or_bar(ldf, dimension: Clause, measure: Clause):
# ShowMe logic + additional heuristics
# count_col = Clause( attribute="count()", data_model="measure")
count_col = Clause(
attribute="Record",
aggregation="count",
data_model="measure",
data_type="quantitative",
attribute="Record", aggregation="count", data_model="measure", data_type="quantitative",
)
auto_channel = {}
if ndim == 0 and nmsr == 1:
Expand Down Expand Up @@ -454,9 +451,7 @@ def populate_wildcard_options(_inferred_intent: List[Clause], ldf: LuxDataFrame)
options = ldf.unique_values[attr]
specInd = _inferred_intent.index(clause)
_inferred_intent[specInd] = Clause(
attribute=clause.attribute,
filter_op="=",
value=list(options),
attribute=clause.attribute, filter_op="=", value=list(options),
)
else:
options.extend(convert_to_list(clause.value))
Expand Down
5 changes: 1 addition & 4 deletions lux/vislib/altair/AltairChart.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@ def encode_color(self):
timeUnit = compute_date_granularity(self.vis.data[color_attr_name])
self.chart = self.chart.encode(
color=alt.Color(
color_attr_name,
type=color_attr_type,
timeUnit=timeUnit,
title=color_attr_name,
color_attr_name, type=color_attr_type, timeUnit=timeUnit, title=color_attr_name,
)
)
self.code += f"chart = chart.encode(color=alt.Color('{color_attr_name}',type='{color_attr_type}',timeUnit='{timeUnit}',title='{color_attr_name}'))"
Expand Down
3 changes: 1 addition & 2 deletions lux/vislib/altair/AltairRenderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ def create_vis(self, vis, standalone=True):
found_variable = "df"
if standalone:
chart.code = chart.code.replace(
"placeholder_variable",
f"pd.DataFrame({str(vis.data.to_dict())})",
"placeholder_variable", f"pd.DataFrame({str(vis.data.to_dict())})",
)
else:
# TODO: Placeholder (need to read dynamically via locals())
Expand Down
8 changes: 2 additions & 6 deletions lux/vislib/altair/BarChart.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def initialize_chart(self):
measure_attr = x_attr.attribute
bar_attr = y_attr.attribute
y_attr_field = alt.Y(
y_attr.attribute,
type=y_attr.data_type,
axis=alt.Axis(labelOverlap=True),
y_attr.attribute, type=y_attr.data_type, axis=alt.Axis(labelOverlap=True),
)
x_attr_field = alt.X(x_attr.attribute, type=x_attr.data_type, title=agg_title)
y_attr_field_code = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', axis=alt.Axis(labelOverlap=True))"
Expand All @@ -63,9 +61,7 @@ def initialize_chart(self):
measure_attr = y_attr.attribute
bar_attr = x_attr.attribute
x_attr_field = alt.X(
x_attr.attribute,
type=x_attr.data_type,
axis=alt.Axis(labelOverlap=True),
x_attr.attribute, type=x_attr.data_type, axis=alt.Axis(labelOverlap=True),
)
x_attr_field_code = f"alt.X('{x_attr.attribute}', type= '{x_attr.data_type}', axis=alt.Axis(labelOverlap=True))"
y_attr_field = alt.Y(y_attr.attribute, type=y_attr.data_type, title=agg_title)
Expand Down
5 changes: 1 addition & 4 deletions lux/vislib/altair/Heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ def initialize_chart(self):
),
y2=alt.Y2("yBinEnd"),
opacity=alt.Opacity(
"count",
type="quantitative",
scale=alt.Scale(type="log"),
legend=None,
"count", type="quantitative", scale=alt.Scale(type="log"), legend=None,
),
)
)
Expand Down
8 changes: 2 additions & 6 deletions lux/vislib/altair/ScatterChart.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,10 @@ def initialize_chart(self):
.mark_circle()
.encode(
x=alt.X(
x_attr.attribute,
scale=alt.Scale(domain=(x_min, x_max)),
type=x_attr.data_type,
x_attr.attribute, scale=alt.Scale(domain=(x_min, x_max)), type=x_attr.data_type,
),
y=alt.Y(
y_attr.attribute,
scale=alt.Scale(domain=(y_min, y_max)),
type=y_attr.data_type,
y_attr.attribute, scale=alt.Scale(domain=(y_min, y_max)), type=y_attr.data_type,
),
)
)
Expand Down
5 changes: 2 additions & 3 deletions tests/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ def test_year_filter_value(global_var):
lambda vis: len(
list(
filter(
lambda clause: clause.value != "" and clause.attribute == "Year",
vis._intent,
lambda clause: clause.value != "" and clause.attribute == "Year", vis._intent,
)
)
)
Expand All @@ -198,6 +197,7 @@ def test_year_filter_value(global_var):
), "Year filter title contains extraneous string, not displayed as summarized string"
df.clear_intent()


def test_similarity(global_var):
df = pytest.car_df
df["Year"] = pd.to_datetime(df["Year"], format="%Y")
Expand All @@ -216,4 +216,3 @@ def test_similarity(global_var):
> df.recommendation["Similarity"][2].score
)
df.clear_intent()

49 changes: 9 additions & 40 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,7 @@ def test_underspecified_vis_collection_zval(global_var):
# check if the number of charts is correct
df = pytest.car_df
vlst = VisList(
[
lux.Clause(attribute="Origin", filter_op="=", value="?"),
lux.Clause(attribute="MilesPerGal"),
],
[lux.Clause(attribute="Origin", filter_op="=", value="?"), lux.Clause(attribute="MilesPerGal"),],
df,
)
assert len(vlst) == 3
Expand Down Expand Up @@ -186,10 +183,7 @@ def test_specified_channel_enforced_vis_collection(global_var):
df = pytest.car_df
# change pandas dtype for the column "Year" to datetype
df["Year"] = pd.to_datetime(df["Year"], format="%Y")
visList = VisList(
[lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal", channel="x")],
df,
)
visList = VisList([lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal", channel="x")], df,)
for vis in visList:
check_attribute_on_channel(vis, "MilesPerGal", "x")

Expand All @@ -204,22 +198,13 @@ def test_autoencoding_scatter(global_var):
check_attribute_on_channel(vis, "Weight", "y")

# Partial channel specified
vis = Vis(
[
lux.Clause(attribute="MilesPerGal", channel="y"),
lux.Clause(attribute="Weight"),
],
df,
)
vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight"),], df,)
check_attribute_on_channel(vis, "MilesPerGal", "y")
check_attribute_on_channel(vis, "Weight", "x")

# Full channel specified
vis = Vis(
[
lux.Clause(attribute="MilesPerGal", channel="y"),
lux.Clause(attribute="Weight", channel="x"),
],
[lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight", channel="x"),],
df,
)
check_attribute_on_channel(vis, "MilesPerGal", "y")
Expand Down Expand Up @@ -257,22 +242,13 @@ def test_autoencoding_line_chart(global_var):
check_attribute_on_channel(vis, "Acceleration", "y")

# Partial channel specified
vis = Vis(
[
lux.Clause(attribute="Year", channel="y"),
lux.Clause(attribute="Acceleration"),
],
df,
)
vis = Vis([lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration"),], df,)
check_attribute_on_channel(vis, "Year", "y")
check_attribute_on_channel(vis, "Acceleration", "x")

# Full channel specified
vis = Vis(
[
lux.Clause(attribute="Year", channel="y"),
lux.Clause(attribute="Acceleration", channel="x"),
],
[lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration", channel="x"),],
df,
)
check_attribute_on_channel(vis, "Year", "y")
Expand Down Expand Up @@ -340,19 +316,15 @@ def test_populate_options(global_var):
assert list_equal(list(col_set), list(df.columns))

df.set_intent(
[
lux.Clause(attribute="?", data_model="measure"),
lux.Clause(attribute="MilesPerGal"),
]
[lux.Clause(attribute="?", data_model="measure"), lux.Clause(attribute="MilesPerGal"),]
)
df._repr_html_()
col_set = set()
for specOptions in Compiler.populate_wildcard_options(df._intent, df)["attributes"]:
for clause in specOptions:
col_set.add(clause.attribute)
assert list_equal(
list(col_set),
["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"],
list(col_set), ["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"],
)


Expand All @@ -361,10 +333,7 @@ def test_remove_all_invalid(global_var):
df["Year"] = pd.to_datetime(df["Year"], format="%Y")
# with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"):
df.set_intent(
[
lux.Clause(attribute="Origin", filter_op="=", value="USA"),
lux.Clause(attribute="Origin"),
]
[lux.Clause(attribute="Origin", filter_op="=", value="USA"), lux.Clause(attribute="Origin"),]
)
df._repr_html_()
assert len(df.current_vis) == 0
Expand Down
5 changes: 1 addition & 4 deletions tests/test_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ def test_period_selection(global_var):
ldf["Year"] = pd.DatetimeIndex(ldf["Year"]).to_period(freq="A")

ldf.set_intent(
[
lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]),
lux.Clause(attribute="Year"),
]
[lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]), lux.Clause(attribute="Year"),]
)

lux.config.executor.execute(ldf.current_vis, ldf)
Expand Down
10 changes: 2 additions & 8 deletions tests/test_interestingness.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,7 @@ def test_interestingness_1_0_1(global_var):
df["Year"] = pd.to_datetime(df["Year"], format="%Y")

df.set_intent(
[
lux.Clause(attribute="Origin", filter_op="=", value="USA"),
lux.Clause(attribute="Cylinders"),
]
[lux.Clause(attribute="Origin", filter_op="=", value="USA"), lux.Clause(attribute="Cylinders"),]
)
df._repr_html_()
assert df.current_vis[0].score == 0
Expand Down Expand Up @@ -121,10 +118,7 @@ def test_interestingness_0_1_1(global_var):
df["Year"] = pd.to_datetime(df["Year"], format="%Y")

df.set_intent(
[
lux.Clause(attribute="Origin", filter_op="=", value="?"),
lux.Clause(attribute="MilesPerGal"),
]
[lux.Clause(attribute="Origin", filter_op="=", value="?"), lux.Clause(attribute="MilesPerGal"),]
)
df._repr_html_()
assert interestingness(df.recommendation["Current Vis"][0], df) != None
Expand Down
3 changes: 1 addition & 2 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def test_validator_invalid_attribute(global_var):
df = pytest.college_df
with pytest.raises(KeyError, match="'blah'"):
with pytest.warns(
UserWarning,
match="The input attribute 'blah' does not exist in the DataFrame.",
UserWarning, match="The input attribute 'blah' does not exist in the DataFrame.",
):
df.intent = ["blah"]
Loading

0 comments on commit 941432a

Please sign in to comment.