Skip to content

Commit

Permalink
fix dataset update (#8581)
Browse files Browse the repository at this point in the history
* fix dataset update

* revert'

* add changeset

* add test

* add changeset

* changes

* add template

* add changeset

* fix docstring

* test postprocessing

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot authored Jun 19, 2024
1 parent 2b0c157 commit a1c21cb
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 37 deletions.
7 changes: 7 additions & 0 deletions .changeset/soft-worms-remain.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/dataset": patch
"gradio": patch
"website": patch
---

fix:fix dataset update
6 changes: 3 additions & 3 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,12 +1723,12 @@ async def postprocess_data(
) from err

if block.stateful:
if not utils.is_update(predictions[i]):
if not utils.is_prop_update(predictions[i]):
state[block._id] = predictions[i]
output.append(None)
else:
prediction_value = predictions[i]
if utils.is_update(
if utils.is_prop_update(
prediction_value
): # if update is passed directly (deprecated), remove Nones
prediction_value = utils.delete_none(
Expand All @@ -1738,7 +1738,7 @@ async def postprocess_data(
if isinstance(prediction_value, Block):
prediction_value = prediction_value.constructor_args.copy()
prediction_value["__type__"] = "update"
if utils.is_update(prediction_value):
if utils.is_prop_update(prediction_value):
kwargs = state[block._id].constructor_args.copy()
kwargs.update(prediction_value)
kwargs.pop("value", None)
Expand Down
34 changes: 23 additions & 11 deletions gradio/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import warnings
from typing import Any, Literal

from gradio_client.documentation import document
Expand All @@ -17,7 +18,8 @@
@document()
class Dataset(Component):
"""
Creates a gallery or table to display data samples. This component is designed for internal use to display examples.
Creates a gallery or table to display data samples. This component is primarily designed for internal use to display examples.
However, it can also be used directly to display a dataset and let users select examples.
"""

EVENTS = [Events.click, Events.select]
Expand All @@ -26,7 +28,7 @@ def __init__(
self,
*,
label: str | None = None,
components: list[Component] | list[str],
components: list[Component] | list[str] | None = None,
component_props: list[dict[str, Any]] | None = None,
samples: list[list[Any]] | None = None,
headers: list[str] | None = None,
Expand Down Expand Up @@ -70,7 +72,7 @@ def __init__(
self.container = container
self.scale = scale
self.min_width = min_width
self._components = [get_component_instance(c) for c in components]
self._components = [get_component_instance(c) for c in components or []]
if component_props is None:
self.component_props = [
component.recover_kwargs(
Expand Down Expand Up @@ -131,29 +133,39 @@ def get_config(self):

return config

def preprocess(self, payload: int) -> int | list | None:
def preprocess(self, payload: int | None) -> int | list | None:
"""
Parameters:
payload: the index of the selected example in the dataset
Returns:
Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index")
"""
if payload is None:
return None
if self.type == "index":
return payload
elif self.type == "values":
return self.samples[payload]

def postprocess(self, samples: list[list]) -> dict:
def postprocess(self, sample: int | list | None) -> int | None:
"""
Parameters:
samples: Expects a `list[list]` corresponding to the dataset data, can be used to update the dataset.
sample: Expects an `int` index or `list` of sample data. Returns the index of the sample in the dataset or `None` if the sample is not found.
Returns:
Returns the updated dataset data as a `dict` with the key "samples".
Returns the index of the sample in the dataset.
"""
return {
"samples": samples,
"__type__": "update",
}
if sample is None or isinstance(sample, int):
return sample
if isinstance(sample, list):
try:
index = self.samples.index(sample)
except ValueError:
index = None
warnings.warn(
"The `Dataset` component does not support updating the dataset data by providing "
"a set of list values. Instead, you should return a new Dataset(samples=...) object."
)
return index

def example_payload(self) -> Any:
return 0
Expand Down
2 changes: 1 addition & 1 deletion gradio/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def flag(
) / client_utils.strip_invalid_filename_characters(
getattr(component, "label", None) or f"component {idx}"
)
if utils.is_update(sample):
if utils.is_prop_update(sample):
csv_data.append(str(sample))
else:
data = (
Expand Down
2 changes: 1 addition & 1 deletion gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def load_from_cache(self, example_id: int) -> list[Any]:
component, components.File
):
value_to_use = value_as_dict
if not utils.is_update(value_as_dict):
if not utils.is_prop_update(value_as_dict):
raise TypeError("value wasn't an update") # caught below
output.append(value_as_dict)
except (ValueError, TypeError, SyntaxError):
Expand Down
2 changes: 1 addition & 1 deletion gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def validate_url(possible_url: str) -> bool:
return False


def is_update(val):
def is_prop_update(val):
return isinstance(val, dict) and "update" in val.get("__type__", "")


Expand Down
35 changes: 35 additions & 0 deletions js/_website/src/lib/templates/gradio/03_components/dataset.svx
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,40 @@ def predict(···) -> list[list]
<DemosSection demos={obj.demos} />
{/if}

### Examples

**Updating a Dataset**

In this example, we display a text dataset using `gr.Dataset` and then update it when the user clicks a button:

```py
import gradio as gr

philosophy_quotes = [
["I think therefore I am."],
["The unexamined life is not worth living."]
]

startup_quotes = [
["Ideas are easy. Implementation is hard"],
["Make mistakes faster."]
]

def show_startup_quotes():
return gr.Dataset(samples=startup_quotes)

with gr.Blocks() as demo:
textbox = gr.Textbox()
dataset = gr.Dataset(components=[textbox], samples=philosophy_quotes)
button = gr.Button()

button.click(show_startup_quotes, None, dataset)

demo.launch()
```



{#if obj.fns && obj.fns.length > 0}
<!--- Event Listeners -->
### Event Listeners
Expand All @@ -97,3 +131,4 @@ def predict(···) -> list[list]
### Guides
<GuidesSection guides={obj.guides}/>
{/if}

5 changes: 3 additions & 2 deletions js/dataset/Index.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
>;
export let label = "Examples";
export let headers: string[];
export let samples: any[][];
export let samples: any[][] | null = null;
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
Expand All @@ -34,7 +34,7 @@
: `${root}/file=`;
let page = 0;
$: gallery = components.length < 2;
let paginate = samples.length > samples_per_page;
let paginate = samples ? samples.length > samples_per_page : false;
let selected_samples: any[][];
let page_count: number;
Expand All @@ -51,6 +51,7 @@
}
$: {
samples = samples || [];
paginate = samples.length > samples_per_page;
if (paginate) {
visible_pages = [];
Expand Down
19 changes: 1 addition & 18 deletions test/components/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,10 @@ def test_preprocessing(self):
assert dataset.samples == [["value 1"], ["value 2"]]

def test_postprocessing(self):
test_file_dir = Path(Path(__file__).parent, "test_files")
bus = Path(test_file_dir, "bus.png")

dataset = gr.Dataset(
components=["number", "textbox", "image", "html", "markdown"], type="index"
)

output = dataset.postprocess(
samples=[
[5, "hello", bus, "<b>Bold</b>", "**Bold**"],
[15, "hi", bus, "<i>Italics</i>", "*Italics*"],
],
)

assert output == {
"samples": [
[5, "hello", bus, "<b>Bold</b>", "**Bold**"],
[15, "hi", bus, "<i>Italics</i>", "*Italics*"],
],
"__type__": "update",
}
assert dataset.postprocess(1) == 1


@patch(
Expand Down
27 changes: 27 additions & 0 deletions test/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,33 @@ def infer(a, b):
):
await demo.postprocess_data(demo.fns[0], predictions=(1, 2), state=None)

@pytest.mark.asyncio
async def test_dataset_is_updated(self):
def update(value):
return value, gr.Dataset(samples=[["New A"], ["New B"]])

with gr.Blocks() as demo:
with gr.Row():
textbox = gr.Textbox()
dataset = gr.Dataset(
components=["text"], samples=[["Original"]], label="Saved Prompts"
)
dataset.click(update, inputs=[dataset], outputs=[textbox, dataset])
app, _, _ = demo.launch(prevent_thread_lock=True)

client = TestClient(app)

session_1 = client.post(
"/api/predict/",
json={"data": [0], "session_hash": "1", "fn_index": 0},
)
assert "Original" in session_1.json()["data"][0]
session_2 = client.post(
"/api/predict/",
json={"data": [0], "session_hash": "1", "fn_index": 0},
)
assert "New" in session_2.json()["data"][0]


class TestStateHolder:
@pytest.mark.asyncio
Expand Down

0 comments on commit a1c21cb

Please sign in to comment.