Skip to content

Commit

Permalink
Merge pull request #1471 from stanfordnlp/assertions_patches
Browse files Browse the repository at this point in the history
Assertions patches
  • Loading branch information
okhat authored Sep 9, 2024
2 parents ec10028 + b7435c1 commit 2dbc249
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 89 deletions.
4 changes: 2 additions & 2 deletions docs/docs/cheatsheet.md
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,9 @@ Other custom configurations are similar to customizing the `dspy.BootstrapFewSho

### Including `dspy.Assert` and `dspy.Suggest` statements
```python
dspy.Assert(your_validation_fn(model_outputs), "your feedback message", target_module="YourDSPyModuleSignature")
dspy.Assert(your_validation_fn(model_outputs), "your feedback message", target_module="YourDSPyModule")

dspy.Suggest(your_validation_fn(model_outputs), "your feedback message", target_module="YourDSPyModuleSignature")
dspy.Suggest(your_validation_fn(model_outputs), "your feedback message", target_module="YourDSPyModule")
```

### Activating DSPy Program with Assertions
Expand Down
2 changes: 1 addition & 1 deletion dspy/predict/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _create_new_signature(self, signature):
actual_prefix = value.json_schema_extra["prefix"].split(":")[0] + ":"
signature = signature.append(f"past_{key}", dspy.InputField(
prefix="Previous " + actual_prefix,
desc=f"past {actual_prefix} with errors",
desc=f"past {actual_prefix[:-1]} with errors",
format=value.json_schema_extra.get("format"),
))

Expand Down
8 changes: 6 additions & 2 deletions dspy/primitives/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def wrapper(*args, **kwargs):
for i in range(len(dsp.settings.trace) - 1, -1, -1):
trace_element = dsp.settings.trace[i]
mod = trace_element[0]
if mod.signature == error_target_module:
if mod == error_target_module:
error_state = e.state[i]
dspy.settings.backtrack_to = mod
break
Expand All @@ -257,7 +257,11 @@ def wrapper(*args, **kwargs):
):
dspy.settings.predictor_feedbacks[dspy.settings.backtrack_to].append(error_msg)

output_fields = error_state[0].new_signature.output_fields
# use `new_signature` if available (CoT)
if hasattr(error_state[0], 'new_signature'):
output_fields = error_state[0].new_signature.output_fields
else:
output_fields = error_state[0].signature.output_fields
past_outputs = {}
for field_name in output_fields.keys():
past_outputs[field_name] = getattr(
Expand Down
80 changes: 44 additions & 36 deletions examples/longformqa/longformqa_assertions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"metadata": {},
"source": [
"### 0] Setting Up\n",
"Let's begin by setting things up. The snippet below will retrieve the cached requests for the task."
"Let's begin by setting things up."
]
},
{
Expand All @@ -41,28 +41,16 @@
"metadata": {},
"outputs": [],
"source": [
"!git clone https://huggingface.co/arnavs11/DSPy_LongFormQA\n",
"%cd DSPy_LongFormQA\n",
"!git checkout master\n",
"%cd ..\n",
"import os\n",
"repo_clone_path = '/content/DSPy_LongFormQA'\n",
"\n",
"# Check if '/content' is writable\n",
"if not os.access('/content', os.W_OK):\n",
" # If '/content' is not writable, choose an alternative directory\n",
" # Example: using a directory relative to the current working directory\n",
" repo_clone_path = os.path.join(os.getcwd(), 'DSPy_LongFormQA')\n",
"\n",
"# Set up the cache for this notebook\n",
"os.environ[\"DSP_NOTEBOOK_CACHEDIR\"] = repo_clone_path"
"import openai\n",
"openai.api_key = os.getenv('OPENAI_API_KEY')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will also install **DSPy** if it's not there already."
"We will install **DSPy** if it's not there already."
]
},
{
Expand Down Expand Up @@ -131,7 +119,7 @@
"source": [
"colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')\n",
"dspy.settings.configure(rm=colbertv2_wiki17_abstracts)\n",
"turbo = dspy.OpenAI(model='gpt-3.5-turbo-0613', max_tokens=500)\n",
"turbo = dspy.OpenAI(model='gpt-4o-mini', max_tokens=500)\n",
"dspy.settings.configure(lm=turbo, trace=[], temperature=0.7)"
]
},
Expand All @@ -150,7 +138,7 @@
"metadata": {},
"outputs": [],
"source": [
"dataset = HotPotQA(train_seed=1, train_size=300, eval_seed=2023, dev_size=300, test_size=0, keep_details=True)\n",
"dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0, keep_details=True)\n",
"trainset = [x.with_inputs('question') for x in dataset.train]\n",
"devset = [x.with_inputs('question') for x in dataset.dev]"
]
Expand Down Expand Up @@ -415,11 +403,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"question = devset[28].question\n",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: What is the name of this region of Italy, referring to the medieval March of Ancona and nearby marches of Camerino and Fermo, where the comune Pollenza is located?\n",
"Predicted Paragraph: The region of Italy that refers to the medieval March of Ancona and nearby marches of Camerino and Fermo is called Marche. This name is derived from the plural form of \"marca,\" which originally denoted the frontier territories established during the Middle Ages, particularly the March of Ancona (Marche) (1). Today, Marche is recognized not only for its historical significance but also for its rich shoemaking tradition, producing some of the finest Italian footwear (1). Within this region lies the comune of Pollenza, located approximately 40 km southwest of Ancona, further illustrating the geographical and cultural significance of Marche (3).\n",
"Citation Faithfulness: False\n"
]
}
],
"source": [
"question = devset[15].question\n",
"pred = longformqa(question)\n",
"citation_faithfulness_score, _ = citation_faithfulness(None, pred, None)\n",
"\n",
Expand All @@ -432,9 +430,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the generated paragraph does not properly include citations as intended. In some instances, it follows an incorrect citation format by not keeping the referenced source at before the period at the end of a sentence. In other instances, it does not provide citations for every 1-2 sentences as expected. \n",
"We can see that the generated paragraph does not properly include citations as intended in the format of \"[source]\". \n",
"\n",
"Additionally, we see that not all included citations are faithful to their preceding text (which can be a result of cited references in an incorrect format)."
"Additionally, we see that not all included citations are faithful to their preceding text."
]
},
{
Expand Down Expand Up @@ -482,12 +480,12 @@
" context = deduplicate(context + passages)\n",
" pred = self.generate_cited_paragraph(context=context, question=question)\n",
" pred = dspy.Prediction(context=context, paragraph=pred.paragraph)\n",
" dspy.Suggest(citations_check(pred.paragraph), \"Make sure every 1-2 sentences has citations. If any 1-2 sentences lack citations, add them in 'text... [x].' format.\", target_module=GenerateCitedParagraph)\n",
" dspy.Suggest(citations_check(pred.paragraph), \"Make sure every 1-2 sentences has citations. If any 1-2 sentences lack citations, add them in 'text... [x].' format.\", target_module=self.generate_cited_paragraph)\n",
" _, unfaithful_outputs = citation_faithfulness(None, pred, None)\n",
" if unfaithful_outputs:\n",
" unfaithful_pairs = [(output['text'], output['context']) for output in unfaithful_outputs]\n",
" for _, context in unfaithful_pairs:\n",
" dspy.Suggest(len(unfaithful_pairs) == 0, f\"Make sure your output is based on the following context: '{context}'.\", target_module=GenerateCitedParagraph)\n",
" dspy.Suggest(len(unfaithful_pairs) == 0, f\"Make sure your output is based on the following context: '{context}'.\", target_module=self.generate_cited_paragraph)\n",
" else:\n",
" return pred\n",
" return pred"
Expand Down Expand Up @@ -534,11 +532,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"question = devset[28].question\n",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: What is the name of this region of Italy, referring to the medieval March of Ancona and nearby marches of Camerino and Fermo, where the comune Pollenza is located?\n",
"Predicted Paragraph: The region of Italy that refers to the medieval March of Ancona and nearby marches of Camerino and Fermo is called Marche [1]. This name is derived from the plural form of \"marca,\" which originally denoted the frontier territories established during the Middle Ages, particularly the March of Ancona [2]. Today, Marche is recognized not only for its historical significance but also for its rich shoemaking tradition, producing some of the finest Italian footwear [1]. Within this region lies the comune of Pollenza, located approximately 40 km southwest of Ancona and about 9 km southwest of Macerata, further illustrating the geographical and cultural significance of Marche [3].\n",
"Citation Faithfulness: True\n"
]
}
],
"source": [
"question = devset[15].question\n",
"pred = longformqa_with_assertions(question)\n",
"citation_faithfulness_score, _ = citation_faithfulness(None, pred, None)\n",
"\n",
Expand All @@ -551,7 +559,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We now see that both computational constraints are indeed met. Every 1-2 sentences includes a citation and from our citation_faithfulness check, we see that each reference is also faithful to its preceding text. "
"We now see that both computational constraints are indeed met. Every 1-2 sentences includes a citation and from our `citation_faithfulness` check, we see that each reference is also faithful to its preceding text. "
]
},
{
Expand All @@ -575,7 +583,7 @@
"source": [
"longformqa = LongFormQA()\n",
"teleprompter = BootstrapFewShotWithRandomSearch(metric = answer_correctness, max_bootstrapped_demos=2, num_candidate_programs=6)\n",
"cited_longformqa = teleprompter.compile(student = longformqa, teacher = longformqa, trainset=trainset, valset=devset[:100])\n",
"cited_longformqa = teleprompter.compile(student = longformqa, teacher = longformqa, trainset=trainset, valset=devset[:25])\n",
"evaluate(cited_longformqa)"
]
},
Expand All @@ -596,7 +604,7 @@
"source": [
"longformqa = LongFormQA()\n",
"teleprompter = BootstrapFewShotWithRandomSearch(metric = answer_correctness, max_bootstrapped_demos=2, num_candidate_programs=6)\n",
"cited_longformqa_teacher = teleprompter.compile(student=longformqa, teacher = assert_transform_module(LongFormQAWithAssertions().map_named_predictors(Retry), backtrack_handler), trainset=trainset, valset=devset[:100])\n",
"cited_longformqa_teacher = teleprompter.compile(student=longformqa, teacher = assert_transform_module(LongFormQAWithAssertions().map_named_predictors(Retry), backtrack_handler), trainset=trainset, valset=devset[:25])\n",
"evaluate(cited_longformqa_teacher)"
]
},
Expand All @@ -615,7 +623,7 @@
"source": [
"longformqa = LongFormQA()\n",
"teleprompter = BootstrapFewShotWithRandomSearch(metric = answer_correctness, max_bootstrapped_demos=2, num_candidate_programs=6)\n",
"cited_longformqa_student_teacher = teleprompter.compile(student=assert_transform_module(LongFormQAWithAssertions().map_named_predictors(Retry), backtrack_handler), teacher = assert_transform_module(LongFormQAWithAssertions().map_named_predictors(Retry), backtrack_handler), trainset=trainset, valset=devset[:100])\n",
"cited_longformqa_student_teacher = teleprompter.compile(student=assert_transform_module(LongFormQAWithAssertions().map_named_predictors(Retry), backtrack_handler), teacher = assert_transform_module(LongFormQAWithAssertions().map_named_predictors(Retry), backtrack_handler), trainset=trainset, valset=devset[:25])\n",
"evaluate(cited_longformqa_student_teacher)"
]
}
Expand Down
4 changes: 2 additions & 2 deletions examples/qa/hotpot/hotpotqa_with_assertions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"source": [
"colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')\n",
"dspy.settings.configure(rm=colbertv2_wiki17_abstracts)\n",
"turbo = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=500)\n",
"turbo = dspy.OpenAI(model='gpt-4o-mini', max_tokens=500)\n",
"dspy.settings.configure(lm=turbo, trace=[], temperature=0.7)"
]
},
Expand All @@ -45,7 +45,7 @@
"metadata": {},
"outputs": [],
"source": [
"dataset = HotPotQA(train_seed=1, train_size=300, eval_seed=2023, dev_size=300, test_size=0)\n",
"dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0, keep_details=True)\n",
"trainset = [x.with_inputs('question') for x in dataset.train]\n",
"devset = [x.with_inputs('question') for x in dataset.dev]"
]
Expand Down
32 changes: 10 additions & 22 deletions examples/quiz/quiz_assertions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,9 @@
"metadata": {},
"outputs": [],
"source": [
"!git clone https://huggingface.co/arnavs11/DSPy_QuizGen_Cache\n",
"%cd DSPy_QuizGen_Cache/\n",
"!git checkout master\n",
"%cd ..\n",
"import os\n",
"repo_clone_path = '/content/DSPy_QuizGen_Cache'\n",
"\n",
"# Check if '/content' is writable\n",
"if not os.access('/content', os.W_OK):\n",
" # If '/content' is not writable, choose an alternative directory\n",
" # Example: using a directory relative to the current working directory\n",
" repo_clone_path = os.path.join(os.getcwd(), 'DSPy_QuizGen_Cache')\n",
"\n",
"# Set up the cache for this notebook\n",
"os.environ[\"DSP_NOTEBOOK_CACHEDIR\"] = repo_clone_path"
"import openai\n",
"openai.api_key = os.getenv('OPENAI_API_KEY')"
]
},
{
Expand Down Expand Up @@ -95,7 +83,7 @@
"source": [
"colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')\n",
"dspy.settings.configure(rm=colbertv2_wiki17_abstracts)\n",
"turbo = dspy.OpenAI(model='gpt-3.5-turbo-0613', max_tokens=500)\n",
"turbo = dspy.OpenAI(model='gpt-4o-mini', max_tokens=500)\n",
"dspy.settings.configure(lm=turbo, trace=[], temperature=0.7)"
]
},
Expand All @@ -105,7 +93,7 @@
"metadata": {},
"outputs": [],
"source": [
"dataset = HotPotQA(train_seed=1, train_size=300, eval_seed=2023, dev_size=300, test_size=0, keep_details=True)\n",
"dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0, keep_details=True)\n",
"trainset = [x.with_inputs('question', 'answer') for x in dataset.train]\n",
"devset = [x.with_inputs('question', 'answer') for x in dataset.dev]"
]
Expand Down Expand Up @@ -341,11 +329,11 @@
"\n",
" def forward(self, question, answer):\n",
" choice_string = self.generate_choices(question=question, correct_answer=answer, number_of_choices=number_of_choices).answer_choices\n",
" dspy.Suggest(format_checker(choice_string), \"The format of the answer choices should be in JSON format. Please revise accordingly.\", target_module=GenerateAnswerChoices)\n",
" dspy.Suggest(is_correct_answer_included(answer, choice_string), \"The answer choices do not include the correct answer to the question. Please revise accordingly.\", target_module=GenerateAnswerChoices)\n",
" dspy.Suggest(format_checker(choice_string), \"The format of the answer choices should be in JSON format. Please revise accordingly.\", target_module=self.generate_choices)\n",
" dspy.Suggest(is_correct_answer_included(answer, choice_string), \"The answer choices do not include the correct answer to the question. Please revise accordingly.\", target_module=self.generate_choices)\n",
" plausibility_question = \"Are the distractors in the answer choices plausible and not easily identifiable as incorrect?\"\n",
" plausibility_assessment = dspy.Predict(AssessQuizChoices)(question=question, answer_choices=choice_string, assessment_question=plausibility_question)\n",
" dspy.Suggest(is_plausibility_yes(plausibility_assessment.assessment_answer), \"The answer choices are not plausible distractors or are too easily identifiable as incorrect. Please revise to provide more challenging and plausible distractors.\", target_module=GenerateAnswerChoices)\n",
" dspy.Suggest(is_plausibility_yes(plausibility_assessment.assessment_answer), \"The answer choices are not plausible distractors or are too easily identifiable as incorrect. Please revise to provide more challenging and plausible distractors.\", target_module=self.generate_choices)\n",
" return dspy.Prediction(choices = choice_string)\n",
"\n",
"number_of_choices = '4'\n",
Expand Down Expand Up @@ -428,7 +416,7 @@
"outputs": [],
"source": [
"teleprompter = BootstrapFewShotWithRandomSearch(metric = overall_metric, max_bootstrapped_demos=2, num_candidate_programs=6)\n",
"compiled_quiz_generator = teleprompter.compile(student = quiz_generator, teacher = quiz_generator, trainset=trainset, valset=devset[:100])\n",
"compiled_quiz_generator = teleprompter.compile(student = quiz_generator, teacher = quiz_generator, trainset=trainset, valset=devset[:25])\n",
"\n",
"for metric in metrics:\n",
" evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)\n",
Expand All @@ -453,7 +441,7 @@
"outputs": [],
"source": [
"teleprompter = BootstrapFewShotWithRandomSearch(metric = overall_metric, max_bootstrapped_demos=2, num_candidate_programs=6)\n",
"compiled_with_assertions_quiz_generator = teleprompter.compile(student=quiz_generator, teacher = quiz_generator_with_assertions, trainset=trainset, valset=devset[:100])\n",
"compiled_with_assertions_quiz_generator = teleprompter.compile(student=quiz_generator, teacher = quiz_generator_with_assertions, trainset=trainset, valset=devset[:25])\n",
"\n",
"\n",
"for metric in metrics:\n",
Expand All @@ -468,7 +456,7 @@
"outputs": [],
"source": [
"teleprompter = BootstrapFewShotWithRandomSearch(metric = overall_metric, max_bootstrapped_demos=2, num_candidate_programs=6)\n",
"compiled_quiz_generator_with_assertions = teleprompter.compile(student=quiz_generator_with_assertions, teacher = quiz_generator_with_assertions, trainset=trainset, valset=devset[:100])\n",
"compiled_quiz_generator_with_assertions = teleprompter.compile(student=quiz_generator_with_assertions, teacher = quiz_generator_with_assertions, trainset=trainset, valset=devset[:25])\n",
"\n",
"for metric in metrics:\n",
" evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)\n",
Expand Down
Loading

0 comments on commit 2dbc249

Please sign in to comment.