Skip to content

Commit

Permalink
test: Write unit tests for token_counter
Browse files Browse the repository at this point in the history
  • Loading branch information
drikusroor committed Apr 15, 2023
1 parent 1073954 commit bdefa24
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
2 changes: 1 addition & 1 deletion autogpt/token_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def count_message_tokens(
logger.warn("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo":
# !Node: gpt-3.5-turbo may change over time.
# !Note: gpt-3.5-turbo may change over time.
# Returning num tokens assuming gpt-3.5-turbo-0301.")
return count_message_tokens(messages, model="gpt-3.5-turbo-0301")
elif model == "gpt-4":
Expand Down
61 changes: 61 additions & 0 deletions tests/test_token_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import unittest
import tests.context

from scripts.token_counter import count_message_tokens, count_string_tokens

class TestTokenCounter(unittest.TestCase):

def test_count_message_tokens(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}
]
self.assertEqual(count_message_tokens(messages), 17)

def test_count_message_tokens_with_name(self):
messages = [
{"role": "user", "content": "Hello", "name": "John"},
{"role": "assistant", "content": "Hi there!"}
]
self.assertEqual(count_message_tokens(messages), 17)

def test_count_message_tokens_empty_input(self):
self.assertEqual(count_message_tokens([]), 3)

def test_count_message_tokens_invalid_model(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}
]
with self.assertRaises(KeyError):
count_message_tokens(messages, model="invalid_model")

def test_count_message_tokens_gpt_4(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}
]
self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15)

def test_count_string_tokens(self):
string = "Hello, world!"
self.assertEqual(count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4)

def test_count_string_tokens_empty_input(self):
self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0)

def test_count_message_tokens_invalid_model(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}
]
with self.assertRaises(NotImplementedError):
count_message_tokens(messages, model="invalid_model")

def test_count_string_tokens_gpt_4(self):
string = "Hello, world!"
self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4)


if __name__ == '__main__':
unittest.main()

0 comments on commit bdefa24

Please sign in to comment.