Skip to content

Commit

Permalink
Add more utility functions (formatters) for creating probes (#8)
Browse files Browse the repository at this point in the history
* add prompting util functions and their formatters for last_letter and is_present

* add formatters for letter from start and letter from end

* add tests for new formatters
fix tolerance error for spelling grader test

* format the files (facepalm)

* fix linting etc
  • Loading branch information
hrdkbhatnagar authored Jul 18, 2024
1 parent 215c124 commit 514fbd9
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 2 deletions.
159 changes: 159 additions & 0 deletions sae_spelling/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,100 @@ def first_letter(
return prefix + first_char


def last_letter(
word: str,
prefix: str = " ",
capitalize: bool = False,
ignore_leading_space: bool = True,
ignore_non_alpha_chars: bool = True,
) -> str:
"""
return just the last letter of the word, optionally capitalized
e.g. last_letter("cat") -> " t"
"""
if ignore_leading_space:
word = word.strip()
chars = list(word)
if ignore_non_alpha_chars:
chars = [c for c in chars if c.isalpha()]
last_char = chars[-1]
if capitalize:
last_char = last_char.upper()
return prefix + last_char


def is_present(
word: str,
char_to_check: str,
prefix: str = " ",
return_binary: bool = False,
) -> str:
"""
Returns whether a character is present in the word or not
e.g. is_present("cat", "t") -> " 1"
OR
is_present("cat", "t") -> " True"
"""
result = char_to_check in word

return prefix + str(int(result)) if return_binary else prefix + str(result)


def letter_from_start(
word: str,
index: int,
prefix: str = " ",
capitalize: bool = False,
ignore_leading_space: bool = True,
ignore_non_alpha_chars: bool = True,
) -> str:
"""
return the letter of the word at the 'index' position relative to the START, optionally capitalized
e.g. letter_from_start("mobile", 2) -> " b"
"""
if ignore_leading_space:
word = word.strip()

chars = list(word)
if ignore_non_alpha_chars:
chars = [c for c in chars if c.isalpha()]

char_at_idx = chars[index]

if capitalize:
char_at_idx = char_at_idx.upper()
return prefix + char_at_idx


def letter_from_end(
word: str,
index: int,
prefix: str = " ",
capitalize: bool = False,
ignore_leading_space: bool = True,
ignore_non_alpha_chars: bool = True,
) -> str:
"""
return the letter of the word at the 'index' position relative to the END, optionally capitalized
NOTE: This follows the Python notation of negative indexing
i.e mobile[-1] will give 'e' and not 'l'
e.g. letter_from_end("mobile", 2) -> " i"
"""
if ignore_leading_space:
word = word.strip()

chars = list(word)
if ignore_non_alpha_chars:
chars = [c for c in chars if c.isalpha()]

char_at_idx = chars[-index]

if capitalize:
char_at_idx = char_at_idx.upper()
return prefix + char_at_idx


# ----- Formatters -------------------------------
Formatter = Callable[[str], str]


Expand Down Expand Up @@ -101,6 +195,71 @@ def first_letter_formatter(
)


def last_letter_formatter(
prefix: str = " ",
capitalize: bool = False,
ignore_leading_space: bool = True,
ignore_non_alpha_chars: bool = True,
) -> Formatter:
return partial(
last_letter,
prefix=prefix,
capitalize=capitalize,
ignore_leading_space=ignore_leading_space,
ignore_non_alpha_chars=ignore_non_alpha_chars,
)


def is_present_formatter(
char_to_check: str,
return_binary: bool = False,
prefix: str = " ",
) -> Formatter:
return partial(
is_present,
prefix=prefix,
char_to_check=char_to_check,
return_binary=return_binary,
)


def letter_from_start_formatter(
index,
prefix: str = " ",
capitalize: bool = False,
ignore_leading_space: bool = True,
ignore_non_alpha_chars: bool = True,
) -> Formatter:
return partial(
letter_from_start,
index=index,
prefix=prefix,
capitalize=capitalize,
ignore_leading_space=ignore_leading_space,
ignore_non_alpha_chars=ignore_non_alpha_chars,
)


def letter_from_end_formatter(
index,
prefix: str = " ",
capitalize: bool = False,
ignore_leading_space: bool = True,
ignore_non_alpha_chars: bool = True,
) -> Formatter:
return partial(
letter_from_end,
index=index,
prefix=prefix,
capitalize=capitalize,
ignore_leading_space=ignore_leading_space,
ignore_non_alpha_chars=ignore_non_alpha_chars,
)


# --------------------------------


def create_icl_prompt(
word: str,
examples: list[str],
Expand Down
64 changes: 64 additions & 0 deletions tests/test_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
create_icl_prompt,
first_letter,
first_letter_formatter,
is_present,
last_letter,
letter_from_end,
letter_from_start,
spelling,
)
from sae_spelling.vocab import get_alpha_tokens
Expand Down Expand Up @@ -58,6 +62,66 @@ def test_first_letter_can_respect_non_alphanum_chars():
assert first_letter("1cat", ignore_non_alpha_chars=False) == " 1"


# ---
def test_last_letter_can_capitalize_letter():
assert last_letter("cat", capitalize=True) == " T"


def test_last_letter_ignores_non_alphanum_chars_and_leading_space_by_default():
assert last_letter("_cat") == " t"
assert last_letter(" cat") == " t"
assert last_letter(" CAT") == " T"
assert last_letter("▁cat") == " t"
assert last_letter("1cat") == " t"


def test_last_letter_can_respect_non_alphanum_chars():
assert last_letter(" cat", ignore_non_alpha_chars=False) == " t"
assert last_letter("▁cat", ignore_non_alpha_chars=False) == " t"
assert last_letter("cat1", ignore_non_alpha_chars=False) == " 1"


def test_is_present_can_give_num_binary():
assert is_present("cat", "a", return_binary=True) == " 1"
assert is_present("cat", "a", return_binary=False) == " True"


def test_letter_from_start_can_capitalize_letter():
assert letter_from_start("cat", index=1, capitalize=True) == " A"


def test_letter_from_start_ignores_non_alphanum_chars_and_leading_space_by_default():
assert letter_from_start("_cat", index=1) == " a"
assert letter_from_start(" cat", index=1) == " a"
assert letter_from_start(" CAT", index=1) == " A"
assert letter_from_start("▁cat", index=1) == " a"
assert letter_from_start("1cat", index=1) == " a"


def test_letter_from_start_can_respect_non_alphanum_chars():
assert letter_from_start(" cat", index=1, ignore_non_alpha_chars=False) == " a"
assert letter_from_start("▁cat", index=1, ignore_non_alpha_chars=False) == " c"
assert letter_from_start("cat1", index=3, ignore_non_alpha_chars=False) == " 1"


def test_letter_from_end_can_capitalize_letter():
assert letter_from_end("cat", index=1, capitalize=True) == " T"


def test_letter_from_end_ignores_non_alphanum_chars_and_leading_space_by_default():
assert letter_from_end("_cat", index=1) == " t"
assert letter_from_end(" cat", index=1) == " t"
assert letter_from_end(" CAT", index=1) == " T"
assert letter_from_end("▁cat", index=1) == " t"
assert letter_from_end("1cat", index=1) == " t"


def test_letter_from_end_can_respect_non_alphanum_chars():
assert letter_from_end(" cat", index=1, ignore_non_alpha_chars=False) == " t"
assert letter_from_end("▁cat_", index=1, ignore_non_alpha_chars=False) == " _"
assert letter_from_end("cat1", index=1, ignore_non_alpha_chars=False) == " 1"


def test_create_icl_prompt_with_defaults():
prompt = create_icl_prompt("cat", examples=["dog", "bird"], shuffle_examples=False)

Expand Down
6 changes: 4 additions & 2 deletions tests/test_spelling_grader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ def test_spelling_grader_batch_processing_gives_the_same_results_as_individual_p
assert batch_grade.is_correct == individual_grade.is_correct
assert batch_grade.answer == individual_grade.answer
assert batch_grade.prediction == individual_grade.prediction
assert batch_grade.answer_log_prob == approx(individual_grade.answer_log_prob)
assert batch_grade.answer_log_prob == approx(
individual_grade.answer_log_prob, abs=1e-5
)
assert batch_grade.prediction_log_prob == approx(
individual_grade.prediction_log_prob
individual_grade.prediction_log_prob, abs=1e-5
)


Expand Down

0 comments on commit 514fbd9

Please sign in to comment.