Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add color_words and color_array to plot_text_comparison #6

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions cluestar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,28 @@ def plot_text(X, texts, color_array=None, color_words=None, disable_warning=True

return (p1 | p2).configure_axis(grid=False).configure_view(strokeWidth=0)


def _single_scatter_chart(df_, idx, brush, title="embedding space"):
def _single_scatter_chart(df_, idx, brush, title="embedding space", color_words=None, color_array=None):
cols = ("x1:Q", "y1:Q") if idx == 1 else ("x2:Q", "y2:Q")
if color_words:
color=alt.Color("color", sort=["none"] + color_words)
elif color_array:
color=alt.Color("color")
else:
color=alt.condition(brush, 'id:O', alt.value('lightgray'), legend=None)
return (
alt.Chart(df_)
.mark_circle(opacity=0.6, size=20)
.encode(
x=alt.X(cols[0], axis=None, scale=alt.Scale(zero=False)),
y=alt.Y(cols[1], axis=None, scale=alt.Scale(zero=False)),
color=alt.condition(brush, 'id:O', alt.value('lightgray'), legend=None),
tooltip=["text"],
color=color,
)
.properties(width=350, height=350, title=title)
.add_params(brush)
)

def plot_text_comparison(X1, X2, texts, disable_warning=True):
def plot_text_comparison(X1, X2, texts, disable_warning=True, color_array=None, color_words=None):
"""
Make a visualisation to help find clues in text data.

Expand All @@ -124,6 +129,8 @@ def plot_text_comparison(X1, X2, texts, disable_warning=True):
- `X2`: the numeric features, should be a 2D numpy array
- `texts`: list of text data
- `disable_warning`: disable the standard altair max rows warning
- `color_words`: list of words to highlight
- `color_array`: an array that represents color for the plot
"""
if disable_warning:
alt.data_transformers.disable_max_rows()
Expand All @@ -136,10 +143,24 @@ def plot_text_comparison(X1, X2, texts, disable_warning=True):
df_ = pd.DataFrame({"x1": X1[:, 0], "y1": X1[:, 1], "x2": X2[:, 0], "y2": X2[:, 1], "text": texts}).assign(
trunc_text=lambda d: d["text"].str[:120], r=0
)

if color_array is not None:
if len(color_array) != X1.shape[0]:
raise ValueError(
f"The number of color array ({len(color_array)}) should match X array ({X.shape[0]})."
)
df_ = df_.assign(color=color_array)

if color_words is not None:
df_ = df_.assign(color="none")

for w in color_words:
predicate = df_["text"].str.lower().str.contains(w)
df_ = df_.assign(color=lambda d: np.where(predicate, w, d["color"]))

brush = alt.selection_interval()
p1 = _single_scatter_chart(df_, 1, brush, title="embedding space X1")
p2 = _single_scatter_chart(df_, 2, brush, title="embedding space X2")
p1 = _single_scatter_chart(df_, 1, brush, title="embedding space X1", color_words=color_words, color_array=color_array)
p2 = _single_scatter_chart(df_, 2, brush, title="embedding space X2", color_words=color_words, color_array=color_array)

p3 = (
alt.Chart(df_)
Expand All @@ -156,4 +177,4 @@ def plot_text_comparison(X1, X2, texts, disable_warning=True):
.properties(title="text")
)

return (p1 | p2 | p3).configure_axis(grid=False).configure_view(strokeWidth=0)
return (p1 | p2 | p3).configure_axis(grid=False).configure_view(strokeWidth=0)
Loading