diff --git a/cluestar/__init__.py b/cluestar/__init__.py index 7ecd1da..89d44cf 100644 --- a/cluestar/__init__.py +++ b/cluestar/__init__.py @@ -99,23 +99,28 @@ def plot_text(X, texts, color_array=None, color_words=None, disable_warning=True return (p1 | p2).configure_axis(grid=False).configure_view(strokeWidth=0) - -def _single_scatter_chart(df_, idx, brush, title="embedding space"): +def _single_scatter_chart(df_, idx, brush, title="embedding space", color_words=None, color_array=None): cols = ("x1:Q", "y1:Q") if idx == 1 else ("x2:Q", "y2:Q") + if color_words: + color=alt.Color("color", sort=["none"] + color_words) + elif color_array: + color=alt.Color("color") + else: + color=alt.condition(brush, 'id:O', alt.value('lightgray'), legend=None) return ( alt.Chart(df_) .mark_circle(opacity=0.6, size=20) .encode( x=alt.X(cols[0], axis=None, scale=alt.Scale(zero=False)), y=alt.Y(cols[1], axis=None, scale=alt.Scale(zero=False)), - color=alt.condition(brush, 'id:O', alt.value('lightgray'), legend=None), tooltip=["text"], + color=color, ) .properties(width=350, height=350, title=title) .add_params(brush) ) -def plot_text_comparison(X1, X2, texts, disable_warning=True): +def plot_text_comparison(X1, X2, texts, disable_warning=True, color_array=None, color_words=None): """ Make a visualisation to help find clues in text data. @@ -124,6 +129,8 @@ def plot_text_comparison(X1, X2, texts, disable_warning=True): - `X2`: the numeric features, should be a 2D numpy array - `texts`: list of text data - `disable_warning`: disable the standard altair max rows warning + - `color_words`: list of words to highlight + - `color_array`: an array that represents color for the plot """ if disable_warning: alt.data_transformers.disable_max_rows() @@ -136,10 +143,24 @@ def plot_text_comparison(X1, X2, texts, disable_warning=True): df_ = pd.DataFrame({"x1": X1[:, 0], "y1": X1[:, 1], "x2": X2[:, 0], "y2": X2[:, 1], "text": texts}).assign( trunc_text=lambda d: d["text"].str[:120], r=0 ) + + if color_array is not None: + if len(color_array) != X1.shape[0]: + raise ValueError( + f"The number of color array ({len(color_array)}) should match X array ({X.shape[0]})." + ) + df_ = df_.assign(color=color_array) + + if color_words is not None: + df_ = df_.assign(color="none") + + for w in color_words: + predicate = df_["text"].str.lower().str.contains(w) + df_ = df_.assign(color=lambda d: np.where(predicate, w, d["color"])) brush = alt.selection_interval() - p1 = _single_scatter_chart(df_, 1, brush, title="embedding space X1") - p2 = _single_scatter_chart(df_, 2, brush, title="embedding space X2") + p1 = _single_scatter_chart(df_, 1, brush, title="embedding space X1", color_words=color_words, color_array=color_array) + p2 = _single_scatter_chart(df_, 2, brush, title="embedding space X2", color_words=color_words, color_array=color_array) p3 = ( alt.Chart(df_) @@ -156,4 +177,4 @@ def plot_text_comparison(X1, X2, texts, disable_warning=True): .properties(title="text") ) - return (p1 | p2 | p3).configure_axis(grid=False).configure_view(strokeWidth=0) + return (p1 | p2 | p3).configure_axis(grid=False).configure_view(strokeWidth=0) \ No newline at end of file