Skip to content

Chain-of-thought prompting in BLOOM-7b1 on the Grade School Math Dataset

Notifications You must be signed in to change notification settings

armanbolatov/tinkoff_nlp

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Chain-of-Thought Prompting in BLOOM-7b1 on the Grade School Math Dataset

Описание

Решение задачи Tinkoff NLP — Провести эксперименты со сравнением CoT и ансамблированного CoT на GSM8K с BLOOM-176B.

Когда я пытался запустить эксперименты с BLOOM-176B в колабе, постоянно вылезало предупреждение ControlFailure('Connect failed. msg=failed to find peers: routing: not found') и не находило пир соединения, а на A100 жаловалась на несостыковку во времени с пир серверами. Через Inference API генерация длилась супер долго и возвращало {'error': 'Model bigscience/bloom-petals time out'}. Поэтому эксперименты я проводил с BLOOM-7b1 на NVIDIA A100 40GB. Надеюсь примете решение 🥲

Установка

git clone https://github.com/armanbolatov/tinkoff_nlp.git
cd tinkoff_nlp
git clone https://github.com/openai/grade-school-math.git
pip install -r requirements.txt

Пример запуска эксперимента

python experiment.py \
    --gpu_id=2 \            # Номер GPU
    --checkpoint="bigscience/bloom-7b1" \ # Путь к модели
    --seed=1337 \           # Сид для воспроизводимости few-shot промптов
    --greedy \              # Использовать greedy алгоритм
    --temperature=0.7 \     # Температура для семплирования
    --batch_size=5 \        # Размер батча, нужно подобрать по видеокарте
    --num_thoughts=5 \      # Количество размышлений
    --num_few_shots=8 \     # Количество few-shot промптов
    --only_equations        # Только уравнения в few-shot промптах

Результаты

Я провел 12 экспериментов исследуя влияние количества few-shot промптов, использования только уравнений, ансамблирования CoT, температуры и количества размышлений.

Также для каждого эксперимента, чтобы минимизировать влияние случайности, я сделал 3 запуска (за исключением двух экспериментов) на разных сидах, и посчитал среднее и стандартное отклонение. Конкретный сид соответствовал конкретным few-shot промптам из тренировочного датасета. В таблицах ниже показаны результаты. Все в процентах.

Эксперименты с 8 few-shot промптами с полным решением (t=0.7, N=20 значит ансабмль CoT с температурой 0.7 и 20-ю размышлениями):

Run Greedy t=0.7, N=20 t=0.7, N=5 t=0.3, N=20
1 2.505 1.367 1.292 2.505
2 2.201 1.08 1.233 2.201
3 2.397 1.103 1.149 2.39
AVG 2.368 1.183 1.225 2.365
STD 0.154 0.159 0.072 0.153

Эксперименты с 4 few-shot промптами с полным решением:

Run Greedy t=0.7, N=20 t=0.7, N=5 t=0.3, N=20
1 2.18 1.244 1.135 2.183
2 2.266 1.011 1.202 2.271
3 2.023 1.2 0.998 2.023
AVG 2.156 1.152 1.112 2.159
STD 0.123 0.124 0.104 0.126

Эксперименты с 8 few-shot промптами только с уравнениями:

Run Greedy t=0.7, N=20 t=0.7, N=5 t=0.3, N=20
1 0.432 0.257 0.21 0.432
2 0.484 0.2 0.196 0.484
3 0.399 - - -
AVG 0.438 0.229 0.203 0.458
STD 0.043 0.04 0.01 0.037

Выводы

  1. Ансамблирование CoT на порядок проигрывает жадному CoT на "маленьких" ( < 8B ) моделях. Лучший результат у жадного CoT с 8 few-shot промптами с полным решением — 2.37 ± 0.154.

  2. Результаты данной модели лучше некоторых моделей со сравнительным числом параметров, а именно LaMDA-8B (CoT: 1.6), GPT-6.7B (CoT: 2.4), но хуже чем PaLM (CoT: 4.1), например.

  3. При увеличении числа размышлений с 5 до 20, точность модели всегда увеличивается.

  4. Ансамбль CoT с температурой 0.3 дает почти тот же результат что и детерминированный CoT.

  5. При уменьшении числа few-shot промптов с 8 до 4-х, точность падает. Также есть достаточно заметная девиация в результатах с разными сидами, что говорит о том, что качество модели сильно зависит от качества few-shot промптов, как и подмечено в статье.

  6. Использовать полное решение вместе с уравнениями в few-shot промптах всегда лучше чем использовать исключительно уравнения.

Идеи для улучшения

  1. Попробовать разное количество few-shot промптов. Слишком много — больше шансов переобучиться, слишком мало — недообучиться. Нужно подобрать золотую середину.

  2. Оценивать сложность задачи через определенные эвристики (например длина решения, количество шагов размышлений), и в зависимости от этого подбирать few-shot промпты. Например, попробовать создать модель оценивающую количество шагов решения по условию и выбрать few-shot промпты в которых количество шагов находится в районе предсказанного моделью.

  3. Поменять структуру запросов. К примеру, вместо предложенных решений подавать последовательно шаги решения с соответствующим уравнением.

  4. Попробовать поиграться с Top P и Top K и отсеивать "слишком слабые" размышления.

  5. Взвешенное голосование: брать несколько предсказаний с разной температурой и домножать каждое на определенный коэффициент зависящий от температуры, чтобы более "рандомные" ответы имели меньший вес, чем более "детерминированные".

About

Chain-of-thought prompting in BLOOM-7b1 on the Grade School Math Dataset

Resources

Stars

Watchers

Forks

Languages