From bdefa24ac6adcc810ff144bd62b08a0d0e6b9c40 Mon Sep 17 00:00:00 2001 From: Drikus Roor Date: Sat, 15 Apr 2023 14:50:54 +0200 Subject: [PATCH] test: Write unit tests for token_counter --- autogpt/token_counter.py | 2 +- tests/test_token_counter.py | 61 +++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 tests/test_token_counter.py diff --git a/autogpt/token_counter.py b/autogpt/token_counter.py index c12397221f2a..a85a54be0e11 100644 --- a/autogpt/token_counter.py +++ b/autogpt/token_counter.py @@ -27,7 +27,7 @@ def count_message_tokens( logger.warn("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") if model == "gpt-3.5-turbo": - # !Node: gpt-3.5-turbo may change over time. + # !Note: gpt-3.5-turbo may change over time. # Returning num tokens assuming gpt-3.5-turbo-0301.") return count_message_tokens(messages, model="gpt-3.5-turbo-0301") elif model == "gpt-4": diff --git a/tests/test_token_counter.py b/tests/test_token_counter.py new file mode 100644 index 000000000000..d13f2ae02a6a --- /dev/null +++ b/tests/test_token_counter.py @@ -0,0 +1,61 @@ +import unittest +import tests.context + +from scripts.token_counter import count_message_tokens, count_string_tokens + +class TestTokenCounter(unittest.TestCase): + + def test_count_message_tokens(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ] + self.assertEqual(count_message_tokens(messages), 17) + + def test_count_message_tokens_with_name(self): + messages = [ + {"role": "user", "content": "Hello", "name": "John"}, + {"role": "assistant", "content": "Hi there!"} + ] + self.assertEqual(count_message_tokens(messages), 17) + + def test_count_message_tokens_empty_input(self): + self.assertEqual(count_message_tokens([]), 3) + + def test_count_message_tokens_invalid_model(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ] + with self.assertRaises(KeyError): + count_message_tokens(messages, model="invalid_model") + + def test_count_message_tokens_gpt_4(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ] + self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15) + + def test_count_string_tokens(self): + string = "Hello, world!" + self.assertEqual(count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4) + + def test_count_string_tokens_empty_input(self): + self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0) + + def test_count_message_tokens_invalid_model(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ] + with self.assertRaises(NotImplementedError): + count_message_tokens(messages, model="invalid_model") + + def test_count_string_tokens_gpt_4(self): + string = "Hello, world!" + self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4) + + +if __name__ == '__main__': + unittest.main()