-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support bloom models #3553
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PPL TestTest script for torch fp32import argparse
import torch
from transformers import AutoConfig, BloomForCausalLM, BloomTokenizerFast
def calculate_ppl(
device, model, tokenizer, sentence: str, max_length: int = 100, stride: int = 50
) -> float:
sentence_ids = tokenizer.encode(sentence) # do not add bos_token_id
print(sentence_ids)
seq_len = len(sentence_ids)
nlls = []
for begin_loc in range(0, seq_len, stride):
end_loc = min(begin_loc + max_length + stride // 2, seq_len)
if (end_loc - begin_loc) != (max_length + stride // 2):
break
input_ids = sentence_ids[begin_loc:end_loc]
input_ids = torch.tensor([input_ids])
input_ids = input_ids.to(device)
target_ids = input_ids.clone()
target_ids[:, :-stride] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
neg_log_likelihood = outputs.loss
nlls.append(neg_log_likelihood)
if end_loc == seq_len:
break
ggml_nlls = torch.cumsum(torch.stack(nlls) * stride, dim=0)
count = torch.arange(stride, len(nlls) * stride + stride, stride)
chunk_ppls = torch.exp(ggml_nlls / count).cpu().tolist()
for i, ppl in enumerate(chunk_ppls):
print("[{}] {}".format(i + 1, ppl))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="")
args = parser.parse_args()
device = torch.device("cpu")
tokenizer = BloomTokenizerFast.from_pretrained(args.model_name_or_path)
model_config = AutoConfig.from_pretrained(
args.model_name_or_path, trust_remote_code=True
)
model = BloomForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=torch.float16,
config=model_config,
device_map="auto",
)
model.to(device).float() # type: ignore
model.eval() # type: ignore
sens = ["About six million children are reported to child protection agencies in America each year. About 400,000 of those children are placed in protective custody because of severe neglect or abuse. About 500,000 children are placed into foster care and adoptive placements. Abused and neglected children are all around us. These children are invisible in our community, yet each one of us is directly responsible for their plight. They live under our laws; they go to our schools; they are convicted by our courts; many of them spend lifetimes in our prisons. They have no say in the laws and policies that rule their lives. Just like they had no say in the neglect and abuse that was their childhood. Neglected and abused children make up a great majority of the crime, drugs, and violence we experience in our communities. Over fifty percent of the children in the juvenile justice system have diagnosable mental illness, about thirty percent of children in child protection services are proscribed psychotropic medications, & almost eighty percent of youth aging out of foster care lead dysfunctional lives. Ninety percent of the juveniles in the Juvenile Justice System have come out of the Child Protection System (Minnesota’s Chief Justice, Kathleen Blatz). Over 90 percent of the adults in the Criminal Justice System come out of the Juvenile Justice System. Justice Blatz (and others) call it a prison “feeder” system. The United States is the only nation in the world to build prisons based on failed third grade reading scores or the number of children in Child Protection. Children are not aware of the rightness or wrongness of their own abuse. They do not know that abuse is abnormal, or even that it is wrong. To a five-year-old, no matter how painful and frightening her life is, her life is normal. A sad and lasting fact of child abuse is that children blame themselves for the abuse they receive. How can sex, drugs, and violence be unlearned by a ten year old child whose entire life has been just that? It takes years of therapy to change a child’s perception of an abusive past. It takes a great deal longer for an abused child to develop a healthy view of the world and a positive self-image. There is no book a child can go to, or code they are born with, that explains the abnormality of what is happening to them. Children can’t call their senators, or complain to the authorities (they can’t even tell their parents). Behaviors learned by abused children to stay alive in toxic homes are terribly counter-productive once the child is out of the abusive circumstances and trying to live a normal life. The behaviors developed for staying alive and avoiding pain dominate and thus can become significant detriments to getting along in society. As a matter of fact, for many troubled youth, their explosive responses and pain avoidance behaviors define them as uneducated social misfits with criminal histories."]
for _, sen in enumerate(sens):
calculate_ppl(
device, model, tokenizer, sen, max_length=100, stride=50
) Test script for ggml fp16/q4_1./build/bin/perplexity -m models/bloom-1b7.fp16.gguf \
-p "About six million children are reported to child protection agencies in America each year. About 400,000 of those children are placed in protective custody because of severe neglect or abuse. About 500,000 children are placed into foster care and adoptive placements. Abused and neglected children are all around us. These children are invisible in our community, yet each one of us is directly responsible for their plight. They live under our laws; they go to our schools; they are convicted by our courts; many of them spend lifetimes in our prisons. They have no say in the laws and policies that rule their lives. Just like they had no say in the neglect and abuse that was their childhood. Neglected and abused children make up a great majority of the crime, drugs, and violence we experience in our communities. Over fifty percent of the children in the juvenile justice system have diagnosable mental illness, about thirty percent of children in child protection services are proscribed psychotropic medications, & almost eighty percent of youth aging out of foster care lead dysfunctional lives. Ninety percent of the juveniles in the Juvenile Justice System have come out of the Child Protection System (Minnesota’s Chief Justice, Kathleen Blatz). Over 90 percent of the adults in the Criminal Justice System come out of the Juvenile Justice System. Justice Blatz (and others) call it a prison “feeder” system. The United States is the only nation in the world to build prisons based on failed third grade reading scores or the number of children in Child Protection. Children are not aware of the rightness or wrongness of their own abuse. They do not know that abuse is abnormal, or even that it is wrong. To a five-year-old, no matter how painful and frightening her life is, her life is normal. A sad and lasting fact of child abuse is that children blame themselves for the abuse they receive. How can sex, drugs, and violence be unlearned by a ten year old child whose entire life has been just that? It takes years of therapy to change a child’s perception of an abusive past. It takes a great deal longer for an abused child to develop a healthy view of the world and a positive self-image. There is no book a child can go to, or code they are born with, that explains the abnormality of what is happening to them. Children can’t call their senators, or complain to the authorities (they can’t even tell their parents). Behaviors learned by abused children to stay alive in toxic homes are terribly counter-productive once the child is out of the abusive circumstances and trying to live a normal life. The behaviors developed for staying alive and avoiding pain dominate and thus can become significant detriments to getting along in society. As a matter of fact, for many troubled youth, their explosive responses and pain avoidance behaviors define them as uneducated social misfits with criminal histories." \
--ppl-stride 50 -c 100 -b 512 -s 2023 Results
|
PPL results look good to me, I think this PR is ready for a final review :)), @ggerganov |
Nice job. This still lacks tensor offloading for GPU support, but we can fix this later. |
ggerganov
added
model
Model specific
need feedback
Testing and feedback with results are needed
labels
Oct 9, 2023
Tested on M2 Ultra using Metal - seems to work as expected: ./main -m ./models/bloom-1b/ggml-model-f16.gguf -p "I believe the meaning of life is" --ignore-eos -n 64 -t 4 -ngl 1 -s 1
llama_new_context_with_model: compute buffer total size = 500.13 MB
llama_new_context_with_model: max tensor size = 980.00 MB
ggml_metal_add_buffer: allocated 'data ' buffer, size = 4279.47 MB, ( 4280.09 / 147456.00)
ggml_metal_add_buffer: allocated 'kv ' buffer, size = 98.00 MB, ( 4378.09 / 147456.00)
ggml_metal_add_buffer: allocated 'alloc ' buffer, size = 494.02 MB, ( 4872.11 / 147456.00)
system_info: n_threads = 4 / 24 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 |
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 64, n_keep = 0
I believe the meaning of life is determined not by an individual's physical, spiritual or mental well-being but rather their place in a more meaningful context.
The term holistic wellbeing was first coined to describe the concept that people should be healthy and happy as individuals without being forced into health care programs (Barnett & Jones, 2006) . To achieve this
llama_print_timings: load time = 216.92 ms
llama_print_timings: sample time = 330.09 ms / 64 runs ( 5.16 ms per token, 193.89 tokens per second)
llama_print_timings: prompt eval time = 20.13 ms / 7 tokens ( 2.88 ms per token, 347.81 tokens per second)
llama_print_timings: eval time = 609.28 ms / 63 runs ( 9.67 ms per token, 103.40 tokens per second)
llama_print_timings: total time = 1002.27 ms
//////////////////
./main -m ./models/bloom-1b/ggml-model-q4_0.gguf -p "I believe the meaning of life is" --ignore-eos -n 64 -t 4 -ngl 1 -s 1
llama_new_context_with_model: compute buffer total size = 500.13 MB
llama_new_context_with_model: max tensor size = 401.95 MB
ggml_metal_add_buffer: allocated 'data ' buffer, size = 1341.05 MB, ( 1341.67 / 147456.00)
ggml_metal_add_buffer: allocated 'kv ' buffer, size = 98.00 MB, ( 1439.67 / 147456.00)
ggml_metal_add_buffer: allocated 'alloc ' buffer, size = 494.02 MB, ( 1933.69 / 147456.00)
system_info: n_threads = 4 / 24 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 |
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 64, n_keep = 0
I believe the meaning of life is finding the right partner. And that you just don't know how to find it when you're young or mature," she said.
"You really need a mentor who will give you guidance and direction in your relationships - whether it's with friends, family, partners, children.
"My advice would be to trust yourself enough so not to let
llama_print_timings: load time = 167.38 ms
llama_print_timings: sample time = 319.66 ms / 64 runs ( 4.99 ms per token, 200.21 tokens per second)
llama_print_timings: prompt eval time = 21.48 ms / 7 tokens ( 3.07 ms per token, 325.84 tokens per second)
llama_print_timings: eval time = 402.63 ms / 63 runs ( 6.39 ms per token, 156.47 tokens per second)
llama_print_timings: total time = 786.13 ms |
ggerganov
approved these changes
Oct 10, 2023
joelkuiper
added a commit
to vortext/llama.cpp
that referenced
this pull request
Oct 12, 2023
…example * 'master' of github.com:ggerganov/llama.cpp: (34 commits) examples: support LLaVA v1.5 (multimodal model) (ggerganov#3436) docs : fix typo GOMP_CPU_AFFINITY (ggerganov#3597) cmake : fix add_compile_options on macOS typo : it is `--n-gpu-layers` not `--gpu-layers` (ggerganov#3592) ci : check if there is enough VRAM (ggerganov#3596) server : add completion mode (no chat) (ggerganov#3582) prompts : add mnemonics.txt server : fix kv cache management (ggerganov#3588) main : fix session loading bug (ggerganov#3400) server : add parameter -tb N, --threads-batch N (ggerganov#3584) common : fix mirostat state when using multiple sequences (ggerganov#3543) batched : add bench tool (ggerganov#3545) examples : add batched.swift + improve CI for swift (ggerganov#3562) Add MPT model to supported models in README.md (ggerganov#3574) Minor improvements in GPT2 tokenizer (ggerganov#3567) readme : add bloom (ggerganov#3570) llm : add bloom models (ggerganov#3553) swift : improvements and fixes (ggerganov#3564) llm : add MPT support (ggerganov#3417) infill. : fix tokenization (ggerganov#3508) ...
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is a follow-up PR, plz see ggerganov/ggml#543
Test Script
./build/bin/main -m models/bloom-1b7.fp16.gguf \ -p "Building a website can be done in 10 simple steps:\nStep 1:" \ -n 100 -e --temp 1.0 --top-k 1 --top-p 1.0 \ --repeat-last-n 0 -s 2023
Tested Models
https://huggingface.co/bigscience/bloom-1b7
(fp16)
(q4_1)
https://huggingface.co/bigscience/bloom-3b
(fp16)
(q4_1)
https://huggingface.co/bigscience/bloom-7b1
(fp16)
(q4_1)
https://huggingface.co/Langboat/bloom-1b4-zh
(fp16)
(q4_1)
TODO