From f8bc991cb3347900724e0689057b07cc048de479 Mon Sep 17 00:00:00 2001 From: Clay Shoaf Date: Fri, 14 Apr 2023 22:08:23 -0400 Subject: [PATCH 01/13] created a new branch --- extensions/xy_grid/script.py | 112 +++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 extensions/xy_grid/script.py diff --git a/extensions/xy_grid/script.py b/extensions/xy_grid/script.py new file mode 100644 index 0000000000..8775db9026 --- /dev/null +++ b/extensions/xy_grid/script.py @@ -0,0 +1,112 @@ +import os +import json + +import gradio as gr +import modules.shared as shared +import modules.ui +import pyparsing as pp + +from modules.chat import chatbot_wrapper +from pathlib import Path + +custom_state = {} +custom_output = [] + +def load_preset_values(preset_menu, state, return_dict=False): + generate_params = { + 'do_sample': True, + 'temperature': 1, + 'top_p': 1, + 'typical_p': 1, + 'repetition_penalty': 1, + 'encoder_repetition_penalty': 1, + 'top_k': 50, + 'num_beams': 1, + 'penalty_alpha': 0, + 'min_length': 0, + 'length_penalty': 1, + 'no_repeat_ngram_size': 0, + 'early_stopping': False, + } + with open(Path(f'presets/{preset_menu}.txt'), 'r') as infile: + preset = infile.read() + for i in preset.splitlines(): + i = i.rstrip(',').strip().split('=') + if len(i) == 2 and i[0].strip() != 'tokens': + generate_params[i[0].strip()] = eval(i[1].strip()) + generate_params['temperature'] = min(1.99, generate_params['temperature']) + + if return_dict: + return generate_params + else: + state.update(generate_params) + return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] + + +def get_presets(): + global custom_state + presets = [] + filenames = os.listdir("presets/") + for file in filenames: + preset = file[:-4] + presets.append(preset) + custom_state = load_preset_values(preset, custom_state)[0] + return ", ".join(presets) + +def get_params(*args): + global custom_state + custom_state = modules.ui.gather_interface_values(*args) + return json.dumps(custom_state) + +def run(x="",y=""): + global custom_state + global custom_output + + output = "" + + x_strings = pp.common.comma_separated_list.parseString(x).asList() + y_strings = pp.common.comma_separated_list.parseString(y).asList() + + for i in y_strings: + output = output + f"" + output = output + "" + for i in x_strings: + output = output + f"" + if y_strings[0] != '': + for j in y_strings: + custom_state = load_preset_values(j.strip(), custom_state)[0] + for new in chatbot_wrapper(i.strip(), custom_state): + custom_output = new + output = output + f"" + custom_output.pop() + shared.history['internal'].pop() + output = output + "" + else: + for new in chatbot_wrapper(i.strip(), custom_state): + custom_output = new + output = output + f"" + custom_output.pop() + shared.history['internal'].pop() + output = output + "" + output = output + "
{i.strip()}
{i}{custom_state['name1']}: {custom_output[-1][0]}

{custom_state['name2']}: {custom_output[-1][1]}
{custom_state['name1']}: {custom_output[-1][0]}

{custom_state['name2']}: {custom_output[-1][1]}
" + return output + +def gradio_sucks(flubby): + return flubby + +def ui(): + with gr.Accordion("XY Grid", open=False): + prompt = gr.Textbox(value="name1", label='Input Prompt', interactive=True) + with gr.Row(): + presets_box = gr.Textbox(placeholder="presets go here...", label='Presets', interactive=True) + refresh_presets = modules.ui.ToolButton(value='\U0001f504', elem_id='refresh-button') + refresh_presets.click(fn=get_presets, outputs=presets_box) + make_state = gr.Button("make_state") + generate_grid = gr.Button("generate_grid") + tester = gr.HTML(value="what the fuck is happening?") + state = gr.HTML(value="the state will go here") + custom_chat = gr.HTML(value="for the love of God, is this actually going to work???") + + prompt.change(gradio_sucks, prompt, tester) + make_state.click(get_params, [shared.gradio[k] for k in shared.input_elements], state) + generate_grid.click(get_params, [shared.gradio[k] for k in shared.input_elements], state).then(run, [prompt, presets_box], custom_chat) From 20fa4b4c285ea4df94803601403e66f7b3209a14 Mon Sep 17 00:00:00 2001 From: Clay Shoaf Date: Fri, 14 Apr 2023 23:54:14 -0400 Subject: [PATCH 02/13] interface / write to file Cleaned up the interface a little bit. It will now also write the output to an html file and provide a link to open the file for better viewing. --- extensions/xy_grid/script.py | 55 ++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/extensions/xy_grid/script.py b/extensions/xy_grid/script.py index 8775db9026..da7e3f7abb 100644 --- a/extensions/xy_grid/script.py +++ b/extensions/xy_grid/script.py @@ -1,5 +1,6 @@ import os import json +import datetime import gradio as gr import modules.shared as shared @@ -12,7 +13,8 @@ custom_state = {} custom_output = [] -def load_preset_values(preset_menu, state, return_dict=False): +# I had to steal this from server.py because the program freaks out if I try to `import server` +def load_preset_values(preset_menu, state): generate_params = { 'do_sample': True, 'temperature': 1, @@ -36,13 +38,11 @@ def load_preset_values(preset_menu, state, return_dict=False): generate_params[i[0].strip()] = eval(i[1].strip()) generate_params['temperature'] = min(1.99, generate_params['temperature']) - if return_dict: - return generate_params - else: - state.update(generate_params) - return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] + state.update(generate_params) + return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] +# Get all of the presets from the presets folder def get_presets(): global custom_state presets = [] @@ -53,17 +53,22 @@ def get_presets(): custom_state = load_preset_values(preset, custom_state)[0] return ", ".join(presets) +# This is a workaround function because gradio has to access parameters if you want them to be current def get_params(*args): global custom_state custom_state = modules.ui.gather_interface_values(*args) return json.dumps(custom_state) -def run(x="",y=""): +# The main function that generates the output, formats the html table, and returns it to the interface +def run(x="", y=""): global custom_state global custom_output + custom_state['seed'] = "420691337" + output = "" + # Have to format the strings because gradio makes it difficult to pass lists around x_strings = pp.common.comma_separated_list.parseString(x).asList() y_strings = pp.common.comma_separated_list.parseString(y).asList() @@ -75,38 +80,46 @@ def run(x="",y=""): if y_strings[0] != '': for j in y_strings: custom_state = load_preset_values(j.strip(), custom_state)[0] + + # This is the part that actually does the generating for new in chatbot_wrapper(i.strip(), custom_state): custom_output = new - output = output + f"" + + output = output + f"" custom_output.pop() shared.history['internal'].pop() + output = output + "" else: for new in chatbot_wrapper(i.strip(), custom_state): custom_output = new - output = output + f"" + output = output + f"" custom_output.pop() shared.history['internal'].pop() output = output + "" output = output + "
{custom_state['name1']}: {custom_output[-1][0]}

{custom_state['name2']}: {custom_output[-1][1]}
{custom_state['name1']}: {custom_output[-1][0]}
{custom_state['name2']}: {custom_output[-1][1]}
{custom_state['name1']}: {custom_output[-1][0]}

{custom_state['name2']}: {custom_output[-1][1]}
{custom_state['name1']}: {custom_output[-1][0]}
{custom_state['name2']}: {custom_output[-1][1]}
" - return output -def gradio_sucks(flubby): - return flubby + # Save the output to a file + # Useful for large grids that don't display well in gradio + save_filename = f"{datetime.datetime.now().strftime('%Y_%m_%d_%f')}.html" + with open(Path(f"extensions/xy_grid/outputs/{save_filename}"), 'w') as outfile: + outfile.write(output) + + # Trying to include a link to easily open the html file in a new tab, but I think this is gonna be more confusing than I expected + output = output + f"

open html file" + return output +# Create the interface for the extension (this runs first) def ui(): - with gr.Accordion("XY Grid", open=False): - prompt = gr.Textbox(value="name1", label='Input Prompt', interactive=True) + with gr.Accordion("XY Grid", open=True): + prompt = gr.Textbox(placeholder="Comma separated prompts go here...", label='Input Prompts', interactive=True) with gr.Row(): - presets_box = gr.Textbox(placeholder="presets go here...", label='Presets', interactive=True) + presets_box = gr.Textbox(placeholder="Presets go here. Click the buttton to the right...", label='Presets', interactive=True) refresh_presets = modules.ui.ToolButton(value='\U0001f504', elem_id='refresh-button') refresh_presets.click(fn=get_presets, outputs=presets_box) - make_state = gr.Button("make_state") generate_grid = gr.Button("generate_grid") - tester = gr.HTML(value="what the fuck is happening?") - state = gr.HTML(value="the state will go here") - custom_chat = gr.HTML(value="for the love of God, is this actually going to work???") + with gr.Accordion("Generation Parameters for testing", open=False): + state = gr.HTML(value="the state will go here") + custom_chat = gr.HTML(value="") - prompt.change(gradio_sucks, prompt, tester) - make_state.click(get_params, [shared.gradio[k] for k in shared.input_elements], state) generate_grid.click(get_params, [shared.gradio[k] for k in shared.input_elements], state).then(run, [prompt, presets_box], custom_chat) From 7293d47a17e2d31bbdc63a7d09ccdff3022aa243 Mon Sep 17 00:00:00 2001 From: Clay Shoaf Date: Mon, 17 Apr 2023 20:14:28 -0400 Subject: [PATCH 03/13] reworked ui and added new axes options Looks a lot better now and you can choose characters now too. TODO: - better error handling, like for a comma at the beginning of a line - clean up the code. I know a lot of stuff is not optimal yet - make an easier way to add new axis options - add more axis options - LoRA - Seed - Individual parameters - Models (probably way down the road) - checkbox for "constant seed" --- extensions/xy_grid/script.py | 369 +++++++++++++++++++++++++++++++++-- 1 file changed, 349 insertions(+), 20 deletions(-) diff --git a/extensions/xy_grid/script.py b/extensions/xy_grid/script.py index da7e3f7abb..18aa11644d 100644 --- a/extensions/xy_grid/script.py +++ b/extensions/xy_grid/script.py @@ -7,11 +7,12 @@ import modules.ui import pyparsing as pp -from modules.chat import chatbot_wrapper +from modules.chat import chatbot_wrapper, load_character from pathlib import Path +axis_type = {'x': "prompts", 'y': "prompts"} custom_state = {} -custom_output = [] +gen_output = [] # I had to steal this from server.py because the program freaks out if I try to `import server` def load_preset_values(preset_menu, state): @@ -41,16 +42,18 @@ def load_preset_values(preset_menu, state): state.update(generate_params) return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] +# Get all of the characters from the character folder +def get_characters(): + paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) + return ", ".join(['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)) # Get all of the presets from the presets folder def get_presets(): - global custom_state presets = [] filenames = os.listdir("presets/") for file in filenames: preset = file[:-4] presets.append(preset) - custom_state = load_preset_values(preset, custom_state)[0] return ", ".join(presets) # This is a workaround function because gradio has to access parameters if you want them to be current @@ -59,12 +62,223 @@ def get_params(*args): custom_state = modules.ui.gather_interface_values(*args) return json.dumps(custom_state) +# Returns the correct results for the axis type chosen by the axis dropdown box +def fill_axis(option): + global axis_get + global custom_state + if option == "prompts": + return gr.update(label=option, value=custom_state['textbox']) + else: + return gr.update(label=option, value=axis_get.get(option)) + +# Sets the type of data each axis will use +def set_axis(x, y): + global axis_type + axis_type.update({'x': x}) + axis_type.update({'y': y}) + + + +def newrun(x="", y=""): + global custom_state + global gen_output + global axis_type + global testa + + if custom_state['custom_stopping_strings'] == None: + custom_state['custom_stopping_strings'] = "" + + # Have to format the strings because gradio makes it difficult to pass lists around + x_strings = pp.common.comma_separated_list.parseString(x).asList() + y_strings = pp.common.comma_separated_list.parseString(y).asList() + + output = "" + + if axis_type['x'] == axis_type['y']: + return "

ERROR: both axes cannot be the same setting" + + elif axis_type['x'] == "prompts": # Run as if x axis is prompts + for i in x_strings: + output = output + f"

" + output = output + "" + if y_strings[0] != '': + for i in y_strings: + output = output + f"" + for j in x_strings: + + # parse the type of the Y axis and alter custom_state accordingly + if axis_type['y'] == "presets": + custom_state = load_preset_values(i.strip(), custom_state)[0] + elif axis_type['y'] == "characters": + custom_state['character_menu'] = i.strip() + custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + + # This is the part that actually does the generating + for new in chatbot_wrapper(j.strip(), custom_state): + gen_output = new + #gen_output = [['test', 'pest'], ['poop', 'floop']] + + output = output + f"" + gen_output.pop() + shared.history['internal'].pop() + + output = output + "" + else: + output = output + "" + for i in x_strings: + for new in chatbot_wrapper(i.strip(), custom_state): + gen_output = new + #gen_output = [['test', 'pest'], ['poop', 'floop']] + output = output + f"" + + # Remove the last outputs so they don't influence future generations + gen_output.pop() + shared.history['internal'].pop() + + output = output + "" + + + elif axis_type['y'] == "prompts": # Run as if y axis is prompts + if x_strings[0] != '': + for i in x_strings: + output = output + f"" + output = output + "" + for i in y_strings: + output = output + f"" + for j in x_strings: + + # parse the type of the X axis and alter custom_state accordingly + if axis_type['x'] == "presets": + custom_state = load_preset_values(j.strip(), custom_state)[0] + elif axis_type['x'] == "characters": + custom_state['character_menu'] = j.strip() + custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + + # This is the part that actually does the generating + for new in chatbot_wrapper(i.strip(), custom_state): + gen_output = new + #gen_output = [['test', 'pest'], ['poop', 'floop']] + + output = output + f"" + gen_output.pop() + shared.history['internal'].pop() + + output = output + "" + else: + for i in x_strings: + output = output + f"" + output = output + "" + for i in y_strings: + for new in chatbot_wrapper(i.strip(), custom_state): + gen_output = new + #gen_output = [['test', 'pest'], ['poop', 'floop']] + output = output + f"" + + # Remove the last outputs so they don't influence future generations + gen_output.pop() + shared.history['internal'].pop() + + + + else: # Take the prompts from custom_state['textbox'] + for i in x_strings: + output = output + f"" + output = output + "" + if y_strings[0] != '' and x_strings[0] != '': + for i in y_strings: + output = output + f"" + for j in x_strings: + # parse the types of the axes and alter custom_state accordingly + if axis_type['y'] == "presets": + custom_state = load_preset_values(i.strip(), custom_state)[0] + elif axis_type['y'] == "characters": + custom_state['character_menu'] = i.strip() + custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + + if axis_type['x'] == "presets": + custom_state = load_preset_values(j.strip(), custom_state)[0] + elif axis_type['x'] == "characters": + custom_state['character_menu'] = j.strip() + custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + + # This is the part that actually does the generating + for new in chatbot_wrapper(custom_state['textbox'].strip(), custom_state): + gen_output = new + + output = output + f"" + gen_output.pop() + shared.history['internal'].pop() + + output = output + "" + + elif x_strings[0] != '': + output = output + "" + for i in x_strings: + + # parse the types of the axes and alter custom_state accordingly + if axis_type['x'] == "presets": + custom_state = load_preset_values(i.strip(), custom_state)[0] + elif axis_type['x'] == "characters": + custom_state['character_menu'] = i.strip() + custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + + # Run the actual text generator + for new in chatbot_wrapper(custom_state['textbox'].strip(), custom_state): + gen_output = new + output = output + f"" + + # Remove the last outputs so they don't influence future generations + gen_output.pop() + shared.history['internal'].pop() + + output = output + "" + + + elif y_strings[0] != '': + output = output + "" + for i in y_strings: + # parse the types of the axes and alter custom_state accordingly + if axis_type['y'] == "presets": + custom_state = load_preset_values(i.strip(), custom_state)[0] + elif axis_type['y'] == "characters": + custom_state['character_menu'] = i.strip() + custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + + # Run the actual text generator + for new in chatbot_wrapper(custom_state['textbox'].strip(), custom_state): + gen_output = new + output = output + f"" + + # Remove the last outputs so they don't influence future generations + gen_output.pop() + shared.history['internal'].pop() + + else: + return "

ERROR: both fields are empty" + + output = output + "

{i.strip()}
{i}{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}
{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}
{i.strip()}
{i}{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}
{i.strip()}
{i}{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}
{i.strip()}
{i}{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}
{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}
{i}{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}
" + + # Save the output to a file + # Useful for large grids that don't display well in gradio + save_filename = f"{datetime.datetime.now().strftime('%Y_%m_%d_%f')}.html" + with open(Path(f"extensions/xy_grid/outputs/{save_filename}"), 'w') as outfile: + outfile.write(output) + + # Trying to include a link to easily open the html file in a new tab, but I think this is gonna be more confusing than I expected + output = output + f"

[ open html file ]

" + + return output + + + + + + + # The main function that generates the output, formats the html table, and returns it to the interface def run(x="", y=""): global custom_state - global custom_output - custom_state['seed'] = "420691337" - + global gen_output output = "" @@ -83,19 +297,22 @@ def run(x="", y=""): # This is the part that actually does the generating for new in chatbot_wrapper(i.strip(), custom_state): - custom_output = new + gen_output = new - output = output + f"" - custom_output.pop() + output = output + f"" + gen_output.pop() shared.history['internal'].pop() output = output + "" else: for new in chatbot_wrapper(i.strip(), custom_state): - custom_output = new - output = output + f"" - custom_output.pop() + gen_output = new + output = output + f"" + + # Remove the last outputs so they don't influence future generations + gen_output.pop() shared.history['internal'].pop() + output = output + "" output = output + "
{custom_state['name1']}: {custom_output[-1][0]}
{custom_state['name2']}: {custom_output[-1][1]}
{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}
{custom_state['name1']}: {custom_output[-1][0]}
{custom_state['name2']}: {custom_output[-1][1]}
{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}
" @@ -109,17 +326,129 @@ def run(x="", y=""): output = output + f"

open html file" return output + + +# Necessary for som stuff because gradio +def kickback(flubby=""): + return flubby + +axis_get = { + 'presets': get_presets(), + 'prompts': kickback(), + 'characters': get_characters() + } + # Create the interface for the extension (this runs first) def ui(): + global custom_state + global testa + global axis_type + + # Grab all of the variable from shared.gradio and put them in the custom_state dictionary + custom_state.update({k: v for k, v in zip([key for key in shared.gradio if not isinstance(shared.gradio[key], (gr.Blocks, gr.Button, gr.State))], [shared.gradio[k].value for k in [key for key in shared.gradio] if not isinstance(shared.gradio[k], (gr.Blocks, gr.Button, gr.State))])}) + + # Track changes to all variables in shared.gradio + shared.gradio['add_bos_token'].change(lambda x: custom_state.update({'add_bos_token': x}), shared.gradio['add_bos_token'], []) + shared.gradio['auto_devices'].change(lambda x: custom_state.update({'auto_devices': x}), shared.gradio['auto_devices'], []) + shared.gradio['ban_eos_token'].change(lambda x: custom_state.update({'ban_eos_token': x}), shared.gradio['ban_eos_token'], []) + shared.gradio['bf16'].change(lambda x: custom_state.update({'bf16': x}), shared.gradio['bf16'], []) + shared.gradio['bool_menu'].change(lambda x: custom_state.update({'bool_menu': x}), shared.gradio['bool_menu'], []) + shared.gradio['character_menu'].change(lambda x: custom_state.update({'character_menu': x}), shared.gradio['character_menu'], []) + shared.gradio['character_picture'].change(lambda x: custom_state.update({'character_picture': x}), shared.gradio['character_picture'], []) + shared.gradio['chat_generation_attempts'].change(lambda x: custom_state.update({'chat_generation_attempts': x}), shared.gradio['chat_generation_attempts'], []) + #shared.gradio['Chat input'].change(lambda x: custom_state.update({'Chat input': x}), shared.gradio['Chat input'], []) + shared.gradio['chat_prompt_size'].change(lambda x: custom_state.update({'chat_prompt_size': x}), shared.gradio['chat_prompt_size'], []) + #shared.gradio['Clear history'].change(lambda x: custom_state.update({'Clear history': x}), shared.gradio['Clear history'], []) + #shared.gradio['Clear history-cancel'].change(lambda x: custom_state.update({'Clear history-cancel': x}), shared.gradio['Clear history-cancel'], []) + #shared.gradio['Clear history-confirm'].change(lambda x: custom_state.update({'Clear history-confirm': x}), shared.gradio['Clear history-confirm'], []) + shared.gradio['context'].change(lambda x: custom_state.update({'context': x}), shared.gradio['context'], []) + #shared.gradio['Continue'].change(lambda x: custom_state.update({'Continue': x}), shared.gradio['Continue'], []) + #shared.gradio['Copy last reply'].change(lambda x: custom_state.update({'Copy last reply': x}), shared.gradio['Copy last reply'], []) + shared.gradio['cpu'].change(lambda x: custom_state.update({'cpu': x}), shared.gradio['cpu'], []) + shared.gradio['cpu_memory'].change(lambda x: custom_state.update({'cpu_memory': x}), shared.gradio['cpu_memory'], []) + shared.gradio['custom_model_menu'].change(lambda x: custom_state.update({'custom_model_menu': x}), shared.gradio['custom_model_menu'], []) + shared.gradio['custom_stopping_strings'].change(lambda x: custom_state.update({'custom_stopping_strings': x}), shared.gradio['custom_stopping_strings'], []) + shared.gradio['disk'].change(lambda x: custom_state.update({'disk': x}), shared.gradio['disk'], []) + shared.gradio['display'].change(lambda x: custom_state.update({'display': x}), shared.gradio['display'], []) + shared.gradio['do_sample'].change(lambda x: custom_state.update({'do_sample': x}), shared.gradio['do_sample'], []) + shared.gradio['download'].change(lambda x: custom_state.update({'download': x}), shared.gradio['download'], []) + #shared.gradio['download_button'].change(lambda x: custom_state.update({'download_button': x}), shared.gradio['download_button'], []) + #shared.gradio['download_model_button'].change(lambda x: custom_state.update({'download_model_button': x}), shared.gradio['download_model_button'], []) + shared.gradio['early_stopping'].change(lambda x: custom_state.update({'early_stopping': x}), shared.gradio['early_stopping'], []) + shared.gradio['encoder_repetition_penalty'].change(lambda x: custom_state.update({'encoder_repetition_penalty': x}), shared.gradio['encoder_repetition_penalty'], []) + shared.gradio['end_of_turn'].change(lambda x: custom_state.update({'end_of_turn': x}), shared.gradio['end_of_turn'], []) + shared.gradio['extensions_menu'].change(lambda x: custom_state.update({'extensions_menu': x}), shared.gradio['extensions_menu'], []) + #shared.gradio['Generate'].change(lambda x: custom_state.update({'Generate': x}), shared.gradio['Generate'], []) + shared.gradio['gpu_memory_0'].change(lambda x: custom_state.update({'gpu_memory_0': x}), shared.gradio['gpu_memory_0'], []) + shared.gradio['greeting'].change(lambda x: custom_state.update({'greeting': x}), shared.gradio['greeting'], []) + shared.gradio['groupsize'].change(lambda x: custom_state.update({'groupsize': x}), shared.gradio['groupsize'], []) + #shared.gradio['Impersonate'].change(lambda x: custom_state.update({'Impersonate': x}), shared.gradio['Impersonate'], []) + shared.gradio['instruction_template'].change(lambda x: custom_state.update({'instruction_template': x}), shared.gradio['instruction_template'], []) + #shared.gradio['interface'].change(lambda x: custom_state.update({'interface': x}), shared.gradio['interface'], []) + shared.gradio['interface_modes_menu'].change(lambda x: custom_state.update({'interface_modes_menu': x}), shared.gradio['interface_modes_menu'], []) + #shared.gradio['interface_state'].change(lambda x: custom_state.update({'interface_state': x}), shared.gradio['interface_state'], []) + shared.gradio['length_penalty'].change(lambda x: custom_state.update({'length_penalty': x}), shared.gradio['length_penalty'], []) + shared.gradio['load_in_8bit'].change(lambda x: custom_state.update({'load_in_8bit': x}), shared.gradio['load_in_8bit'], []) + shared.gradio['lora_menu'].change(lambda x: custom_state.update({'lora_menu': x}), shared.gradio['lora_menu'], []) + #shared.gradio['lora_menu_apply'].change(lambda x: custom_state.update({'lora_menu_apply': x}), shared.gradio['lora_menu_apply'], []) + shared.gradio['max_new_tokens'].change(lambda x: custom_state.update({'max_new_tokens': x}), shared.gradio['max_new_tokens'], []) + shared.gradio['min_length'].change(lambda x: custom_state.update({'min_length': x}), shared.gradio['min_length'], []) + shared.gradio['mode'].change(lambda x: custom_state.update({'mode': x}), shared.gradio['mode'], []) + shared.gradio['model_menu'].change(lambda x: custom_state.update({'model_menu': x}), shared.gradio['model_menu'], []) + shared.gradio['model_status'].change(lambda x: custom_state.update({'model_status': x}), shared.gradio['model_status'], []) + shared.gradio['model_type'].change(lambda x: custom_state.update({'model_type': x}), shared.gradio['model_type'], []) + shared.gradio['name1'].change(lambda x: custom_state.update({'name1': x}), shared.gradio['name1'], []) + shared.gradio['name2'].change(lambda x: custom_state.update({'name2': x}), shared.gradio['name2'], []) + shared.gradio['no_repeat_ngram_size'].change(lambda x: custom_state.update({'no_repeat_ngram_size': x}), shared.gradio['no_repeat_ngram_size'], []) + shared.gradio['num_beams'].change(lambda x: custom_state.update({'num_beams': x}), shared.gradio['num_beams'], []) + shared.gradio['penalty_alpha'].change(lambda x: custom_state.update({'penalty_alpha': x}), shared.gradio['penalty_alpha'], []) + shared.gradio['pre_layer'].change(lambda x: custom_state.update({'pre_layer': x}), shared.gradio['pre_layer'], []) + shared.gradio['preset_menu'].change(lambda x: custom_state.update({'preset_menu': x}), shared.gradio['preset_menu'], []) + #shared.gradio['Regenerate'].change(lambda x: custom_state.update({'Regenerate': x}), shared.gradio['Regenerate'], []) + #shared.gradio['Remove last'].change(lambda x: custom_state.update({'Remove last': x}), shared.gradio['Remove last'], []) + shared.gradio['repetition_penalty'].change(lambda x: custom_state.update({'repetition_penalty': x}), shared.gradio['repetition_penalty'], []) + #shared.gradio['Replace last reply'].change(lambda x: custom_state.update({'Replace last reply': x}), shared.gradio['Replace last reply'], []) + #shared.gradio['reset_interface'].change(lambda x: custom_state.update({'reset_interface': x}), shared.gradio['reset_interface'], []) + shared.gradio['seed'].change(lambda x: custom_state.update({'seed': x}), shared.gradio['seed'], []) + #shared.gradio['Send dummy message'].change(lambda x: custom_state.update({'Send dummy message': x}), shared.gradio['Send dummy message'], []) + #shared.gradio['Send dummy reply'].change(lambda x: custom_state.update({'Send dummy reply': x}), shared.gradio['Send dummy reply'], []) + shared.gradio['skip_special_tokens'].change(lambda x: custom_state.update({'skip_special_tokens': x}), shared.gradio['skip_special_tokens'], []) + shared.gradio['softprompts_menu'].change(lambda x: custom_state.update({'softprompts_menu': x}), shared.gradio['softprompts_menu'], []) + #shared.gradio['Stop'].change(lambda x: custom_state.update({'Stop': x}), shared.gradio['Stop'], []) + shared.gradio['stop_at_newline'].change(lambda x: custom_state.update({'stop_at_newline': x}), shared.gradio['stop_at_newline'], []) + shared.gradio['temperature'].change(lambda x: custom_state.update({'temperature': x}), shared.gradio['temperature'], []) + shared.gradio['textbox'].change(lambda x: custom_state.update({'textbox': x}), shared.gradio['textbox'], []) + shared.gradio['top_k'].change(lambda x: custom_state.update({'top_k': x}), shared.gradio['top_k'], []) + shared.gradio['top_p'].change(lambda x: custom_state.update({'top_p': x}), shared.gradio['top_p'], []) + shared.gradio['truncation_length'].change(lambda x: custom_state.update({'truncation_length': x}), shared.gradio['truncation_length'], []) + shared.gradio['typical_p'].change(lambda x: custom_state.update({'typical_p': x}), shared.gradio['typical_p'], []) + #shared.gradio['Upload character'].change(lambda x: custom_state.update({'Upload character': x}), shared.gradio['Upload character'], []) + shared.gradio['upload_chat_history'].change(lambda x: custom_state.update({'upload_chat_history': x}), shared.gradio['upload_chat_history'], []) + shared.gradio['upload_img_bot'].change(lambda x: custom_state.update({'upload_img_bot': x}), shared.gradio['upload_img_bot'], []) + shared.gradio['upload_img_tavern'].change(lambda x: custom_state.update({'upload_img_tavern': x}), shared.gradio['upload_img_tavern'], []) + shared.gradio['upload_json'].change(lambda x: custom_state.update({'upload_json': x}), shared.gradio['upload_json'], []) + shared.gradio['upload_softprompt'].change(lambda x: custom_state.update({'upload_softprompt': x}), shared.gradio['upload_softprompt'], []) + shared.gradio['wbits'].change(lambda x: custom_state.update({'wbits': x}), shared.gradio['wbits'], []) + shared.gradio['your_picture'].change(lambda x: custom_state.update({'your_picture': x}), shared.gradio['your_picture'], []) + with gr.Accordion("XY Grid", open=True): - prompt = gr.Textbox(placeholder="Comma separated prompts go here...", label='Input Prompts', interactive=True) + + # Axis selections and inputs + with gr.Row(): + xType = gr.Dropdown(label='X Axis', choices=list(["prompts","presets","characters"]), value="prompts", interactive=True) + xInput = gr.Textbox(label=xType.value, interactive=True) with gr.Row(): - presets_box = gr.Textbox(placeholder="Presets go here. Click the buttton to the right...", label='Presets', interactive=True) - refresh_presets = modules.ui.ToolButton(value='\U0001f504', elem_id='refresh-button') - refresh_presets.click(fn=get_presets, outputs=presets_box) + yType = gr.Dropdown(label='Y Axis', choices=["prompts","presets","characters"], value="presets", interactive=True) + yInput = gr.Textbox(label=yType.value, interactive=True) + xType.change(set_axis, [xType, yType], []).then(fill_axis, xType, xInput) + yType.change(set_axis, [xType, yType], []).then(fill_axis, yType, yInput) + + # Testing variables and whatnot + testd = gr.Button(value="breakpoint", visible=False) + testd.click(kickback, [], []) + + generate_grid = gr.Button("generate_grid") - with gr.Accordion("Generation Parameters for testing", open=False): - state = gr.HTML(value="the state will go here") custom_chat = gr.HTML(value="") - generate_grid.click(get_params, [shared.gradio[k] for k in shared.input_elements], state).then(run, [prompt, presets_box], custom_chat) + generate_grid.click(newrun, [xInput, yInput], custom_chat) From 5f80461b7469a9a8fe505e53b7b873a2f65113cf Mon Sep 17 00:00:00 2001 From: Clay Shoaf Date: Tue, 18 Apr 2023 22:21:16 -0400 Subject: [PATCH 04/13] Ready for review --- extensions/xy_grid/script.py | 350 +++++++++++++++++------------------ 1 file changed, 165 insertions(+), 185 deletions(-) diff --git a/extensions/xy_grid/script.py b/extensions/xy_grid/script.py index 18aa11644d..c29e5ae87e 100644 --- a/extensions/xy_grid/script.py +++ b/extensions/xy_grid/script.py @@ -1,19 +1,20 @@ import os import json import datetime +import random import gradio as gr import modules.shared as shared -import modules.ui import pyparsing as pp from modules.chat import chatbot_wrapper, load_character from pathlib import Path -axis_type = {'x': "prompts", 'y': "prompts"} +axis_type = {'x': "prompts", 'y': "presets"} custom_state = {} gen_output = [] + # I had to steal this from server.py because the program freaks out if I try to `import server` def load_preset_values(preset_menu, state): generate_params = { @@ -40,13 +41,16 @@ def load_preset_values(preset_menu, state): generate_params['temperature'] = min(1.99, generate_params['temperature']) state.update(generate_params) + custom_state['preset_menu'] = preset_menu return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] + # Get all of the characters from the character folder def get_characters(): paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) return ", ".join(['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)) + # Get all of the presets from the presets folder def get_presets(): presets = [] @@ -56,11 +60,6 @@ def get_presets(): presets.append(preset) return ", ".join(presets) -# This is a workaround function because gradio has to access parameters if you want them to be current -def get_params(*args): - global custom_state - custom_state = modules.ui.gather_interface_values(*args) - return json.dumps(custom_state) # Returns the correct results for the axis type chosen by the axis dropdown box def fill_axis(option): @@ -71,6 +70,7 @@ def fill_axis(option): else: return gr.update(label=option, value=axis_get.get(option)) + # Sets the type of data each axis will use def set_axis(x, y): global axis_type @@ -78,47 +78,98 @@ def set_axis(x, y): axis_type.update({'y': y}) +# Parse the type of the X axis and alter custom_state accordingly +# If you want to add more axes, this is where you would do it. +# Add logic here, add an entry to axis_type{}, and add it to the dropdown menus +def parse_axis(axis, value): + global custom_state + global axis_type + + # PRESETS + if axis_type[axis] == "presets": + if value.strip() != "": + custom_state = load_preset_values(value.strip(), custom_state)[0] + else: + custom_state = load_preset_values(shared.gradio['preset_menu'].value, custom_state)[0] + # CHARACTERS + elif axis_type[axis] == "characters": + if value.strip() != "": + custom_state['character_menu'] = value.strip() + else: + custom_state['character_menu'] = shared.gradio["character_menu"].value + custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + # SEEDS + elif axis_type[axis] == "seeds": + if value.strip() != "": + custom_state['seed'] = value.strip() + else: + custom_state['seed'] = shared.gradio['seed'].value +# # TEMPLATE +# elif axis_type[axis] == "": +# if value.strip() != "": +# custom_state[''] = value.strip() +# else: +# custom_state[''] = shared.gradio[''].value + return None + + +def run(constant_seed, seed_value, use_history, x="", y=""): -def newrun(x="", y=""): global custom_state global gen_output global axis_type - global testa - if custom_state['custom_stopping_strings'] == None: + if constant_seed: + if seed_value == "-1": + custom_state['seed'] = random.randint(1, 2**31) + else: + custom_state['seed'] = seed_value + + temp_history = shared.history['internal'] + + # Gather output json info, from before the X/Y parameters take effect + output_json = {k: custom_state[k] for k in shared.input_elements} + + if custom_state['custom_stopping_strings'] is None: custom_state['custom_stopping_strings'] = "" # Have to format the strings because gradio makes it difficult to pass lists around - x_strings = pp.common.comma_separated_list.parseString(x).asList() - y_strings = pp.common.comma_separated_list.parseString(y).asList() + if x == "": + x_strings = "" + else: + x_strings = pp.common.comma_separated_list.parseString(x).asList() + if y == "": + y_strings = "" + else: + y_strings = pp.common.comma_separated_list.parseString(y).asList() output = "" if axis_type['x'] == axis_type['y']: return "

ERROR: both axes cannot be the same setting" - elif axis_type['x'] == "prompts": # Run as if x axis is prompts + # Run as if x axis is prompts + elif axis_type['x'] == "prompts": for i in x_strings: output = output + f"

" output = output + "" - if y_strings[0] != '': + if y_strings != '': for i in y_strings: - output = output + f"" + output = output + f"" for j in x_strings: # parse the type of the Y axis and alter custom_state accordingly - if axis_type['y'] == "presets": - custom_state = load_preset_values(i.strip(), custom_state)[0] - elif axis_type['y'] == "characters": - custom_state['character_menu'] = i.strip() - custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + parse_axis("y", i) + + # This was at the top of the function, but for some reason it broke with a recent update + if not use_history: + shared.history['internal'] = shared.history['internal'][:1] # This is the part that actually does the generating - for new in chatbot_wrapper(j.strip(), custom_state): + for new in chatbot_wrapper(j.strip().strip('"'), custom_state): gen_output = new - #gen_output = [['test', 'pest'], ['poop', 'floop']] - output = output + f"" + output = output + f"" gen_output.pop() shared.history['internal'].pop() @@ -126,130 +177,116 @@ def newrun(x="", y=""): else: output = output + "" for i in x_strings: - for new in chatbot_wrapper(i.strip(), custom_state): + for new in chatbot_wrapper(i.strip().strip('"'), custom_state): gen_output = new - #gen_output = [['test', 'pest'], ['poop', 'floop']] - output = output + f"" + output = output + f"" - # Remove the last outputs so they don't influence future generations + # Remove the last outputs, so they don't influence future generations gen_output.pop() shared.history['internal'].pop() output = output + "" - - elif axis_type['y'] == "prompts": # Run as if y axis is prompts - if x_strings[0] != '': - for i in x_strings: - output = output + f"" - output = output + "" + # Run as if y axis is prompts + elif axis_type['y'] == "prompts": + for i in x_strings: + output = output + f"" + output = output + "" + if x_strings != '': for i in y_strings: - output = output + f"" + output = output + f"" for j in x_strings: # parse the type of the X axis and alter custom_state accordingly - if axis_type['x'] == "presets": - custom_state = load_preset_values(j.strip(), custom_state)[0] - elif axis_type['x'] == "characters": - custom_state['character_menu'] = j.strip() - custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + parse_axis("x", j) + + # This was at the top of the function, but for some reason it broke with a recent update + if not use_history: + shared.history['internal'] = shared.history['internal'][:1] # This is the part that actually does the generating - for new in chatbot_wrapper(i.strip(), custom_state): + for new in chatbot_wrapper(i.strip().strip('"'), custom_state): gen_output = new - #gen_output = [['test', 'pest'], ['poop', 'floop']] - output = output + f"" + output = output + f"" gen_output.pop() shared.history['internal'].pop() output = output + "" else: - for i in x_strings: - output = output + f"" - output = output + "" for i in y_strings: - for new in chatbot_wrapper(i.strip(), custom_state): + for new in chatbot_wrapper(i.strip().strip('"'), custom_state): gen_output = new - #gen_output = [['test', 'pest'], ['poop', 'floop']] - output = output + f"" + output = output + f"" - # Remove the last outputs so they don't influence future generations + # Remove the last outputs, so they don't influence future generations gen_output.pop() shared.history['internal'].pop() - - - else: # Take the prompts from custom_state['textbox'] + # Take the prompts from custom_state['textbox'] + else: for i in x_strings: output = output + f"" output = output + "" - if y_strings[0] != '' and x_strings[0] != '': + if y_strings != '' and x_strings != '': for i in y_strings: - output = output + f"" + output = output + f"" for j in x_strings: # parse the types of the axes and alter custom_state accordingly - if axis_type['y'] == "presets": - custom_state = load_preset_values(i.strip(), custom_state)[0] - elif axis_type['y'] == "characters": - custom_state['character_menu'] = i.strip() - custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) - - if axis_type['x'] == "presets": - custom_state = load_preset_values(j.strip(), custom_state)[0] - elif axis_type['x'] == "characters": - custom_state['character_menu'] = j.strip() - custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + parse_axis("y", i) + parse_axis("x", j) + + # This was at the top of the function, but for some reason it broke with a recent update + if not use_history: + shared.history['internal'] = shared.history['internal'][:1] # This is the part that actually does the generating for new in chatbot_wrapper(custom_state['textbox'].strip(), custom_state): gen_output = new - output = output + f"" + output = output + f"" gen_output.pop() shared.history['internal'].pop() output = output + "" - elif x_strings[0] != '': + elif x_strings != '': output = output + "" - for i in x_strings: + for j in x_strings: # parse the types of the axes and alter custom_state accordingly - if axis_type['x'] == "presets": - custom_state = load_preset_values(i.strip(), custom_state)[0] - elif axis_type['x'] == "characters": - custom_state['character_menu'] = i.strip() - custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + parse_axis("x", j) + + # This was at the top of the function, but for some reason it broke with a recent update + if not use_history: + shared.history['internal'] = shared.history['internal'][:1] # Run the actual text generator for new in chatbot_wrapper(custom_state['textbox'].strip(), custom_state): gen_output = new - output = output + f"" + output = output + f"" - # Remove the last outputs so they don't influence future generations + # Remove the last outputs, so they don't influence future generations gen_output.pop() shared.history['internal'].pop() output = output + "" - - elif y_strings[0] != '': - output = output + "" + elif y_strings != '': for i in y_strings: # parse the types of the axes and alter custom_state accordingly - if axis_type['y'] == "presets": - custom_state = load_preset_values(i.strip(), custom_state)[0] - elif axis_type['y'] == "characters": - custom_state['character_menu'] = i.strip() - custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + parse_axis("y", i) + + # This was at the top of the function, but for some reason it broke with a recent update + if not use_history: + shared.history['internal'] = shared.history['internal'][:1] # Run the actual text generator for new in chatbot_wrapper(custom_state['textbox'].strip(), custom_state): gen_output = new - output = output + f"" + output = output + f"" - # Remove the last outputs so they don't influence future generations + # Remove the last outputs, so they don't influence future generations gen_output.pop() shared.history['internal'].pop() @@ -259,92 +296,50 @@ def newrun(x="", y=""): output = output + "
{i.strip()}
{i}
{i.strip()}{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}
{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}
{i.strip()}
{i.strip()}
{i}
{i.strip()}{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}
{i.strip()}
{i}{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}
{i.strip()}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}
{i.strip()}
{i}
{i.strip()}{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}
{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}
{i}{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}
{i.strip()}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}
" # Save the output to a file - # Useful for large grids that don't display well in gradio - save_filename = f"{datetime.datetime.now().strftime('%Y_%m_%d_%f')}.html" - with open(Path(f"extensions/xy_grid/outputs/{save_filename}"), 'w') as outfile: + output_folder = Path("extensions/xy_grid/outputs") + if not Path(output_folder).exists(): + os.mkdir(output_folder) + output_filename = Path(f"{datetime.datetime.now().strftime('%Y_%m_%d_%H%M%S')}") + with open(Path(f"{output_folder}/{output_filename}.html"), 'w') as outfile: outfile.write(output) + with open(Path(f"{output_folder}/{output_filename}.json"), 'w') as outparams: + outparams.write(json.dumps(output_json)) - # Trying to include a link to easily open the html file in a new tab, but I think this is gonna be more confusing than I expected - output = output + f"

[ open html file ]

" + # Include a link to the generated HTML file + output = output + f"

[ open html file 🔗 ]

" + # Clean up some of the changes that were made during this generation + custom_state['seed'] = -1 + shared.history['internal'] = temp_history return output +# Necessary for some stuff because gradio +def swap_axes(x_menu, x_data, y_menu, y_data): + return y_menu, y_data, gr.update(label=y_menu), x_menu, x_data, gr.update(label=x_menu) +def toggle_visible(var): + if not var: + custom_state['seed'] = -1 + return gr.update(visible=var) - -# The main function that generates the output, formats the html table, and returns it to the interface -def run(x="", y=""): - global custom_state - global gen_output - - output = "" - - # Have to format the strings because gradio makes it difficult to pass lists around - x_strings = pp.common.comma_separated_list.parseString(x).asList() - y_strings = pp.common.comma_separated_list.parseString(y).asList() - - for i in y_strings: - output = output + f"" - output = output + "" - for i in x_strings: - output = output + f"" - if y_strings[0] != '': - for j in y_strings: - custom_state = load_preset_values(j.strip(), custom_state)[0] - - # This is the part that actually does the generating - for new in chatbot_wrapper(i.strip(), custom_state): - gen_output = new - - output = output + f"" - gen_output.pop() - shared.history['internal'].pop() - - output = output + "" - else: - for new in chatbot_wrapper(i.strip(), custom_state): - gen_output = new - output = output + f"" - - # Remove the last outputs so they don't influence future generations - gen_output.pop() - shared.history['internal'].pop() - - output = output + "" - output = output + "
{i.strip()}
{i}{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}
{custom_state['name1']}: {gen_output[-1][0]}
{custom_state['name2']}: {gen_output[-1][1]}
" - - # Save the output to a file - # Useful for large grids that don't display well in gradio - save_filename = f"{datetime.datetime.now().strftime('%Y_%m_%d_%f')}.html" - with open(Path(f"extensions/xy_grid/outputs/{save_filename}"), 'w') as outfile: - outfile.write(output) - - # Trying to include a link to easily open the html file in a new tab, but I think this is gonna be more confusing than I expected - output = output + f"

open html file" - return output - - - -# Necessary for som stuff because gradio -def kickback(flubby=""): - return flubby - axis_get = { 'presets': get_presets(), - 'prompts': kickback(), - 'characters': get_characters() + 'prompts': "", + 'characters': get_characters(), + 'seeds': "-1" } + # Create the interface for the extension (this runs first) def ui(): global custom_state - global testa global axis_type + global axis_get - # Grab all of the variable from shared.gradio and put them in the custom_state dictionary + # Grab all the variable from shared.gradio and put them in the custom_state dictionary custom_state.update({k: v for k, v in zip([key for key in shared.gradio if not isinstance(shared.gradio[key], (gr.Blocks, gr.Button, gr.State))], [shared.gradio[k].value for k in [key for key in shared.gradio] if not isinstance(shared.gradio[k], (gr.Blocks, gr.Button, gr.State))])}) # Track changes to all variables in shared.gradio @@ -356,14 +351,8 @@ def ui(): shared.gradio['character_menu'].change(lambda x: custom_state.update({'character_menu': x}), shared.gradio['character_menu'], []) shared.gradio['character_picture'].change(lambda x: custom_state.update({'character_picture': x}), shared.gradio['character_picture'], []) shared.gradio['chat_generation_attempts'].change(lambda x: custom_state.update({'chat_generation_attempts': x}), shared.gradio['chat_generation_attempts'], []) - #shared.gradio['Chat input'].change(lambda x: custom_state.update({'Chat input': x}), shared.gradio['Chat input'], []) shared.gradio['chat_prompt_size'].change(lambda x: custom_state.update({'chat_prompt_size': x}), shared.gradio['chat_prompt_size'], []) - #shared.gradio['Clear history'].change(lambda x: custom_state.update({'Clear history': x}), shared.gradio['Clear history'], []) - #shared.gradio['Clear history-cancel'].change(lambda x: custom_state.update({'Clear history-cancel': x}), shared.gradio['Clear history-cancel'], []) - #shared.gradio['Clear history-confirm'].change(lambda x: custom_state.update({'Clear history-confirm': x}), shared.gradio['Clear history-confirm'], []) shared.gradio['context'].change(lambda x: custom_state.update({'context': x}), shared.gradio['context'], []) - #shared.gradio['Continue'].change(lambda x: custom_state.update({'Continue': x}), shared.gradio['Continue'], []) - #shared.gradio['Copy last reply'].change(lambda x: custom_state.update({'Copy last reply': x}), shared.gradio['Copy last reply'], []) shared.gradio['cpu'].change(lambda x: custom_state.update({'cpu': x}), shared.gradio['cpu'], []) shared.gradio['cpu_memory'].change(lambda x: custom_state.update({'cpu_memory': x}), shared.gradio['cpu_memory'], []) shared.gradio['custom_model_menu'].change(lambda x: custom_state.update({'custom_model_menu': x}), shared.gradio['custom_model_menu'], []) @@ -372,25 +361,18 @@ def ui(): shared.gradio['display'].change(lambda x: custom_state.update({'display': x}), shared.gradio['display'], []) shared.gradio['do_sample'].change(lambda x: custom_state.update({'do_sample': x}), shared.gradio['do_sample'], []) shared.gradio['download'].change(lambda x: custom_state.update({'download': x}), shared.gradio['download'], []) - #shared.gradio['download_button'].change(lambda x: custom_state.update({'download_button': x}), shared.gradio['download_button'], []) - #shared.gradio['download_model_button'].change(lambda x: custom_state.update({'download_model_button': x}), shared.gradio['download_model_button'], []) shared.gradio['early_stopping'].change(lambda x: custom_state.update({'early_stopping': x}), shared.gradio['early_stopping'], []) shared.gradio['encoder_repetition_penalty'].change(lambda x: custom_state.update({'encoder_repetition_penalty': x}), shared.gradio['encoder_repetition_penalty'], []) shared.gradio['end_of_turn'].change(lambda x: custom_state.update({'end_of_turn': x}), shared.gradio['end_of_turn'], []) shared.gradio['extensions_menu'].change(lambda x: custom_state.update({'extensions_menu': x}), shared.gradio['extensions_menu'], []) - #shared.gradio['Generate'].change(lambda x: custom_state.update({'Generate': x}), shared.gradio['Generate'], []) shared.gradio['gpu_memory_0'].change(lambda x: custom_state.update({'gpu_memory_0': x}), shared.gradio['gpu_memory_0'], []) shared.gradio['greeting'].change(lambda x: custom_state.update({'greeting': x}), shared.gradio['greeting'], []) shared.gradio['groupsize'].change(lambda x: custom_state.update({'groupsize': x}), shared.gradio['groupsize'], []) - #shared.gradio['Impersonate'].change(lambda x: custom_state.update({'Impersonate': x}), shared.gradio['Impersonate'], []) shared.gradio['instruction_template'].change(lambda x: custom_state.update({'instruction_template': x}), shared.gradio['instruction_template'], []) - #shared.gradio['interface'].change(lambda x: custom_state.update({'interface': x}), shared.gradio['interface'], []) shared.gradio['interface_modes_menu'].change(lambda x: custom_state.update({'interface_modes_menu': x}), shared.gradio['interface_modes_menu'], []) - #shared.gradio['interface_state'].change(lambda x: custom_state.update({'interface_state': x}), shared.gradio['interface_state'], []) shared.gradio['length_penalty'].change(lambda x: custom_state.update({'length_penalty': x}), shared.gradio['length_penalty'], []) shared.gradio['load_in_8bit'].change(lambda x: custom_state.update({'load_in_8bit': x}), shared.gradio['load_in_8bit'], []) shared.gradio['lora_menu'].change(lambda x: custom_state.update({'lora_menu': x}), shared.gradio['lora_menu'], []) - #shared.gradio['lora_menu_apply'].change(lambda x: custom_state.update({'lora_menu_apply': x}), shared.gradio['lora_menu_apply'], []) shared.gradio['max_new_tokens'].change(lambda x: custom_state.update({'max_new_tokens': x}), shared.gradio['max_new_tokens'], []) shared.gradio['min_length'].change(lambda x: custom_state.update({'min_length': x}), shared.gradio['min_length'], []) shared.gradio['mode'].change(lambda x: custom_state.update({'mode': x}), shared.gradio['mode'], []) @@ -404,17 +386,10 @@ def ui(): shared.gradio['penalty_alpha'].change(lambda x: custom_state.update({'penalty_alpha': x}), shared.gradio['penalty_alpha'], []) shared.gradio['pre_layer'].change(lambda x: custom_state.update({'pre_layer': x}), shared.gradio['pre_layer'], []) shared.gradio['preset_menu'].change(lambda x: custom_state.update({'preset_menu': x}), shared.gradio['preset_menu'], []) - #shared.gradio['Regenerate'].change(lambda x: custom_state.update({'Regenerate': x}), shared.gradio['Regenerate'], []) - #shared.gradio['Remove last'].change(lambda x: custom_state.update({'Remove last': x}), shared.gradio['Remove last'], []) shared.gradio['repetition_penalty'].change(lambda x: custom_state.update({'repetition_penalty': x}), shared.gradio['repetition_penalty'], []) - #shared.gradio['Replace last reply'].change(lambda x: custom_state.update({'Replace last reply': x}), shared.gradio['Replace last reply'], []) - #shared.gradio['reset_interface'].change(lambda x: custom_state.update({'reset_interface': x}), shared.gradio['reset_interface'], []) shared.gradio['seed'].change(lambda x: custom_state.update({'seed': x}), shared.gradio['seed'], []) - #shared.gradio['Send dummy message'].change(lambda x: custom_state.update({'Send dummy message': x}), shared.gradio['Send dummy message'], []) - #shared.gradio['Send dummy reply'].change(lambda x: custom_state.update({'Send dummy reply': x}), shared.gradio['Send dummy reply'], []) shared.gradio['skip_special_tokens'].change(lambda x: custom_state.update({'skip_special_tokens': x}), shared.gradio['skip_special_tokens'], []) shared.gradio['softprompts_menu'].change(lambda x: custom_state.update({'softprompts_menu': x}), shared.gradio['softprompts_menu'], []) - #shared.gradio['Stop'].change(lambda x: custom_state.update({'Stop': x}), shared.gradio['Stop'], []) shared.gradio['stop_at_newline'].change(lambda x: custom_state.update({'stop_at_newline': x}), shared.gradio['stop_at_newline'], []) shared.gradio['temperature'].change(lambda x: custom_state.update({'temperature': x}), shared.gradio['temperature'], []) shared.gradio['textbox'].change(lambda x: custom_state.update({'textbox': x}), shared.gradio['textbox'], []) @@ -422,7 +397,6 @@ def ui(): shared.gradio['top_p'].change(lambda x: custom_state.update({'top_p': x}), shared.gradio['top_p'], []) shared.gradio['truncation_length'].change(lambda x: custom_state.update({'truncation_length': x}), shared.gradio['truncation_length'], []) shared.gradio['typical_p'].change(lambda x: custom_state.update({'typical_p': x}), shared.gradio['typical_p'], []) - #shared.gradio['Upload character'].change(lambda x: custom_state.update({'Upload character': x}), shared.gradio['Upload character'], []) shared.gradio['upload_chat_history'].change(lambda x: custom_state.update({'upload_chat_history': x}), shared.gradio['upload_chat_history'], []) shared.gradio['upload_img_bot'].change(lambda x: custom_state.update({'upload_img_bot': x}), shared.gradio['upload_img_bot'], []) shared.gradio['upload_img_tavern'].change(lambda x: custom_state.update({'upload_img_tavern': x}), shared.gradio['upload_img_tavern'], []) @@ -435,20 +409,26 @@ def ui(): # Axis selections and inputs with gr.Row(): - xType = gr.Dropdown(label='X Axis', choices=list(["prompts","presets","characters"]), value="prompts", interactive=True) - xInput = gr.Textbox(label=xType.value, interactive=True) + x_type = gr.Dropdown(label='X Axis', choices=list(["prompts", "presets", "characters", "seeds"]), value="prompts", interactive=True) + x_input = gr.Textbox(label=x_type.value, interactive=True) with gr.Row(): - yType = gr.Dropdown(label='Y Axis', choices=["prompts","presets","characters"], value="presets", interactive=True) - yInput = gr.Textbox(label=yType.value, interactive=True) - xType.change(set_axis, [xType, yType], []).then(fill_axis, xType, xInput) - yType.change(set_axis, [xType, yType], []).then(fill_axis, yType, yInput) - - # Testing variables and whatnot - testd = gr.Button(value="breakpoint", visible=False) - testd.click(kickback, [], []) - + y_type = gr.Dropdown(label='Y Axis', choices=["prompts", "presets", "characters", "seeds"], value="presets", interactive=True) + y_input = gr.Textbox(label=y_type.value, value=axis_get[y_type.value], interactive=True) + x_type.select(set_axis, [x_type, y_type], []).then(fill_axis, x_type, x_input) + y_type.select(set_axis, [x_type, y_type], []).then(fill_axis, y_type, y_input) + x_type.change(set_axis, [x_type, y_type], []) + y_type.change(set_axis, [x_type, y_type], []) + with gr.Row(): + swap_xy = gr.Button(value='Swap X/Y Axes 🔀') + with gr.Row(): + seed_input = gr.Checkbox(label='Use a constant seed', value=False) + use_history = gr.Checkbox(label='Use character\'s chat history', value=False) + with gr.Row(): + seed_value = gr.Textbox(label='Seed', value="-1", visible=False, interactive=True) + seed_input.change(toggle_visible, seed_input, seed_value) + swap_xy.click(swap_axes, [x_type, x_input, y_type, y_input], [x_type, x_input, x_input, y_type, y_input, y_input]) generate_grid = gr.Button("generate_grid") custom_chat = gr.HTML(value="") - generate_grid.click(newrun, [xInput, yInput], custom_chat) + generate_grid.click(run, [seed_input, seed_value, use_history, x_input, y_input], custom_chat) From 3826daa95d8d9b0b4c7b9c0766363e8706982570 Mon Sep 17 00:00:00 2001 From: Clay Shoaf Date: Tue, 18 Apr 2023 23:10:21 -0400 Subject: [PATCH 05/13] update button title, change accordian to false --- extensions/xy_grid/script.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/extensions/xy_grid/script.py b/extensions/xy_grid/script.py index c29e5ae87e..3b0ecb5f46 100644 --- a/extensions/xy_grid/script.py +++ b/extensions/xy_grid/script.py @@ -147,7 +147,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): if axis_type['x'] == axis_type['y']: return "

ERROR: both axes cannot be the same setting" - + # Run as if x axis is prompts elif axis_type['x'] == "prompts": for i in x_strings: @@ -160,7 +160,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): # parse the type of the Y axis and alter custom_state accordingly parse_axis("y", i) - + # This was at the top of the function, but for some reason it broke with a recent update if not use_history: shared.history['internal'] = shared.history['internal'][:1] @@ -271,7 +271,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): shared.history['internal'].pop() output = output + "" - + elif y_strings != '': for i in y_strings: # parse the types of the axes and alter custom_state accordingly @@ -280,7 +280,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): # This was at the top of the function, but for some reason it broke with a recent update if not use_history: shared.history['internal'] = shared.history['internal'][:1] - + # Run the actual text generator for new in chatbot_wrapper(custom_state['textbox'].strip(), custom_state): gen_output = new @@ -405,7 +405,7 @@ def ui(): shared.gradio['wbits'].change(lambda x: custom_state.update({'wbits': x}), shared.gradio['wbits'], []) shared.gradio['your_picture'].change(lambda x: custom_state.update({'your_picture': x}), shared.gradio['your_picture'], []) - with gr.Accordion("XY Grid", open=True): + with gr.Accordion("XY Grid", open=False): # Axis selections and inputs with gr.Row(): @@ -428,7 +428,7 @@ def ui(): seed_input.change(toggle_visible, seed_input, seed_value) swap_xy.click(swap_axes, [x_type, x_input, y_type, y_input], [x_type, x_input, x_input, y_type, y_input, y_input]) - generate_grid = gr.Button("generate_grid") + generate_grid = gr.Button("Generate Grid") custom_chat = gr.HTML(value="") generate_grid.click(run, [seed_input, seed_value, use_history, x_input, y_input], custom_chat) From 4d650f93a48ee90e91aa4a34b6a2ceb43222a672 Mon Sep 17 00:00:00 2001 From: Clay Shoaf Date: Fri, 21 Apr 2023 12:09:38 -0400 Subject: [PATCH 06/13] added more axis options, formatting Adding these axis options was causing weird bugs to start appearing that, as far as I can tell, should have appeared before. I added some logic to the main run function to clean them up. It's getting pretty crowded in there though. I might move the run function to its own file. Not sure yet. I also changed some formatting around and got rid of some variables. Still haven't pruned the trackers. --- extensions/xy_grid/script.py | 161 ++++++++++++++++++++++++----------- 1 file changed, 110 insertions(+), 51 deletions(-) diff --git a/extensions/xy_grid/script.py b/extensions/xy_grid/script.py index 3b0ecb5f46..4501932e7d 100644 --- a/extensions/xy_grid/script.py +++ b/extensions/xy_grid/script.py @@ -8,12 +8,13 @@ import pyparsing as pp from modules.chat import chatbot_wrapper, load_character +from modules.html_generator import convert_to_markdown from pathlib import Path axis_type = {'x': "prompts", 'y': "presets"} custom_state = {} gen_output = [] - +axis_options = ["prompts", "presets", "characters", "seed", "max_new_tokens", "temperature", "top_p", "top_k", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "no_repeat_ngram_size", "min_length"] # I had to steal this from server.py because the program freaks out if I try to `import server` def load_preset_values(preset_menu, state): @@ -54,21 +55,25 @@ def get_characters(): # Get all of the presets from the presets folder def get_presets(): presets = [] - filenames = os.listdir("presets/") + filenames = sorted(os.listdir("presets/")) for file in filenames: preset = file[:-4] presets.append(preset) + presets.remove("Verbose (Beam Search)") return ", ".join(presets) # Returns the correct results for the axis type chosen by the axis dropdown box def fill_axis(option): - global axis_get global custom_state - if option == "prompts": + if option == "presets": + return gr.update(label=option, value=get_presets()) + elif option == "characters": + return gr.update(label=option, value=get_characters()) + elif option == "prompts": return gr.update(label=option, value=custom_state['textbox']) else: - return gr.update(label=option, value=axis_get.get(option)) + return gr.update(label=option, value=custom_state[option]) # Sets the type of data each axis will use @@ -80,7 +85,7 @@ def set_axis(x, y): # Parse the type of the X axis and alter custom_state accordingly # If you want to add more axes, this is where you would do it. -# Add logic here, add an entry to axis_type{}, and add it to the dropdown menus +# Add logic here and include it in axis_options def parse_axis(axis, value): global custom_state global axis_type @@ -98,18 +103,24 @@ def parse_axis(axis, value): else: custom_state['character_menu'] = shared.gradio["character_menu"].value custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) - # SEEDS - elif axis_type[axis] == "seeds": + # FLOATS + elif axis_type[axis] in ("seed", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty"): + if value.strip() != "": + custom_state[axis_type[axis]] = float(value.strip()) + else: + custom_state[axis_type[axis]] = shared.gradio[axis_type[axis]].value + # INTS + elif axis_type[axis] in ("top_k", "max_new_tokens", "no_repeat_ngram_size", "min_length"): + if value.strip() != "": + custom_state[axis_type[axis]] = int(value.strip()) + else: + custom_state[axis_type[axis]] = shared.gradio[axis_type[axis]].value + # ANY + else: if value.strip() != "": - custom_state['seed'] = value.strip() + custom_state[axis_type[axis]] = value.strip() else: - custom_state['seed'] = shared.gradio['seed'].value -# # TEMPLATE -# elif axis_type[axis] == "": -# if value.strip() != "": -# custom_state[''] = value.strip() -# else: -# custom_state[''] = shared.gradio[''].value + custom_state[axis_type[axis]] = shared.gradio[axis_type[axis]].value return None @@ -119,13 +130,15 @@ def run(constant_seed, seed_value, use_history, x="", y=""): global gen_output global axis_type + shared.args.no_stream = True if constant_seed: if seed_value == "-1": custom_state['seed'] = random.randint(1, 2**31) else: custom_state['seed'] = seed_value - temp_history = shared.history['internal'] + temp_internal = shared.history['internal'] + temp_visible = shared.history['visible'] # Gather output json info, from before the X/Y parameters take effect output_json = {k: custom_state[k] for k in shared.input_elements} @@ -143,11 +156,11 @@ def run(constant_seed, seed_value, use_history, x="", y=""): else: y_strings = pp.common.comma_separated_list.parseString(y).asList() - output = "" + output = "
" + f"" if axis_type['x'] == axis_type['y']: return "

ERROR: both axes cannot be the same setting" - + # Run as if x axis is prompts elif axis_type['x'] == "prompts": for i in x_strings: @@ -160,30 +173,45 @@ def run(constant_seed, seed_value, use_history, x="", y=""): # parse the type of the Y axis and alter custom_state accordingly parse_axis("y", i) - + # This was at the top of the function, but for some reason it broke with a recent update if not use_history: shared.history['internal'] = shared.history['internal'][:1] + shared.history['visible'] = shared.history['visible'][:1] # This is the part that actually does the generating for new in chatbot_wrapper(j.strip().strip('"'), custom_state): gen_output = new - output = output + f"

" + if len(gen_output) == 0: + gen_output = [['','']] + user_output = convert_to_markdown(gen_output[-1][0]) + bot_output = convert_to_markdown(gen_output[-1][1]) + output = output + f"" + + # Remove the last outputs, so they don't influence future generations gen_output.pop() - shared.history['internal'].pop() + if len(shared.history['internal']) > 1: + shared.history['internal'].pop() output = output + "" + else: output = output + "" for i in x_strings: for new in chatbot_wrapper(i.strip().strip('"'), custom_state): gen_output = new - output = output + f"" + + if len(gen_output) == 0: + gen_output = [['','']] + user_output = convert_to_markdown(gen_output[-1][0]) + bot_output = convert_to_markdown(gen_output[-1][1]) + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() - shared.history['internal'].pop() + if len(shared.history['internal']) > 1: + shared.history['internal'].pop() output = output + "" @@ -203,25 +231,40 @@ def run(constant_seed, seed_value, use_history, x="", y=""): # This was at the top of the function, but for some reason it broke with a recent update if not use_history: shared.history['internal'] = shared.history['internal'][:1] + shared.history['visible'] = shared.history['visible'][:1] # This is the part that actually does the generating for new in chatbot_wrapper(i.strip().strip('"'), custom_state): gen_output = new - output = output + f"" + if len(gen_output) == 0: + gen_output = [['','']] + user_output = convert_to_markdown(gen_output[-1][0]) + bot_output = convert_to_markdown(gen_output[-1][1]) + output = output + f"" + + # Remove the last outputs, so they don't influence future generations gen_output.pop() - shared.history['internal'].pop() + if len(shared.history['internal']) > 1: + shared.history['internal'].pop() output = output + "" + else: for i in y_strings: for new in chatbot_wrapper(i.strip().strip('"'), custom_state): gen_output = new - output = output + f"" + + if len(gen_output) == 0: + gen_output = [['','']] + user_output = convert_to_markdown(gen_output[-1][0]) + bot_output = convert_to_markdown(gen_output[-1][1]) + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() - shared.history['internal'].pop() + if len(shared.history['internal']) > 1: + shared.history['internal'].pop() # Take the prompts from custom_state['textbox'] else: @@ -239,14 +282,22 @@ def run(constant_seed, seed_value, use_history, x="", y=""): # This was at the top of the function, but for some reason it broke with a recent update if not use_history: shared.history['internal'] = shared.history['internal'][:1] + shared.history['visible'] = shared.history['visible'][:1] # This is the part that actually does the generating for new in chatbot_wrapper(custom_state['textbox'].strip(), custom_state): gen_output = new - output = output + f"" + if len(gen_output) == 0: + gen_output = [['','']] + user_output = convert_to_markdown(gen_output[-1][0]) + bot_output = convert_to_markdown(gen_output[-1][1]) + output = output + f"" + + # Remove the last outputs, so they don't influence future generations gen_output.pop() - shared.history['internal'].pop() + if len(shared.history['internal']) > 1: + shared.history['internal'].pop() output = output + "" @@ -260,18 +311,25 @@ def run(constant_seed, seed_value, use_history, x="", y=""): # This was at the top of the function, but for some reason it broke with a recent update if not use_history: shared.history['internal'] = shared.history['internal'][:1] + shared.history['visible'] = shared.history['visible'][:1] # Run the actual text generator for new in chatbot_wrapper(custom_state['textbox'].strip(), custom_state): gen_output = new - output = output + f"" + + if len(gen_output) == 0: + gen_output = [['','']] + user_output = convert_to_markdown(gen_output[-1][0]) + bot_output = convert_to_markdown(gen_output[-1][1]) + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() - shared.history['internal'].pop() + if len(shared.history['internal']) > 1: + shared.history['internal'].pop() output = output + "" - + elif y_strings != '': for i in y_strings: # parse the types of the axes and alter custom_state accordingly @@ -280,15 +338,22 @@ def run(constant_seed, seed_value, use_history, x="", y=""): # This was at the top of the function, but for some reason it broke with a recent update if not use_history: shared.history['internal'] = shared.history['internal'][:1] - + shared.history['visible'] = shared.history['visible'][:1] + # Run the actual text generator for new in chatbot_wrapper(custom_state['textbox'].strip(), custom_state): gen_output = new - output = output + f"" + + if len(gen_output) == 0: + gen_output = [['','']] + user_output = convert_to_markdown(gen_output[-1][0]) + bot_output = convert_to_markdown(gen_output[-1][1]) + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() - shared.history['internal'].pop() + if len(shared.history['internal']) > 1: + shared.history['internal'].pop() else: return "

ERROR: both fields are empty" @@ -306,14 +371,17 @@ def run(constant_seed, seed_value, use_history, x="", y=""): outparams.write(json.dumps(output_json)) # Include a link to the generated HTML file - output = output + f"

[ open html file 🔗 ]

" + output = f"

[ open html file 🔗 ]



" + output # Clean up some of the changes that were made during this generation custom_state['seed'] = -1 - shared.history['internal'] = temp_history + shared.history['internal'] = temp_internal + shared.history['visible'] = temp_visible return output + + # Necessary for some stuff because gradio def swap_axes(x_menu, x_data, y_menu, y_data): return y_menu, y_data, gr.update(label=y_menu), x_menu, x_data, gr.update(label=x_menu) @@ -325,19 +393,10 @@ def toggle_visible(var): return gr.update(visible=var) -axis_get = { - 'presets': get_presets(), - 'prompts': "", - 'characters': get_characters(), - 'seeds': "-1" - } - - # Create the interface for the extension (this runs first) def ui(): global custom_state global axis_type - global axis_get # Grab all the variable from shared.gradio and put them in the custom_state dictionary custom_state.update({k: v for k, v in zip([key for key in shared.gradio if not isinstance(shared.gradio[key], (gr.Blocks, gr.Button, gr.State))], [shared.gradio[k].value for k in [key for key in shared.gradio] if not isinstance(shared.gradio[k], (gr.Blocks, gr.Button, gr.State))])}) @@ -409,11 +468,11 @@ def ui(): # Axis selections and inputs with gr.Row(): - x_type = gr.Dropdown(label='X Axis', choices=list(["prompts", "presets", "characters", "seeds"]), value="prompts", interactive=True) + x_type = gr.Dropdown(label='X Axis', choices=axis_options, value="prompts", interactive=True) x_input = gr.Textbox(label=x_type.value, interactive=True) with gr.Row(): - y_type = gr.Dropdown(label='Y Axis', choices=["prompts", "presets", "characters", "seeds"], value="presets", interactive=True) - y_input = gr.Textbox(label=y_type.value, value=axis_get[y_type.value], interactive=True) + y_type = gr.Dropdown(label='Y Axis', choices=axis_options, value="presets", interactive=True) + y_input = gr.Textbox(label=y_type.value, value=get_presets, interactive=True) x_type.select(set_axis, [x_type, y_type], []).then(fill_axis, x_type, x_input) y_type.select(set_axis, [x_type, y_type], []).then(fill_axis, y_type, y_input) x_type.change(set_axis, [x_type, y_type], []) @@ -428,7 +487,7 @@ def ui(): seed_input.change(toggle_visible, seed_input, seed_value) swap_xy.click(swap_axes, [x_type, x_input, y_type, y_input], [x_type, x_input, x_input, y_type, y_input, y_input]) - generate_grid = gr.Button("Generate Grid") + generate_grid = gr.Button("generate_grid") custom_chat = gr.HTML(value="") generate_grid.click(run, [seed_input, seed_value, use_history, x_input, y_input], custom_chat) From 7d2152e0ebc1710d9716f5d7f5526b1815e165d7 Mon Sep 17 00:00:00 2001 From: Clay Shoaf Date: Fri, 21 Apr 2023 16:22:35 -0400 Subject: [PATCH 07/13] added instruct mode when you load characters, the instruction following characters will also be loaded with the "instruction-following/prefix" --- extensions/xy_grid/script.py | 87 ++++++++++++++++++++++++++++++------ 1 file changed, 74 insertions(+), 13 deletions(-) diff --git a/extensions/xy_grid/script.py b/extensions/xy_grid/script.py index 4501932e7d..beaa5d59ff 100644 --- a/extensions/xy_grid/script.py +++ b/extensions/xy_grid/script.py @@ -14,7 +14,7 @@ axis_type = {'x': "prompts", 'y': "presets"} custom_state = {} gen_output = [] -axis_options = ["prompts", "presets", "characters", "seed", "max_new_tokens", "temperature", "top_p", "top_k", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "no_repeat_ngram_size", "min_length"] +axis_options = ["prompts", "presets", "characters", "instruction template", "seed", "max_new_tokens", "temperature", "top_p", "top_k", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "no_repeat_ngram_size", "min_length"] # I had to steal this from server.py because the program freaks out if I try to `import server` def load_preset_values(preset_menu, state): @@ -49,7 +49,18 @@ def load_preset_values(preset_menu, state): # Get all of the characters from the character folder def get_characters(): paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) - return ", ".join(['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)) + instructors = [] + filenames = sorted(os.listdir("characters/instruction-following/")) + for file in filenames: + instructor = "instruction-following/" + file[:-5] + instructors.append(instructor) + return ", ".join(['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower) + instructors) + + +# Get all of the instruction following templates from the character folder +def get_instruct(): + paths = (x for x in Path('characters/instruction-following').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) + return ", ".join(['None'] + sorted(set((k.stem for k in paths)), key=str.lower)) # Get all of the presets from the presets folder @@ -70,6 +81,8 @@ def fill_axis(option): return gr.update(label=option, value=get_presets()) elif option == "characters": return gr.update(label=option, value=get_characters()) + elif option == "instruction template": + return gr.update(label=option, value=get_instruct()) elif option == "prompts": return gr.update(label=option, value=custom_state['textbox']) else: @@ -98,11 +111,28 @@ def parse_axis(axis, value): custom_state = load_preset_values(shared.gradio['preset_menu'].value, custom_state)[0] # CHARACTERS elif axis_type[axis] == "characters": + if value.split("/")[0] == "instruction-following": + custom_state['mode'] = "instruct" + else: + custom_state['mode'] = "cai-chat" + value = value.split("/")[-1] + if custom_state['mode'] == "instruct": + char_type = 'instruction_template' + else: + char_type = 'character_menu' if value.strip() != "": - custom_state['character_menu'] = value.strip() + custom_state[char_type] = value.strip() else: - custom_state['character_menu'] = shared.gradio["character_menu"].value - custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['character_menu'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + custom_state[char_type] = shared.gradio[char_type].value + custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state[char_type], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) + # INSTRUCT + elif axis_type[axis] == "instruction template": + custom_state['mode'] = 'instruct' + if value.strip() != "": + custom_state['instruction_template'] = value.strip() + else: + custom_state['instruction_template'] = shared.gradio["instruction_template"].value + custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['instruction_template'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) # FLOATS elif axis_type[axis] in ("seed", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty"): if value.strip() != "": @@ -139,6 +169,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): temp_internal = shared.history['internal'] temp_visible = shared.history['visible'] + temp_custom_state = custom_state # Gather output json info, from before the X/Y parameters take effect output_json = {k: custom_state[k] for k in shared.input_elements} @@ -187,7 +218,11 @@ def run(constant_seed, seed_value, use_history, x="", y=""): gen_output = [['','']] user_output = convert_to_markdown(gen_output[-1][0]) bot_output = convert_to_markdown(gen_output[-1][1]) - output = output + f"" + + if custom_state['mode'] == 'instruct': + output = output + f"" + else: + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -206,7 +241,11 @@ def run(constant_seed, seed_value, use_history, x="", y=""): gen_output = [['','']] user_output = convert_to_markdown(gen_output[-1][0]) bot_output = convert_to_markdown(gen_output[-1][1]) - output = output + f"" + + if custom_state['mode'] == 'instruct': + output = output + f"" + else: + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -241,7 +280,11 @@ def run(constant_seed, seed_value, use_history, x="", y=""): gen_output = [['','']] user_output = convert_to_markdown(gen_output[-1][0]) bot_output = convert_to_markdown(gen_output[-1][1]) - output = output + f"" + + if custom_state['mode'] == 'instruct': + output = output + f"" + else: + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -259,7 +302,11 @@ def run(constant_seed, seed_value, use_history, x="", y=""): gen_output = [['','']] user_output = convert_to_markdown(gen_output[-1][0]) bot_output = convert_to_markdown(gen_output[-1][1]) - output = output + f"" + + if custom_state['mode'] == 'instruct': + output = output + f"" + else: + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -292,7 +339,11 @@ def run(constant_seed, seed_value, use_history, x="", y=""): gen_output = [['','']] user_output = convert_to_markdown(gen_output[-1][0]) bot_output = convert_to_markdown(gen_output[-1][1]) - output = output + f"" + + if custom_state['mode'] == 'instruct': + output = output + f"" + else: + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -321,7 +372,11 @@ def run(constant_seed, seed_value, use_history, x="", y=""): gen_output = [['','']] user_output = convert_to_markdown(gen_output[-1][0]) bot_output = convert_to_markdown(gen_output[-1][1]) - output = output + f"" + + if custom_state['mode'] == 'instruct': + output = output + f"" + else: + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -348,7 +403,11 @@ def run(constant_seed, seed_value, use_history, x="", y=""): gen_output = [['','']] user_output = convert_to_markdown(gen_output[-1][0]) bot_output = convert_to_markdown(gen_output[-1][1]) - output = output + f"" + + if custom_state['mode'] == 'instruct': + output = output + f"" + else: + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -377,6 +436,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): custom_state['seed'] = -1 shared.history['internal'] = temp_internal shared.history['visible'] = temp_visible + custom_state = temp_custom_state return output @@ -463,8 +523,9 @@ def ui(): shared.gradio['upload_softprompt'].change(lambda x: custom_state.update({'upload_softprompt': x}), shared.gradio['upload_softprompt'], []) shared.gradio['wbits'].change(lambda x: custom_state.update({'wbits': x}), shared.gradio['wbits'], []) shared.gradio['your_picture'].change(lambda x: custom_state.update({'your_picture': x}), shared.gradio['your_picture'], []) + shared.gradio['mode'].change(lambda x: custom_state.update({'mode': x}), shared.gradio['mode'], []) - with gr.Accordion("XY Grid", open=False): + with gr.Accordion("XY Grid", open=True): # Axis selections and inputs with gr.Row(): From 4869aa063ceb87991b2c33427fdd4fdf91048205 Mon Sep 17 00:00:00 2001 From: Clay Shoaf Date: Fri, 21 Apr 2023 17:25:20 -0400 Subject: [PATCH 08/13] removed old instruct code, bug fixes --- extensions/xy_grid/script.py | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/extensions/xy_grid/script.py b/extensions/xy_grid/script.py index beaa5d59ff..c5267d0aae 100644 --- a/extensions/xy_grid/script.py +++ b/extensions/xy_grid/script.py @@ -14,7 +14,7 @@ axis_type = {'x': "prompts", 'y': "presets"} custom_state = {} gen_output = [] -axis_options = ["prompts", "presets", "characters", "instruction template", "seed", "max_new_tokens", "temperature", "top_p", "top_k", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "no_repeat_ngram_size", "min_length"] +axis_options = ["prompts", "presets", "characters", "seed", "max_new_tokens", "temperature", "top_p", "top_k", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "no_repeat_ngram_size", "min_length"] # I had to steal this from server.py because the program freaks out if I try to `import server` def load_preset_values(preset_menu, state): @@ -125,14 +125,6 @@ def parse_axis(axis, value): else: custom_state[char_type] = shared.gradio[char_type].value custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state[char_type], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) - # INSTRUCT - elif axis_type[axis] == "instruction template": - custom_state['mode'] = 'instruct' - if value.strip() != "": - custom_state['instruction_template'] = value.strip() - else: - custom_state['instruction_template'] = shared.gradio["instruction_template"].value - custom_state.update({k: v for k, v in zip(['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display'], load_character(custom_state['instruction_template'], custom_state['name1'], custom_state['name2'], custom_state['mode']))}) # FLOATS elif axis_type[axis] in ("seed", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty"): if value.strip() != "": @@ -169,7 +161,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): temp_internal = shared.history['internal'] temp_visible = shared.history['visible'] - temp_custom_state = custom_state + temp_custom_state = custom_state.copy() # Gather output json info, from before the X/Y parameters take effect output_json = {k: custom_state[k] for k in shared.input_elements} @@ -222,7 +214,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): if custom_state['mode'] == 'instruct': output = output + f"" else: - output = output + f"" + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -245,7 +237,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): if custom_state['mode'] == 'instruct': output = output + f"" else: - output = output + f"" + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -284,7 +276,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): if custom_state['mode'] == 'instruct': output = output + f"" else: - output = output + f"" + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -304,9 +296,9 @@ def run(constant_seed, seed_value, use_history, x="", y=""): bot_output = convert_to_markdown(gen_output[-1][1]) if custom_state['mode'] == 'instruct': - output = output + f"" + output = output + f"" else: - output = output + f"" + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -343,7 +335,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): if custom_state['mode'] == 'instruct': output = output + f"" else: - output = output + f"" + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -376,7 +368,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): if custom_state['mode'] == 'instruct': output = output + f"" else: - output = output + f"" + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -405,9 +397,9 @@ def run(constant_seed, seed_value, use_history, x="", y=""): bot_output = convert_to_markdown(gen_output[-1][1]) if custom_state['mode'] == 'instruct': - output = output + f"" + output = output + f"" else: - output = output + f"" + output = output + f"" # Remove the last outputs, so they don't influence future generations gen_output.pop() @@ -436,7 +428,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): custom_state['seed'] = -1 shared.history['internal'] = temp_internal shared.history['visible'] = temp_visible - custom_state = temp_custom_state + custom_state = temp_custom_state.copy() return output From 31a9d4977eddca57caa90d888faa06b8648ffa43 Mon Sep 17 00:00:00 2001 From: Clay Shoaf Date: Fri, 21 Apr 2023 18:28:51 -0400 Subject: [PATCH 09/13] bug fix, closed accordian --- extensions/xy_grid/script.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/extensions/xy_grid/script.py b/extensions/xy_grid/script.py index c5267d0aae..92d3635de4 100644 --- a/extensions/xy_grid/script.py +++ b/extensions/xy_grid/script.py @@ -315,8 +315,13 @@ def run(constant_seed, seed_value, use_history, x="", y=""): output = output + f"" for j in x_strings: # parse the types of the axes and alter custom_state accordingly - parse_axis("y", i) - parse_axis("x", j) + # in this case, we need to make sure we parse presets first, so it doesn't overwrite lower level settings + if axis_type['y'] == "presets": + parse_axis("y", i) + parse_axis("x", j) + else: + parse_axis("x", j) + parse_axis("y", i) # This was at the top of the function, but for some reason it broke with a recent update if not use_history: @@ -517,7 +522,7 @@ def ui(): shared.gradio['your_picture'].change(lambda x: custom_state.update({'your_picture': x}), shared.gradio['your_picture'], []) shared.gradio['mode'].change(lambda x: custom_state.update({'mode': x}), shared.gradio['mode'], []) - with gr.Accordion("XY Grid", open=True): + with gr.Accordion("XY Grid", open=False): # Axis selections and inputs with gr.Row(): From 658fd6a85268389ef9dcd28e2dc9ab29f1af7f97 Mon Sep 17 00:00:00 2001 From: Clay Shoaf Date: Sun, 23 Apr 2023 15:22:10 -0400 Subject: [PATCH 10/13] cleaned up logic, bug fixes, formatting - Cleaned up that hideous pile of code blocks. I don't know why I decided to do it like that originally, but I had a day off with just my old laptop and it let me think about my formatting better. It looks a lot cleaner now and should be a lot easier to work on. - Fixed the context poisoning bug when using instruct mode or instruct characters. - Slightly changed the way the table is made, so it should look better when making large grids. --- extensions/xy_grid/script.py | 355 +++++++++++------------------------ 1 file changed, 107 insertions(+), 248 deletions(-) diff --git a/extensions/xy_grid/script.py b/extensions/xy_grid/script.py index 92d3635de4..f68b5767ff 100644 --- a/extensions/xy_grid/script.py +++ b/extensions/xy_grid/script.py @@ -2,6 +2,7 @@ import json import datetime import random +import time import gradio as gr import modules.shared as shared @@ -11,6 +12,7 @@ from modules.html_generator import convert_to_markdown from pathlib import Path +# Global variables axis_type = {'x': "prompts", 'y': "presets"} custom_state = {} gen_output = [] @@ -97,7 +99,7 @@ def set_axis(x, y): # Parse the type of the X axis and alter custom_state accordingly -# If you want to add more axes, this is where you would do it. +# If you want to add more axes, this is where you would do it. # Add logic here and include it in axis_options def parse_axis(axis, value): global custom_state @@ -146,274 +148,131 @@ def parse_axis(axis, value): return None +# The main function that generates the grid def run(constant_seed, seed_value, use_history, x="", y=""): - global custom_state global gen_output global axis_type + # Error handling + if axis_type['x'] == axis_type['y']: + return "

ERROR: both axes cannot be the same setting" + if x.strip() == '' and y.strip() == '': + return "

ERROR: both fields are empty" + shared.args.no_stream = True + + # Backup our parameters so we can put everything back how it was before we started + temp_internal = shared.history['internal'].copy() + temp_visible = shared.history['visible'].copy() + temp_custom_state = custom_state.copy() + + # Handle the constant seed value if constant_seed: if seed_value == "-1": custom_state['seed'] = random.randint(1, 2**31) else: custom_state['seed'] = seed_value - temp_internal = shared.history['internal'] - temp_visible = shared.history['visible'] - temp_custom_state = custom_state.copy() # Gather output json info, from before the X/Y parameters take effect output_json = {k: custom_state[k] for k in shared.input_elements} + # This was causing problems when the custom stopping strings was set to None if custom_state['custom_stopping_strings'] is None: custom_state['custom_stopping_strings'] = "" # Have to format the strings because gradio makes it difficult to pass lists around - if x == "": - x_strings = "" - else: - x_strings = pp.common.comma_separated_list.parseString(x).asList() - if y == "": - y_strings = "" - else: - y_strings = pp.common.comma_separated_list.parseString(y).asList() - - output = "

X={axis_type['x']}
Y={axis_type['y']}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{i.strip()}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}
{i.strip()}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{i.strip()}

{custom_state['name1']}:

{gen_output[-1][0]}

{custom_state['name2']}:

{gen_output[-1][1]}
{i.strip()}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{i.strip()}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}
{i.strip()}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{i.strip()}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}
{i.strip()}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}
{i.strip()}{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}
{i.strip()}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{i.strip()}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}
{i.strip()}{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}
{i.strip()}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{i.strip()}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{i.strip()}
" + f"" - - if axis_type['x'] == axis_type['y']: - return "

ERROR: both axes cannot be the same setting" - - # Run as if x axis is prompts - elif axis_type['x'] == "prompts": - for i in x_strings: - output = output + f"

" - output = output + "" - if y_strings != '': - for i in y_strings: - output = output + f"" - for j in x_strings: - - # parse the type of the Y axis and alter custom_state accordingly + x_strings = pp.common.comma_separated_list.parseString(x).asList() + y_strings = pp.common.comma_separated_list.parseString(y).asList() + + # If someone uses "-1" for a seed axis, we don't want it generating a new seed for every cell of the grid + if axis_type['x'] == "seed": + x_strings = [str(random.randint(1, 2**31)) if seed in ('-1','-1.0') else seed for seed in x_strings] + if axis_type['y'] == "seed": + y_strings = [str(random.randint(1, 2**31)) if seed in ('-1','-1.0') else seed for seed in y_strings] + + cell_count = len(x_strings) + 1 + output = "
X={axis_type['x']}
Y={axis_type['y']}
{i.strip()}
{i.strip()}
" + f"" + + # Make the grid + for i in x_strings: + output = output + f"" + output = output + "" + for i in y_strings: + output = output + f"" + for j in x_strings: + + # parse the type of the axes and alter custom_state accordingly + if axis_type['x'] == "prompts": + parse_axis("y", i) + elif axis_type['y'] == "prompts": + parse_axis("x", j) + elif y_strings != '' and x_strings != '': + # in this case, we need to make sure we parse presets first, so it doesn't overwrite lower level settings + if axis_type['y'] == "presets": parse_axis("y", i) - - # This was at the top of the function, but for some reason it broke with a recent update - if not use_history: - shared.history['internal'] = shared.history['internal'][:1] - shared.history['visible'] = shared.history['visible'][:1] - - # This is the part that actually does the generating - for new in chatbot_wrapper(j.strip().strip('"'), custom_state): - gen_output = new - - if len(gen_output) == 0: - gen_output = [['','']] - user_output = convert_to_markdown(gen_output[-1][0]) - bot_output = convert_to_markdown(gen_output[-1][1]) - - if custom_state['mode'] == 'instruct': - output = output + f"" - else: - output = output + f"" - - # Remove the last outputs, so they don't influence future generations - gen_output.pop() - if len(shared.history['internal']) > 1: - shared.history['internal'].pop() - - output = output + "" - - else: - output = output + "" - for i in x_strings: - for new in chatbot_wrapper(i.strip().strip('"'), custom_state): - gen_output = new - - if len(gen_output) == 0: - gen_output = [['','']] - user_output = convert_to_markdown(gen_output[-1][0]) - bot_output = convert_to_markdown(gen_output[-1][1]) - - if custom_state['mode'] == 'instruct': - output = output + f"" - else: - output = output + f"" - - # Remove the last outputs, so they don't influence future generations - gen_output.pop() - if len(shared.history['internal']) > 1: - shared.history['internal'].pop() - - output = output + "" - - # Run as if y axis is prompts - elif axis_type['y'] == "prompts": - for i in x_strings: - output = output + f"" - output = output + "" - if x_strings != '': - for i in y_strings: - output = output + f"" - for j in x_strings: - - # parse the type of the X axis and alter custom_state accordingly parse_axis("x", j) - - # This was at the top of the function, but for some reason it broke with a recent update - if not use_history: - shared.history['internal'] = shared.history['internal'][:1] - shared.history['visible'] = shared.history['visible'][:1] - - # This is the part that actually does the generating - for new in chatbot_wrapper(i.strip().strip('"'), custom_state): - gen_output = new - - if len(gen_output) == 0: - gen_output = [['','']] - user_output = convert_to_markdown(gen_output[-1][0]) - bot_output = convert_to_markdown(gen_output[-1][1]) - - if custom_state['mode'] == 'instruct': - output = output + f"" - else: - output = output + f"" - - # Remove the last outputs, so they don't influence future generations - gen_output.pop() - if len(shared.history['internal']) > 1: - shared.history['internal'].pop() - - output = output + "" - - else: - for i in y_strings: - for new in chatbot_wrapper(i.strip().strip('"'), custom_state): - gen_output = new - - if len(gen_output) == 0: - gen_output = [['','']] - user_output = convert_to_markdown(gen_output[-1][0]) - bot_output = convert_to_markdown(gen_output[-1][1]) - - if custom_state['mode'] == 'instruct': - output = output + f"" else: - output = output + f"" - - # Remove the last outputs, so they don't influence future generations - gen_output.pop() - if len(shared.history['internal']) > 1: - shared.history['internal'].pop() - - # Take the prompts from custom_state['textbox'] - else: - for i in x_strings: - output = output + f"" - output = output + "" - if y_strings != '' and x_strings != '': - for i in y_strings: - output = output + f"" - for j in x_strings: - # parse the types of the axes and alter custom_state accordingly - # in this case, we need to make sure we parse presets first, so it doesn't overwrite lower level settings - if axis_type['y'] == "presets": - parse_axis("y", i) - parse_axis("x", j) - else: - parse_axis("x", j) - parse_axis("y", i) - - # This was at the top of the function, but for some reason it broke with a recent update - if not use_history: - shared.history['internal'] = shared.history['internal'][:1] - shared.history['visible'] = shared.history['visible'][:1] - - # This is the part that actually does the generating - for new in chatbot_wrapper(custom_state['textbox'].strip(), custom_state): - gen_output = new - - if len(gen_output) == 0: - gen_output = [['','']] - user_output = convert_to_markdown(gen_output[-1][0]) - bot_output = convert_to_markdown(gen_output[-1][1]) - - if custom_state['mode'] == 'instruct': - output = output + f"" - else: - output = output + f"" - - # Remove the last outputs, so they don't influence future generations - gen_output.pop() - if len(shared.history['internal']) > 1: - shared.history['internal'].pop() - - output = output + "" - - elif x_strings != '': - output = output + "" - for j in x_strings: - - # parse the types of the axes and alter custom_state accordingly + parse_axis("x", j) + parse_axis("y", i) + elif x_strings != '': parse_axis("x", j) - - # This was at the top of the function, but for some reason it broke with a recent update - if not use_history: - shared.history['internal'] = shared.history['internal'][:1] - shared.history['visible'] = shared.history['visible'][:1] - - # Run the actual text generator - for new in chatbot_wrapper(custom_state['textbox'].strip(), custom_state): - gen_output = new - - if len(gen_output) == 0: - gen_output = [['','']] - user_output = convert_to_markdown(gen_output[-1][0]) - bot_output = convert_to_markdown(gen_output[-1][1]) - - if custom_state['mode'] == 'instruct': - output = output + f"" - else: - output = output + f"" - - # Remove the last outputs, so they don't influence future generations - gen_output.pop() - if len(shared.history['internal']) > 1: - shared.history['internal'].pop() - - output = output + "" - - elif y_strings != '': - for i in y_strings: - # parse the types of the axes and alter custom_state accordingly + elif y_strings != '': parse_axis("y", i) - - # This was at the top of the function, but for some reason it broke with a recent update - if not use_history: - shared.history['internal'] = shared.history['internal'][:1] - shared.history['visible'] = shared.history['visible'][:1] - - # Run the actual text generator + else: + return "

ERROR: unknown error" + + # Determine whether or not we are including the character's chat history with the user + if not use_history: + shared.history['internal'] = shared.history['internal'][:1] + shared.history['visible'] = shared.history['visible'][:1] + + # Clear all history for instruct mode + if custom_state['mode'] == "instruct": + shared.history['internal'].clear() + shared.history['visible'].clear() + + # This is the part that actually does the generating + if axis_type['x'] == "prompts": + for new in chatbot_wrapper(j.strip().strip('"'), custom_state): + gen_output = new + elif axis_type['y'] == "prompts": + for new in chatbot_wrapper(i.strip().strip('"'), custom_state): + gen_output = new + else: for new in chatbot_wrapper(custom_state['textbox'].strip(), custom_state): gen_output = new - if len(gen_output) == 0: - gen_output = [['','']] - user_output = convert_to_markdown(gen_output[-1][0]) - bot_output = convert_to_markdown(gen_output[-1][1]) - - if custom_state['mode'] == 'instruct': - output = output + f"

" - else: - output = output + f"" - - # Remove the last outputs, so they don't influence future generations - gen_output.pop() + # Sometimes it the generation kicks back nothing and it causes problems + if len(gen_output) == 0: + gen_output = [['','']] + + # Turn the output into HTML for our table + user_output = convert_to_markdown(gen_output[-1][0]) + bot_output = convert_to_markdown(gen_output[-1][1]) + if custom_state['mode'] == 'instruct': + output = output + f"" + else: + output = output + f"" + + # Remove the last outputs, so they don't influence future generations + if custom_state['mode'] == 'instruct': + shared.history['internal'].clear() + shared.history['visible'].clear() + else: if len(shared.history['internal']) > 1: shared.history['internal'].pop() - - else: - return "

ERROR: both fields are empty" - + elif len(shared.history['internal']) == 1: + if shared.history['internal'] == gen_output: + shared.history['internal'].clear() + if len(shared.history['visible']) > 1: + shared.history['visible'].pop() + elif len(shared.history['visible']) == 1: + if shared.history['visible'] == gen_output: + shared.history['visible'].clear() + + output = output + "

" output = output + "
X={axis_type['x']}
Y={axis_type['y']}
{i.strip()}
{i.strip()}{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{i.strip()}
{i.strip()}{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{i.strip()}{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}
{i.strip()}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{i.strip()}
{i.strip()}{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
{i.strip()}{custom_state['name1']} {user_output}{custom_state['name2']} {bot_output}
{i.strip()}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}

{custom_state['name1']}

{user_output}

{custom_state['name2']}

{bot_output}

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}
" # Save the output to a file @@ -429,14 +288,12 @@ def run(constant_seed, seed_value, use_history, x="", y=""): # Include a link to the generated HTML file output = f"

[ open html file 🔗 ]



" + output - # Clean up some of the changes that were made during this generation - custom_state['seed'] = -1 - shared.history['internal'] = temp_internal - shared.history['visible'] = temp_visible + # Clean up the changes that were made during this generation + shared.history['internal'] = temp_internal.copy() + shared.history['visible'] = temp_visible.copy() custom_state = temp_custom_state.copy() - return output - + return output # Necessary for some stuff because gradio @@ -522,7 +379,8 @@ def ui(): shared.gradio['your_picture'].change(lambda x: custom_state.update({'your_picture': x}), shared.gradio['your_picture'], []) shared.gradio['mode'].change(lambda x: custom_state.update({'mode': x}), shared.gradio['mode'], []) - with gr.Accordion("XY Grid", open=False): + # UI for the extension + with gr.Accordion("XY Grid", open=True): # Axis selections and inputs with gr.Row(): @@ -535,8 +393,9 @@ def ui(): y_type.select(set_axis, [x_type, y_type], []).then(fill_axis, y_type, y_input) x_type.change(set_axis, [x_type, y_type], []) y_type.change(set_axis, [x_type, y_type], []) + with gr.Row(): - swap_xy = gr.Button(value='Swap X/Y Axes 🔀') + swap_xy = gr.Button(value='Swap X/Y Axes') with gr.Row(): seed_input = gr.Checkbox(label='Use a constant seed', value=False) use_history = gr.Checkbox(label='Use character\'s chat history', value=False) @@ -548,4 +407,4 @@ def ui(): generate_grid = gr.Button("generate_grid") custom_chat = gr.HTML(value="") - generate_grid.click(run, [seed_input, seed_value, use_history, x_input, y_input], custom_chat) + generate_grid.click(run, [seed_input, seed_value, use_history, x_input, y_input], custom_chat) \ No newline at end of file From 9b9f83518889dbb30dde2dfeed5038d6d2218bb9 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 23 Apr 2023 21:13:01 -0300 Subject: [PATCH 11/13] Importing from server now works --- extensions/xy_grid/script.py | 38 +++++------------------------------- 1 file changed, 5 insertions(+), 33 deletions(-) diff --git a/extensions/xy_grid/script.py b/extensions/xy_grid/script.py index f68b5767ff..b1dfef6d4a 100644 --- a/extensions/xy_grid/script.py +++ b/extensions/xy_grid/script.py @@ -1,16 +1,17 @@ -import os -import json import datetime +import json +import os import random import time +from pathlib import Path import gradio as gr -import modules.shared as shared import pyparsing as pp +import modules.shared as shared from modules.chat import chatbot_wrapper, load_character from modules.html_generator import convert_to_markdown -from pathlib import Path +from server import load_preset_values # Global variables axis_type = {'x': "prompts", 'y': "presets"} @@ -18,35 +19,6 @@ gen_output = [] axis_options = ["prompts", "presets", "characters", "seed", "max_new_tokens", "temperature", "top_p", "top_k", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "no_repeat_ngram_size", "min_length"] -# I had to steal this from server.py because the program freaks out if I try to `import server` -def load_preset_values(preset_menu, state): - generate_params = { - 'do_sample': True, - 'temperature': 1, - 'top_p': 1, - 'typical_p': 1, - 'repetition_penalty': 1, - 'encoder_repetition_penalty': 1, - 'top_k': 50, - 'num_beams': 1, - 'penalty_alpha': 0, - 'min_length': 0, - 'length_penalty': 1, - 'no_repeat_ngram_size': 0, - 'early_stopping': False, - } - with open(Path(f'presets/{preset_menu}.txt'), 'r') as infile: - preset = infile.read() - for i in preset.splitlines(): - i = i.rstrip(',').strip().split('=') - if len(i) == 2 and i[0].strip() != 'tokens': - generate_params[i[0].strip()] = eval(i[1].strip()) - generate_params['temperature'] = min(1.99, generate_params['temperature']) - - state.update(generate_params) - custom_state['preset_menu'] = preset_menu - return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] - # Get all of the characters from the character folder def get_characters(): From ef57d58d1395df035a32cb00360c5d249b2e7652 Mon Sep 17 00:00:00 2001 From: Clay Shoaf Date: Sun, 23 Apr 2023 22:52:14 -0400 Subject: [PATCH 12/13] reworked axis parser to use server.py --- extensions/xy_grid/script.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/extensions/xy_grid/script.py b/extensions/xy_grid/script.py index b1dfef6d4a..f777caf5e4 100644 --- a/extensions/xy_grid/script.py +++ b/extensions/xy_grid/script.py @@ -31,12 +31,6 @@ def get_characters(): return ", ".join(['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower) + instructors) -# Get all of the instruction following templates from the character folder -def get_instruct(): - paths = (x for x in Path('characters/instruction-following').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) - return ", ".join(['None'] + sorted(set((k.stem for k in paths)), key=str.lower)) - - # Get all of the presets from the presets folder def get_presets(): presets = [] @@ -55,8 +49,6 @@ def fill_axis(option): return gr.update(label=option, value=get_presets()) elif option == "characters": return gr.update(label=option, value=get_characters()) - elif option == "instruction template": - return gr.update(label=option, value=get_instruct()) elif option == "prompts": return gr.update(label=option, value=custom_state['textbox']) else: @@ -80,7 +72,9 @@ def parse_axis(axis, value): # PRESETS if axis_type[axis] == "presets": if value.strip() != "": - custom_state = load_preset_values(value.strip(), custom_state)[0] + temp_dict = load_preset_values(value.strip(), custom_state, return_dict=True) + custom_state.update({k: temp_dict[k] for k in temp_dict.keys()}) + custom_state['preset_menu'] = value.strip() else: custom_state = load_preset_values(shared.gradio['preset_menu'].value, custom_state)[0] # CHARACTERS @@ -379,4 +373,4 @@ def ui(): generate_grid = gr.Button("generate_grid") custom_chat = gr.HTML(value="") - generate_grid.click(run, [seed_input, seed_value, use_history, x_input, y_input], custom_chat) \ No newline at end of file + generate_grid.click(run, [seed_input, seed_value, use_history, x_input, y_input], custom_chat) From 448aaf7a90c4acb898a3df122b071871e51265c7 Mon Sep 17 00:00:00 2001 From: Clay Shoaf Date: Mon, 24 Apr 2023 11:18:56 -0400 Subject: [PATCH 13/13] added pause/stop buttons, realtime output --- extensions/xy_grid/script.py | 45 ++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/extensions/xy_grid/script.py b/extensions/xy_grid/script.py index f777caf5e4..a00405d4b7 100644 --- a/extensions/xy_grid/script.py +++ b/extensions/xy_grid/script.py @@ -11,13 +11,14 @@ import modules.shared as shared from modules.chat import chatbot_wrapper, load_character from modules.html_generator import convert_to_markdown -from server import load_preset_values +from server import load_preset_values, stop_everything_event # Global variables axis_type = {'x': "prompts", 'y': "presets"} custom_state = {} gen_output = [] axis_options = ["prompts", "presets", "characters", "seed", "max_new_tokens", "temperature", "top_p", "top_k", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "no_repeat_ngram_size", "min_length"] +is_paused = False # Get all of the characters from the character folder @@ -116,9 +117,11 @@ def parse_axis(axis, value): # The main function that generates the grid def run(constant_seed, seed_value, use_history, x="", y=""): + global custom_state global gen_output global axis_type + global is_paused # Error handling if axis_type['x'] == axis_type['y']: @@ -221,6 +224,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): output = output + f"

{custom_state['name1']}

{user_output}

{custom_state['name2']}

{bot_output}" else: output = output + f"

{custom_state['name1']}:

{user_output}

{custom_state['name2']}:

{bot_output}" + yield output # Remove the last outputs, so they don't influence future generations if custom_state['mode'] == 'instruct': @@ -238,6 +242,21 @@ def run(constant_seed, seed_value, use_history, x="", y=""): if shared.history['visible'] == gen_output: shared.history['visible'].clear() + # Check to see if the user stopped or paused the generation + if shared.stop_everything: + shared.history['internal'] = temp_internal.copy() + shared.history['visible'] = temp_visible.copy() + custom_state = temp_custom_state.copy() + return output + while is_paused: + if shared.stop_everything: + shared.history['internal'] = temp_internal.copy() + shared.history['visible'] = temp_visible.copy() + custom_state = temp_custom_state.copy() + return output + + + output = output + "" output = output + "" @@ -259,7 +278,7 @@ def run(constant_seed, seed_value, use_history, x="", y=""): shared.history['visible'] = temp_visible.copy() custom_state = temp_custom_state.copy() - return output + yield output # Necessary for some stuff because gradio @@ -272,12 +291,29 @@ def toggle_visible(var): custom_state['seed'] = -1 return gr.update(visible=var) +# These could be one function if gradio allowed me to pass boolean values with a button click +# There are other workarounds, but they feel just as dirty +def pause_switch(): + global is_paused + is_paused = not is_paused + if is_paused: + return gr.update(value="Resume") + if not is_paused: + return gr.update(value="Pause") + +def unpause(): + global is_paused + is_paused = False + return gr.update(value="Pause") + + # Create the interface for the extension (this runs first) def ui(): global custom_state global axis_type + # Grab all the variable from shared.gradio and put them in the custom_state dictionary custom_state.update({k: v for k, v in zip([key for key in shared.gradio if not isinstance(shared.gradio[key], (gr.Blocks, gr.Button, gr.State))], [shared.gradio[k].value for k in [key for key in shared.gradio] if not isinstance(shared.gradio[k], (gr.Blocks, gr.Button, gr.State))])}) @@ -371,6 +407,11 @@ def ui(): swap_xy.click(swap_axes, [x_type, x_input, y_type, y_input], [x_type, x_input, x_input, y_type, y_input, y_input]) generate_grid = gr.Button("generate_grid") + with gr.Row(): + pause_gen = gr.Button(value='Pause') + stop_gen = gr.Button(value='Stop') + pause_gen.click(pause_switch, [], pause_gen, queue=False) + stop_gen.click(stop_everything_event, queue=False).then(unpause, [], pause_gen) custom_chat = gr.HTML(value="") generate_grid.click(run, [seed_input, seed_value, use_history, x_input, y_input], custom_chat)