From ec3bac61cbb171ada184e40bea1c8089fcee0ce6 Mon Sep 17 00:00:00 2001 From: gohai Date: Fri, 28 Jul 2023 12:25:42 +0800 Subject: [PATCH] Add Emscripten interface to llama2.c This is based on the initial prototype by @ggerganov: https://github.com/ggerganov/llama2.c/tree/web --- Makefile | 39 +++++ README.md | 258 +++++++++++++------------------ run.c | 219 +++++++++++++++++++++----- web/dist/basic.html | 39 +++++ web/dist/manual.html | 62 ++++++++ web/package.json | 25 +++ web/src/interface.js | 278 +++++++++++++++++++++++++++++++++ web/utils/callcallback.js | 40 +++++ web/utils/handleArguments.js | 287 +++++++++++++++++++++++++++++++++++ web/webpack.config.js | 46 ++++++ 10 files changed, 1098 insertions(+), 195 deletions(-) create mode 100644 web/dist/basic.html create mode 100644 web/dist/manual.html create mode 100644 web/package.json create mode 100644 web/src/interface.js create mode 100644 web/utils/callcallback.js create mode 100644 web/utils/handleArguments.js create mode 100644 web/webpack.config.js diff --git a/Makefile b/Makefile index 360cd2f..e081038 100644 --- a/Makefile +++ b/Makefile @@ -45,6 +45,45 @@ rungnu: runompgnu: $(CC) -Ofast -fopenmp -std=gnu11 run.c -lm -o run +# includes model & tokenizer +.PHONY: emscripten +emscripten: run.c + emcc -O3 run.c \ + -o web/src/llama2.js \ + -s EXPORTED_FUNCTIONS='["_main", "_main_loop", "_malloc", "_free", "_register_callback", "_set_parameters", "_generate", "_manual_start", "_manual_next", "_get_vocab", "_get_vocab_size"]' \ + -s EXPORTED_RUNTIME_METHODS='["ccall", "addFunction", "UTF8ToString"]' \ + -s ALLOW_MEMORY_GROWTH=1 \ + -s ALLOW_TABLE_GROWTH=1 \ + -s MODULARIZE \ + -s EXPORT_NAME='Llama2' \ + --preload-file model.bin \ + --preload-file tokenizer.bin + +# includes tokenizer only, model loaded from URL +.PHONY: emscripten-small +emscripten-small: run.c + emcc -O3 run.c \ + -o web/src/llama2.js \ + -s EXPORTED_FUNCTIONS='["_main", "_main_loop", "_malloc", "_free", "_register_callback", "_set_parameters", "_generate", "_manual_start", "_manual_next", "_get_vocab", "_get_vocab_size"]' \ + -s EXPORTED_RUNTIME_METHODS='["ccall", "addFunction", "UTF8ToString"]' \ + -s ALLOW_MEMORY_GROWTH=1 \ + -s ALLOW_TABLE_GROWTH=1 \ + -s MODULARIZE \ + -s EXPORT_NAME='Llama2' \ + --preload-file tokenizer.bin + +# model & tokenizer loaded from URL +.PHONY: emscripten-min +emscripten-min: run.c + emcc -O3 run.c \ + -o web/src/llama2.js \ + -s EXPORTED_FUNCTIONS='["_main", "_main_loop", "_malloc", "_free", "_register_callback", "_set_parameters", "_generate", "_manual_start", "_manual_next", "_get_vocab", "_get_vocab_size"]' \ + -s EXPORTED_RUNTIME_METHODS='["ccall", "addFunction", "UTF8ToString"]' \ + -s ALLOW_MEMORY_GROWTH=1 \ + -s ALLOW_TABLE_GROWTH=1 \ + -s MODULARIZE \ + -s EXPORT_NAME='Llama2' + .PHONY: clean clean: rm -f run diff --git a/README.md b/README.md index 0e801c5..e395dd0 100644 --- a/README.md +++ b/README.md @@ -1,223 +1,175 @@ -## llama2.c +## llama2.c Emscripten -

- Cute Llama -

+This is an Emscripten (JavaScript) port of [@karpathy](https://github.com/karpathy)'s [llama2.c](https://github.com/karpathy/llama2.c). This was initially accomplished by [@ggerganov](https://github.com/ggerganov) (see PR [#12](https://github.com/karpathy/llama2.c/pull/12)). This repository attempts to build this out some more and stay current with upstream llama2.c. -With the code in this repo you can train the Llama 2 LLM architecture from scratch in PyTorch, then export the weights to a binary file, and load that into one ~simple 500-line C file ([run.c](run.c)) that inferences the model. Alternatively, you can load, finetune, and inference Meta's Llama 2 (but this is still being actively fleshed out). Hence, this repo is a "fullstack" train + inference solution for Llama 2 LLM, with a focus on minimalism and simplicity. You might think that you need many billion parameter LLMs to do anything useful, but in fact very small LLMs can have surprisingly strong performance if you make the domain narrow enough. I recommend looking at the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) paper for inspiration. +See the [llama2.c README](https://github.com/karpathy/llama2.c/blob/master/README.md) for more information. -Please note that this started recently as just a fun weekend project: I took my earlier [nanoGPT](https://github.com/karpathy/nanoGPT), tuned it to implement the Llama-2 architecture instead of GPT-2, and the meat of it was writing the C inference engine in [run.c](run.c). So the project is young and moving quickly. Hat tip to the awesome [llama.cpp](https://github.com/ggerganov/llama.cpp) for inspiring this project. I wanted something super minimal so I chose to hard-code the Llama 2 architecture, stick to fp32, and just roll one inference file of pure C with no dependencies. -## feel the magic +### Features -Let's just run a baby Llama 2 model in C. You need a model checkpoint. Download this 15M parameter model I trained on the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset (~60MB download): +* Model and tokenizer can be optionally loaded from a URL +* Works via Promise (async/await), or event, or callback +* Probabilities are exposed to JavaScript +* Ability to manually pick next token +* Optionally stop on BOS or EOS token +* Simple output word tokenization -```bash -wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin -``` +### Building -Compile and run the C code: +One of: -```bash -make run -./run stories15M.bin ``` - -You'll see the text stream a sample. On my M1 MacBook Air this runs at ~110 tokens/s. See [performance](#performance) or the Makefile for compile flags that can significantly speed this up. We can also try a bit bigger 42M parameter model: - -```bash -wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin -./run stories42M.bin +make emscripten [requires a model.bin, model+tokenizer included in build artifact] +make emscripten-small [model to be loaded from URL, tokenizer included in build artifact] +make emscripten-min [model+tokenizer to be loaded from URL] ``` -This still runs at interactive rates and samples more coherent and diverse stories: - -> Once upon a time, there was a little girl named Lily. She loved playing with her toys on top of her bed. One day, she decided to have a tea party with her stuffed animals. She poured some tea into a tiny teapot and put it on top of the teapot. Suddenly, her little brother Max came into the room and wanted to join the tea party too. Lily didn't want to share her tea and she told Max to go away. Max started to cry and Lily felt bad. She decided to yield her tea party to Max and they both shared the teapot. But then, something unexpected happened. The teapot started to shake and wiggle. Lily and Max were scared and didn't know what to do. Suddenly, the teapot started to fly towards the ceiling and landed on the top of the bed. Lily and Max were amazed and they hugged each other. They realized that sharing was much more fun than being selfish. From that day on, they always shared their tea parties and toys. - -You can also prompt the model with a prefix (sadly, because this is currently done via positional arguments, you also have to specify temperature 1.0 and 256 steps, before you enter the prompt): +Followed by: -```bash -./run stories42M.bin 1.0 256 "One day, Lily met a Shoggoth" +``` +cd web +npm install +npm run build ``` -> One day, Lily met a Shoggoth. He was very shy, but was also very generous. Lily said “Hello Shoggy! Can I be your friend?” Shoggy was happy to have a friend and said “Yes, let’s explore the universe together!” So they set off on a journey to explore the universe. As they travelled, Shoggy was happy to explain to Lily about all the wonderful things in the universe. At the end of the day, Lily and Shoggy had gathered lots of wonderful things from the universe, and they both felt very proud. They promised to explore the universe as one big pair and to never stop being generous to each other. - -There is also an even better 110M param model available, see [models](#models). - -## Meta's Llama 2 models +### API -As the neural net architecture is identical, we can also inference the Llama 2 models released by Meta. Sadly there is a bit of friction here due to licensing (I can't directly upload the checkpoints, I think). So Step 1, get the Llama 2 checkpoints by following the [Meta instructions](https://github.com/facebookresearch/llama). Once we have those checkpoints, we have to convert them into the llama2.c format. -For this we need to install the python dependencies (`pip install -r requirements.txt`) and then use the `export_meta_llama_bin.py` file, e.g. for 7B model: +#### Initialization -```bash -python export_meta_llama_bin.py path/to/llama/model/7B llama2_7b.bin ``` - -The export will take ~10 minutes or so and generate a 26GB file (the weights of the 7B model in float32) called `llama2_7b.bin` in the current directory. It has been [reported](https://github.com/karpathy/llama2.c/pull/85) that despite efforts, the 13B export currently doesn't work for unknown reaons (accepting PRs for fix). We can run the model as normal: - -```bash -./run llama2_7b.bin +const llama2 = await new LLAMA(); ``` -This ran at about 4 tokens/s compiled with [OpenMP](#OpenMP) on 96 threads on my CPU Linux box in the cloud. (On my MacBook Air M1, currently it's closer to 30 seconds per token if you just build with `make runfast`.) Example output: - -> The purpose of this document is to highlight the state-of-the-art of CoO generation technologies, both recent developments and those in commercial use. The focus is on the technologies with the highest merit to become the dominating processes of the future and therefore to be technologies of interest to S&T ... R&D. As such, CoO generation technologies developed in Russia, Japan and Europe are described in some depth. The document starts with an introduction to cobalt oxides as complex products and a short view on cobalt as an essential material. The document continues with the discussion of the available CoO generation processes with respect to energy and capital consumption as well as to environmental damage. - -base models... ¯\\_(ツ)_/¯. Since we can inference the base model, it should be possible to also inference the chat model quite easily, and have a conversation with it. And if we can find a way to run 7B more efficiently, we can start adding LoRA to our training script, and going wild with finetunes all within the repo! - -## models - -For the sake of examples of smaller, from-scratch models, I trained a small model series on TinyStories. All of these trained in a few hours on my training setup (4X A100 40GB GPUs). The 110M took around 24 hours. I am hosting them on huggingface hub [tinyllamas](https://huggingface.co/karpathy/tinyllamas), both in the original PyTorch .pt, and also in the llama2.c format .bin: +You can optionally provide the some of the following options: -| model | dim | n_layers | n_heads | max context length | parameters | val loss | download -| --- | --- | --- | --- | --- | --- | --- | --- | -| OG | 288 | 6 | 6 | 256 | 15M | 1.072 | [stories15M.bin](https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin) | -| 42M| 512 | 8 | 8 | 1024 | 42M | 0.847 | [stories42M.bin](https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin) | -| 110M| 768 | 12 | 12 | 1024 | 110M | 0.760 | [stories110M.bin](https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin) | - -You'll notice that the 110M model is equivalent to GPT-1 in size. Alternatively, this is also the smallest model in the GPT-2 series (`GPT-2 small`), except the max context length is only 1024 instead of 2048. The only notable changes from GPT-1/2 architecture is that Llama uses RoPE relatively positional embeddings instead of absolute/learned positional embeddings, a bit more fancy SwiGLU non-linearity in the MLP, RMSNorm instead of LayerNorm, bias=False on all Linear layers, and is optionally multiquery (but this is not yet supported in llama2.c). +``` +const options = { + modelUrl: '', // use a custom model from the provided URL instead + tokenizerUrl: '', // use the tokenizer.bin from the provided URL instead + steps: 0, // how many tokens to generate (default: model's maximum) + temperature: 0.9, // 0.0 = (deterministic) argmax sampling, 1.0 = baseline + stopOnBosOrEos: true // stop when encountering beginning-of-sequence or end-of-sequence token +} -## training +const llama2 = await new LLAMA(options); +``` -Let's see how we can train a baby Llama 2 from scratch using the code in this repo. First let's download and pretokenize some source dataset, e.g. I like [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) so this is the only example currently available in this repo. But it should be very easy to add datasets, see the code. +If you are in a context where you can't use `await`, you can instead also provide a callback function that will be invoked when the model is ready: -```bash -python tinystories.py download -python tinystories.py pretokenize ``` +function modelReady() { + console.log('LLAMA2 is ready'); +} -Then train our model: +let llama2 = new LLAMA(modelReady); -```bash -python train.py +// or: let llama2 = new LLAMA(options, modelReady); ``` -**brief training guide**. See the train.py script for more exotic launches and hyperparameter overrides. Here is a brief guide to how to set the parameters. Look at the table at the very end of the [Chinchilla paper](https://arxiv.org/abs/2203.15556) to get a sense of how the Transformer parameters (dim, n_layers, n_heads) grow or shrink together. Extrapolate/interpolate this pattern to get bigger or smaller transformers. Set the max context length however you wish, depending on the problem: this should be the max number of tokens that matter to predict the next token. E.g. Llama 2 uses 2048. Next, you want the _total_ batch size per update (printed by the script as "tokens per iteration will be:") to be somewhere around 100K tokens for medium-sized applications. For tiny applications it could be lower, for large training (e.g. GPTs/LLamas) it is usually ~0.5M, or even more. You get there by first maxing out the batch_size to whatever your system allows (e.g. mine was 16 in a recent run because after that my GPU runs out of memory), and then you want to increase gradient_accumulation_steps to be as high as necessary to reach the total batch size of ~100K. Finally, you want to tune your learning_rate (LR). You want this to be as high as your training allows. Very small networks can get away with a large LR (e.g. 1e-3 or even higher). Large networks need lower LRs. 3e-4 is a safe choice in most medium-sized applications, but can be too low for small networks, so try to increase it! Finally, max_iters is the length of training. Play with different settings. I mostly only ever tune these parameters and leave most of the others unchanged. Here is an example of how I trained the 110M model, which I don't think is anywhere near optimal, but looked sensible to me: dim 768, n_layers 12, n_heads 12 (so size of each head is 768 / 12 = 64 channels), seq len of 1024, batch size 16 (this is the most that fit my A100 40GB GPU), gradient_accumulation_steps = 8 was needed to get total tokens batch size to be 16 batch size * 1024 tokens in sequence * 8 grad_accum = 131,072 tokens per update. Good. Learning rate 4e-4 (probably a little too low). max_iters 200K (probably a bit too high). Dropout 0.1, as that usually helps a bit at medium size. That was it. I ran using Distributed Data Parallel (DDP) on 4 GPUs on my cloud machine, training took ~day or so. +#### Generate output -Totally understand if you want to skip model training, for simple demo just download one of the pretrained models (see [models](#models) section), e.g.: +Use the `generate` method to generate output starting with a given prompt string: -```bash -wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin ``` - -Once we have the model.bin file, we can inference in C. Compile the C code first: - -```bash -make run +const out = await llama2.generate('Today was a great day in'); +console.log(out); ``` -You can now run it simply as +You can also pass a callback function to be executed when the generation has finished: -```bash -./run stories15M.bin ``` +function finishedGenerating(llama2) { + console.log(llama2.out); +} -Watch the tokens stream by, fun! We can also run the PyTorch inference script for a comparison. Download one of the models again from huggingface hub and point the `sample.py` script at it: - -```bash -wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt -P out15M -mv out15M/stories15M.pt out15M/ckpt.pt # sorry the sample script current assumes this directory structure / filename... -python sample.py --out_dir=out15M +llama2.generate('Today was a great day in', finishedGenerating); ``` -Which gives the same results. More detailed testing will be done in `test_all.py`. Currently you will need two files to test or sample: both the .bin file, and the .ckpt file inside a directory (see `test_all.py` for details). Sorry this is a bit janky right now, I have to think through running the tests without having to download 200MB of data. But run the tests with pytest: +As the second argument, an object with the following option can optionally be passed: `temperature`, `steps`, `stopOnBosOrEos`. Those will overwrite previous options: -```bash -$ pytest ``` +const out = await llama2.generate('Today was a great day in', { temperature: 0.8 }); -## performance - -There are many ways to potentially speed up this code depending on your system. Have a look at the [Makefile](Makefile), which contains a lot of notes. The `make run` command currently uses the `-O3` optimization by default, i.e.: - -```bash -gcc -O3 -o run run.c -lm +// or: llama2.generate('Today was a great day in', { temperature: 0.8 }, finishedGenerating); ``` --O3 includes optimizations that are expensive in terms of compile time and memory usage. Including vectorization, loop unrolling, and predicting branches. - -To get a much better performance, try to compile with `make runfast`. This turns on the `-Ofast` flag, which includes additional optimizations that may break compliance with the C/IEEE specifications, in addition to `-O3`. See [the GCC docs](https://gcc.gnu.org/onlinedocs/gcc/Optimize-Options.html) for more information. +#### Events -Try `-march=native` to compile the program to use the architecture of the machine you're compiling on rather than a more generic CPU. This may enable additional optimizations and hardware-specific tuning such as improved vector instructions/width. +The `generate` method will emit the following events: -The fastest throughput I saw so far on my MacBook Air (M1) so far is with `make runfast`. +##### token Event -You can also experiment with replacing `gcc` with `clang`. +Emitted at every token generated: -### OpenMP -Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors. -You'll need to install the OpenMP library and the clang compiler first (e.g. `apt install clang libomp-dev` on ubuntu). I was not able to get improvements from OpenMP on my MacBook, though. Then you can compile with `make runomp`, which does: - -```bash -clang -Ofast -fopenmp -march=native run.c -lm -o run ``` - -When you run inference make sure to use OpenMP flags to set the number of threads, e.g.: - -```bash -OMP_NUM_THREADS=4 ./run out/model.bin +llama2.on('token', function(llama2) { + console.log('token', llama2.tokens[llama2.tokens.length-1]); + // will print e.g.: + // {index: 3057, str: 'Test', probability: -4.414192199707031} +}); ``` -Depending on your system resources you may want to tweak these hyperparameters and use more threads. But more is not always better, usually this is a bit U shaped. +##### word Event -## platforms +Emitted at every detected word added to the output: -On **Windows**, use `build_msvc.bat` in a Visual Studio Command Prompt to build with msvc, or you can use `make win64` to use mingw compiler toolchain from linux or windows to build the windows target. MSVC build will automatically use openmp and max threads appropriate for your CPU unless you set `OMP_NUM_THREADS` env. +``` +llama2.on('word', function(word, llama2) { + console.log('word', word); +}); +``` -On **Centos 7**, **Amazon Linux 2018** use `rungnu` Makefile target: `make rungnu` or `make runompgnu` to use openmp. +##### finish Event -## ack +Emitted at the end of the generation: -I trained the llama2.c storyteller models on a 4X A100 40GB box graciously provided by the excellent [Lambda labs](https://lambdalabs.com/service/gpu-cloud), thank you. +``` +llama2.on('finish', function(llama2) { + console.log('finish', llama2.out); +}); +``` -## discord +#### Manual generation -Figured it's possible to reuse my existing discord channel (that I use for my [zero to hero youtube series](https://karpathy.ai/zero-to-hero.html)), see #llama2c channel on [discord](https://discord.gg/3zy8kqD9Cp), for any quick questions, related discussions, etc. +Rather than receiving the finished output as-is, it's also possible to receive an array with possible continuations at each token, and manually - or programatically - select. The methods to do so are: `manualStart` and `manualNext`. The array of continuations are sorted by probability descending. -## contributing +``` +let continuations = await llama2.manualStart('Today was a great day in'); -A few words on this repo and the kinds of PRs that are likely to be accepted. What is the goal of this repo? Basically I think there will be a lot of interest in training or finetuning custom micro-LLMs (think ~100M - ~1B params, but let's say up to ~10B params) across a large diversity of applications, and deploying them in edge-adjacent environments (think MCUs, phones, web browsers, laptops, etc.). I'd like this repo to be the simplest, smallest, most hackable repo to support this workflow, both training and inference. In particular, this repo is not a complex framework with a 1000 knobs controlling inscrutible code across a nested directory structure of hundreds of files. Instead, I expect most applications will wish to create a fork of this repo and hack it to their specific needs and deployment platforms. +// this will return e.g.: +// [{ index: 278, str: ' the', probability: 0.9308871626853943 }, + { index: 3762, str: ' school', probability: 0.014727797359228134 }, + { index: 6709, str: ' spring', probability: 0.013729158788919449 }, ...] -People who care about deployment efficiency above all else should look at [llama.cpp](https://github.com/ggerganov/llama.cpp). This repo still cares about efficiency, but not at the cost of simplicity, readability or portability. Basically, I expect that a lot of people come to this repo because the training code is 2 readable .py files and the inference code is 500 lines of C. So I'd like this to continue to be a kind of simplest "reference implementation" that can be easily hacked in a separate fork into whatever downstream application people are excited about. It shouldn't be full-featured. It shouldn't take 100 different options or settings. It shouldn't be the most efficient. A few examples: +continuations = await llama2.manualNext(continuations[0]); -- someone re-ordered two loops to improve data locality for a small efficieny win => instant merge. -- someone added the one line "pragma omp parallel for", which allows you to compile with OpenMP and dramatically speed up the code, or acts as just a comment if you don't compile it that way => instant merge. -- bug fixes and touchups etc. => happy to merge +// ... +``` -A few examples of PRs are that are not an excellent fit: +`manualNext` also accepts the number of the index instead of the full object. Instead of `await`, a callback function can be used as well: -- adding more than several #ifdefs all over the place in code. If they are localized / few, might be okay. -- adding a lot of code that is very specific to some specific platform (e.g. MCUs, or some special version of linux or processor). These may be a better fit for forks of the project, and I am very happy to maintain a list of these forks in section below. -- adding hundreds of lines of code to run.c that are only active in specific scenarios or platforms. +``` +function onTokens(tokens, llama2) { + console.log('tokens', tokens[0]); + llama2.manualNext(tokens[0]); +} -If your candidate PRs have elements of these it doesn't mean they won't get merged, it just means they will make it into the gray territory. TLDR: I am eager to merge any mostly small, mostly localized, broadly applicable, clean changes that improve the efficiency and portability of the repo, while keep its hackability and readability. I appreciate all PRs seeking to help me improve the project, thank you! <3. +llama2.manualStart('Today was a great day in', onTokens); +``` -## notable forks +Note that this API does not keep track of whether the number of tokens generated stays within the reasonable limits set by the model. -- [llama2.rs](https://github.com/gaxler/llama2.rs) by @gaxler: a Rust port of this project -- [go-llama2](https://github.com/tmc/go-llama2) by @tmc: a Go port of this project -- [llama2.go](https://github.com/nikolaydubina/llama2.go) by @nikolaydubina: a Go port of this project -- [llama2.go](https://github.com/haormj/llama2.go) by @haormj: a Go port of this project -- [llama2.go](https://github.com/saracen/llama2.go) by @saracen: a Go port of this project -- [llama2.c-android](https://github.com/Manuel030/llama2.c-android): by @Manuel030: adds Android binaries of this project -- [llama2.cpp](https://github.com/leloykun/llama2.cpp) by @leloykun: a C++ port of this project -- [llama2.js](https://github.com/epicure/llama2.js) by @epicure: a JavaScript port of this project -- [llama2.zig](https://github.com/cgbur/llama2.zig) by @cgbur: A Zig port of this project +Alternatively, it's also possible to use an event instead, as shown below. -## unsorted todos +##### onTokens event -- support Llama 2 7B Chat model and tune run.c to Chat UI/UX -- speed up 7B Llama 2 models sufficiently to work at interactive rates on Apple Silicon MacBooks -- possibly include emscripten / web backend (as seen in @gg PR) -- currently the project only runs in fp32, how easy would it be to different precisions? -- look into quantization and what would be involved -- todo multiquery support? doesn't seem as useful for smaller models that run on CPU (?) -- todo support inferencing beyond max_seq_len steps, have to think through the kv cache -- why is MFU so low (~10%) on my A100 40GB for training? -- weird errors with torch.compile and wandb when using DDP -- (LoRA) finetuning of Llama 2 models -- make more better tests to decrease yolo +``` +llama2.on('tokens', function(tokens, llama2) { + console.log('tokens', tokens[0]); + llama2.manualNext(tokens[0]); +}); +``` -## License +### Examples -MIT +See [basic.html](web/dist/basic.html) and [manual.html](web/dist/manual.html). diff --git a/run.c b/run.c index d8f153e..0b130ae 100644 --- a/run.c +++ b/run.c @@ -20,6 +20,10 @@ Then run with: #include #include #endif +#if defined __EMSCRIPTEN__ + #include +#endif + // ---------------------------------------------------------------------------- // Transformer and RunState structs, and related memory management @@ -448,12 +452,72 @@ int argmax(float* v, int n) { } // ---------------------------------------------------------------------------- + +float temperature = 0.9f; // e.g. 1.0, or 0.0 +int steps = 256; // max number of steps to run for, 0: use seq_len +Config config; +TransformerWeights weights; +char** vocab; +float* vocab_scores; +unsigned int max_token_length; +RunState state; +int *prompt_tokens = NULL; +int num_prompt_tokens = 0; +int next; // will store the next token in the sequence +int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer +int pos = 0; // position in the sequence + + +void* on_token_callback = NULL; + + +void main_loop(void * dummy) { +#if defined __EMSCRIPTEN__ + if (pos >= steps) { + emscripten_pause_main_loop(); + return; // pointers might be invalid when this resumes + } +#endif + + // forward the transformer to get logits for the next token + transformer(token, pos, &config, &state, &weights); + + if(pos < num_prompt_tokens) { + // if we are still processing the input prompt, force the next prompt token + next = prompt_tokens[pos]; + } else { + // sample the next token + if (temperature == 0.0f) { + // greedy argmax sampling: take the token with the highest probability + next = argmax(state.logits, config.vocab_size); + } else { + // apply the temperature to the logits + for (int q=0; q= steps)); + } + + // following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89) + char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; + printf("%s", token_str); + fflush(stdout); + + // advance forward + token = next; + pos++; +} + int main(int argc, char *argv[]) { // poor man's C argparse char *checkpoint = NULL; // e.g. out/model.bin - float temperature = 0.9f; // e.g. 1.0, or 0.0 - int steps = 256; // max number of steps to run for, 0: use seq_len char *prompt = NULL; // prompt string // 'checkpoint' is necessary arg @@ -479,8 +543,6 @@ int main(int argc, char *argv[]) { rng_seed = (unsigned int)time(NULL); // read in the model.bin file - Config config; - TransformerWeights weights; int fd = 0; // file descriptor for memory mapping float* data = NULL; // memory mapped data pointer long file_size; // size of the checkpoint file in bytes @@ -508,9 +570,8 @@ int main(int argc, char *argv[]) { if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; } // read in the tokenizer.bin file - char** vocab = (char**)malloc(config.vocab_size * sizeof(char*)); - float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float)); - unsigned int max_token_length; + vocab = (char**)malloc(config.vocab_size * sizeof(char*)); + vocab_scores = (float*)malloc(config.vocab_size * sizeof(float)); { FILE *file = fopen("tokenizer.bin", "rb"); if (!file) { printf("couldn't load tokenizer.bin\n"); return 1; } @@ -527,12 +588,9 @@ int main(int argc, char *argv[]) { } // create and init the application RunState - RunState state; malloc_run_state(&state, &config); // process the prompt, if any - int *prompt_tokens = NULL; - int num_prompt_tokens = 0; if (prompt != NULL) { prompt_tokens = (int*)malloc(config.seq_len * sizeof(int)); bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens); @@ -540,44 +598,19 @@ int main(int argc, char *argv[]) { // start the main loop long start = 0; // used to time our code, only initialized after first iteration - int next; // will store the next token in the sequence - int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer - int pos = 0; // position in the sequence printf("\n"); // explicit print the initial BOS token for stylistic symmetry reasons - while (pos < steps) { - // forward the transformer to get logits for the next token - transformer(token, pos, &config, &state, &weights); - - if(pos < num_prompt_tokens) { - // if we are still processing the input prompt, force the next prompt token - next = prompt_tokens[pos]; - } else { - // sample the next token - if (temperature == 0.0f) { - // greedy argmax sampling: take the token with the highest probability - next = argmax(state.logits, config.vocab_size); - } else { - // apply the temperature to the logits - for (int q=0; q config.seq_len) { + steps = config.seq_len; + } else { + steps = _steps; + } +} + +void generate(char* prompt) { + // reset state + free_run_state(&state); + if (prompt_tokens != NULL) { + free(prompt_tokens); + prompt_tokens = NULL; + } + malloc_run_state(&state, &config); + + // process prompt + if (prompt != NULL) { + prompt_tokens = (int*)malloc(config.seq_len * sizeof(int)); + bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens); + } + + token = 1; + pos = 0; + + // (re-) start the main loop for generation + emscripten_resume_main_loop(); +} + +// +// Besides generate(), which will use the main loop to invoke a +// callback function for every token, the manual_ functions below +// let the caller pick the next token synchronously. You'd want +// to use one or the other. +// + +char** get_vocab() { + return vocab; +} + +int get_vocab_size() { + return config.vocab_size; +} + +int manual_start(char* prompt) { + // stop the main loop of any prior generate() + // the manual_ functions aren't using it + emscripten_pause_main_loop(); + + // reset state + free_run_state(&state); + if (prompt_tokens != NULL) { + free(prompt_tokens); + } + malloc_run_state(&state, &config); + + // process prompt + if (prompt != NULL) { + prompt_tokens = (int*)malloc(config.seq_len * sizeof(int)); + bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens); + } + + token = 1; + pos = 0; + + // run the transformer over the prompt + while (pos < num_prompt_tokens) { + transformer(token, pos, &config, &state, &weights); + token = prompt_tokens[pos]; + pos++; + } + + return token; // return the first token to pass to _next() +} + +float* manual_next(int _token) { + token = _token; + + transformer(token, pos, &config, &state, &weights); + + if (temperature != 0.0f) { + for (int q=0; q + + + + llama2.c-emscripten example + + + + +
+ + + + diff --git a/web/dist/manual.html b/web/dist/manual.html new file mode 100644 index 0000000..5b0b844 --- /dev/null +++ b/web/dist/manual.html @@ -0,0 +1,62 @@ + + + + + llama2.c-emscripten example + + + + + + + + diff --git a/web/package.json b/web/package.json new file mode 100644 index 0000000..21ef108 --- /dev/null +++ b/web/package.json @@ -0,0 +1,25 @@ +{ + "name": "llama2.c-emscripten", + "version": "0.1.0", + "description": "Emscripten (JS) interface to Andrej Karpathy's llama2.c implementation", + "main": "llama2.js", + "scripts": { + "build": "webpack --config webpack.config.js" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/gohai/llama2.c-emscripten.git" + }, + "author": "", + "license": "MIT", + "bugs": { + "url": "https://github.com/gohai/llama2.c-emscripten/issues" + }, + "homepage": "https://github.com/gohai/llama2.c-emscripten#readme", + "devDependencies": { + "webpack": "^5.88.2" + }, + "dependencies": { + "webpack-cli": "^5.1.4" + } +} diff --git a/web/src/interface.js b/web/src/interface.js new file mode 100644 index 0000000..cc132ef --- /dev/null +++ b/web/src/interface.js @@ -0,0 +1,278 @@ +// Copyright (c) 2023 ml5 +// +// This software is released under the MIT License. +// https://opensource.org/licenses/MIT + +import { EventEmitter } from "events"; +import callCallback from "../utils/callcallback"; +import handleArguments from "../utils/handleArguments"; + +import Llama2 from './llama2.js'; +import Llama2Wasm from './llama2.wasm'; +import Llama2Data from './llama2.data'; + + +class LLAMA2 extends EventEmitter { + constructor(optionsOrCb, cb) { + super(); + + this.options = { + modelUrl: '', // if set, model.bin will be preloaded from provided URL (assumed to be embedded in llama2.data if not) + tokenizerUrl: '', // if set, tokenizer.bin will be preloaded from provided URL (assumed to be embedded in llama2.data if not) + steps: 0, // how many tokens to generate (defaults to model's maximum) + temperature: 0.9, // 0.0 = (deterministic) argmax sampling, 1.0 = baseline + stopOnBosOrEos: true, // stop when encountering beginning-of-sequence or end-of-sequence token + }; + + // handle arguments + let callback; + if (typeof optionsOrCb === 'function') { + callback = optionsOrCb; + } else { + if (typeof optionsOrCb === 'object') { + this.options.modelUrl = (typeof optionsOrCb.modelUrl === 'string') ? optionsOrCb.modelUrl : this.options.modelUrl; + this.options.tokenizerUrl = (typeof optionsOrCb.tokenizerUrl === 'string') ? optionsOrCb.tokenizerUrl : this.options.tokenizerUrl; + } + if (typeof cb === 'function') { + callback = cb; + } + } + + this.out = ''; + this.tokens = []; + this.words = []; + this.finished = true; + + this.ready = callCallback(this.loadModel(), callback); + } + + async loadModel() { + const onStdout = (str) => { + //console.log('onStdout', str); + }; + + this.llama2 = await Llama2({ + locateFile(path) { + if (path.endsWith('.wasm')) { + return Llama2Wasm; + } + if (path.endsWith('.data')) { + return Llama2Data; + } + return path; + }, + arguments: ['model.bin'], + print: onStdout, + preRun: [ + (inst) => { + // model.bin and tokenizer.bin can either be baked into the llama2.data file + // (leading to a large library size), or dynamically from an URL provided as + // an option + if (this.options.modelUrl) { + inst.FS_createPreloadedFile('', 'model.bin', this.options.modelUrl, true, false); + } + if (this.options.tokenizerUrl) { + inst.FS_createPreloadedFile('', 'tokenizer.bin', this.options.tokenizerUrl, true, false); + } + } + ] + }); + + const onTokenCallback = await this.llama2.addFunction((tokenStr, token, probability, finished) => { + // ignore tokens after BOS or EOS (with stopOnBosOrEn on) + if (this.finished) { + return; + } + + tokenStr = this.llama2.UTF8ToString(tokenStr); + this.tokens.push({ index: token, str: tokenStr, probability: probability }); + // llama2.c signals finished after completing all steps + if (finished) { + this.finished = true; + } + + // optionally stop after encountering BOS (1) or EOS (2) + if (this.options.stopOnBosOrEos && (token == 1 || token == 2)) { + this.finished = true; + } else { + this.out += tokenStr; + } + + // on-token callback/event + if (this.callback) { + this.callback(this); + } + this.emit('token', this); + + // redo word tokenization + const wordDelimiters = ' .,:;"“?!\n'; + const re = new RegExp('(?=[' + wordDelimiters + '])|(?<=[' + wordDelimiters + '])', 'g'); + const prevNumWords = this.words.length; + this.words = this.out.split(re); + // ignore the last word if we can't be certain it's complete + if (!wordDelimiters.includes(this.out.slice(-1)) && !this.finished) { + this.words.pop(); + } + // on-word event + for (let i=prevNumWords; i < this.words.length; i++) { + this.emit('word', this.words[i], this); + } + + // on-finish promise/event + if (this.finished) { + // fulfill the promise returned by generate() + if (this.promiseResolve) { + this.promiseResolve(this.out); + } + this.emit('finsh', this); + } + }, 'viifi'); + + await this.llama2.ccall('register_callback', null, [ 'number' ], [ onTokenCallback ]); + + //console.log('loadModel done'); + } + + async generate(prompt, optionsOrCb, cb) { + await this.ready; + + // handle arguments + if (typeof optionsOrCb === 'function') { + this.callback = optionsOrCb; + } else { + if (typeof optionsOrCb === 'object') { + this.options.steps = (typeof optionsOrCb.steps === 'number') ? optionsOrCb.steps : this.options.steps; + this.options.temperature = (typeof optionsOrCb.temperature === 'number') ? optionsOrCb.temperature : this.options.temperature; + this.options.stopOnBosOrEos = (typeof optionsOrCb.stopOnBosOrEos == 'boolean') ? optionsOrCb.stopPropagation : this.options.stopOnBosOrEos; + } + if (typeof cb === 'function') { + this.callback = cb; + } else { + this.callback = null; + } + } + + // if there are any outstanding requests, resolve them + // with the output received so far + if (this.promiseResolve) { + this.promiseResolve(this.out); + } + + await this.llama2.ccall('set_parameters', null, [ 'number', 'number' ], [ this.options.temperature, this.options.steps ]); + + this.out = ''; + this.tokens = [{ index: 1, str: '', probability: 1 }]; + this.words = []; + this.finished = false; + + await this.llama2.ccall('generate', null, [ 'string' ], [ prompt ]); + + return new Promise((resolve, reject) => { + this.promiseResolve = resolve; + }); + } + + async vocab() { + if (this._vocab) { + return this._vocab; + } + + await this.ready; + const vocabSize = await this.llama2.ccall('get_vocab_size', 'number', [], []); + const vocabPtr = await this.llama2.ccall('get_vocab', 'number', [], []); + this._vocab = new Array(vocabSize); + for (let i=0; i < vocabSize; i++) { + const strPtr = this.llama2.HEAPU32[(vocabPtr+4*i)/4]; + this._vocab[i] = this.llama2.UTF8ToString(strPtr); + } + return this._vocab; + } + + async manualStart(prompt, optionsOrCb, cb) { + await this.ready; + + // handle arguments + if (typeof optionsOrCb === 'function') { + this.callback = optionsOrCb; + } else { + if (typeof optionsOrCb === 'object') { + this.options.steps = (typeof optionsOrCb.steps === 'number') ? optionsOrCb.steps : this.options.steps; + this.options.temperature = (typeof optionsOrCb.temperature === 'number') ? optionsOrCb.temperature : this.options.temperature; + this.options.stopOnBosOrEos = (typeof optionsOrCb.stopOnBosOrEos == 'boolean') ? optionsOrCb.stopPropagation : this.options.stopOnBosOrEos; + } + if (typeof cb === 'function') { + this.callback = cb; + } else { + this.callback = null; + } + } + + // if there are any outstanding requests, resolve them + // with the output received so far + if (this.promiseResolve) { + this.promiseResolve(this.out); + } + + await this.llama2.ccall('set_parameters', null, [ 'number', 'number' ], [ this.options.temperature, this.options.steps ]); + + this.out = ''; + this.tokens = []; + this.words = []; + this.finished = true; + + let token = await this.llama2.ccall('manual_start', 'number', [ 'string' ], [ prompt ]); + return this.manualNext(token); + } + + async manualNext(token) { + await this.ready; + + if (typeof token === 'number') { + // nothing to do + } else if (typeof token === 'object' && typeof token.index === 'number') { + token = token.index; + } else if (typeof token === 'string') { + // check if numeric + if (token.match(/^\d+$/)) { + token = parseInt(token); + } else { + // look up in vocabulary + const vocab = await this.vocab(); + let found = false; + for (let i=0; i < vocab.length; i++) { + if (token === vocab[i]) { + token = i; + found = true; + break; + } + } + if (!found) { + throw 'Not in vocabulary: ' + token; + } + } + } else { + throw 'Unrecognized next token: ' + token; + } + + const vocab = await this.vocab(); + const logitsPtr = await this.llama2.ccall('manual_next', 'number', [ 'number' ], [ token ]); + + const tokens = new Array(vocab.length-1); + for (let i=1; i < vocab.length; i++) { + tokens[i] = { index: i, str: vocab[i], probability: this.llama2.HEAPF32[(logitsPtr+i*4)/4] }; + } + tokens.sort((a, b) => (a.probability > b.probability) ? -1 : 1); + + // on-tokens callback/event + if (this.callback) { + this.callback(tokens, this); + } + this.emit('tokens', tokens, this); + + return tokens; + } + +} + + +export default LLAMA2; diff --git a/web/utils/callcallback.js b/web/utils/callcallback.js new file mode 100644 index 0000000..a875d03 --- /dev/null +++ b/web/utils/callcallback.js @@ -0,0 +1,40 @@ +// Copyright (c) 2018 ml5 +// +// This software is released under the MIT License. +// https://opensource.org/licenses/MIT + + +/** + * Most ml5 methods accept a callback function which will be + * called with the arguments (error, result). + * + * Generic type T describes the type of the result. + * @template T + * @callback ML5Callback + * @param {unknown} error - any error thrown during the execution of the function. + * @param {T} [result] - the expected result, if successful. + * @return {void} - callbacks can have side effects, but should not return a value. + */ + +/** + * Generic type T describes the type of the result, ie. the value that the Promise will resolve to. + * @template T + * @param {Promise} promise - the Promise to resolve. + * @param {ML5Callback} [callback] - optional callback function to be called + * with the result or error from the resolved Promise. + * @return {Promise} - returns the underlying Promise, which may be rejected. + */ +export default function callCallback(promise, callback) { + if (!callback) return promise; + return new Promise((resolve, reject) => { + promise + .then((result) => { + callback(undefined, result); + resolve(result); + }) + .catch((error) => { + callback(error); + reject(error); + }); + }); +} diff --git a/web/utils/handleArguments.js b/web/utils/handleArguments.js new file mode 100644 index 0000000..387e33b --- /dev/null +++ b/web/utils/handleArguments.js @@ -0,0 +1,287 @@ +/** + * @typedef ImageElement + * @type {HTMLImageElement | HTMLCanvasElement | HTMLVideoElement} + */ + +/** + * Standard input accepted by most TensorFlow models. + * @typedef InputImage + * @type {ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement | tf.Tensor3D} + */ + +/** + * ML5 models accept all TensorFlow image inputs as well as p5 images and videos. + * @typedef ImageArg + * @type {InputImage | p5.Image | p5.Video | p5.Element} + */ + +/** + * Check if a variable is an HTMLVideoElement. + * @param {any} img + * @returns {img is HTMLVideoElement} + */ +export const isVideo = (img) => { + // Must guard all instanceof checks on DOM elements in order to run in node. + return typeof (HTMLVideoElement) !== 'undefined' && + img instanceof HTMLVideoElement; +} + +/** + * Check if a variable is an HTMLAudioElement. + * @param {any} img + * @returns {img is HTMLAudioElement} + */ +export const isAudio = (img) => { + return typeof (HTMLAudioElement) !== 'undefined' && + img instanceof HTMLAudioElement; +} + +/** + * Check if a variable is an HTMLCanvasElement. + * @param {any} img + * @returns {img is HTMLCanvasElement} + */ +export const isCanvas = (img) => { + return typeof (HTMLCanvasElement) !== 'undefined' && + img instanceof HTMLCanvasElement; +} + +/** + * Check if a variable is an HTMLImageElement. + * @param {any} img + * @returns {img is HTMLImageElement} + */ +export const isImg = (img) => { + return typeof (HTMLImageElement) !== 'undefined' && + img instanceof HTMLImageElement; +} + +/** + * Check if a variable is a p5.Image or other p5.Element. + * @param {p5.Element | p5.Image} img + * @returns {img is p5.Element | p5.Image} + */ +export const isP5Image = (img) => { + return 'elt' in img || 'canvas' in img; +} + +/** + * Check if a variable is an instance of ImageData, + * or a plain object with the same properties as ImageData. + * This allows it to work in Node environments where ImageData is not defined. + * @param {any} img + * @returns {img is ImageData} + */ +export const isImageData = (img) => { + if (typeof (ImageData) === 'undefined') { + return ( + typeof img === 'object' && + // TODO: figure out TensorFlow issues with Uint8ClampedArray vs. Uint8Array + (img.data instanceof Uint8ClampedArray || img.data instanceof Uint8Array) && + typeof img.width === 'number' && + typeof img.height === 'number' + ) + } + return img instanceof ImageData; +} + +/** + * Check if an unknown variable is a TensorFlow tensor with rank 3. + * @param {any} img + * @returns {img is tf.Tensor3D} + */ +export const isTensor3D = (img) => { + return false; +} + +/** + * Check if an image is one of HTMLImageElement, HTMLCanvasElement, HTMLVideoElement + * @param {any} img + * @returns {img is ImageElement} + */ +export const isImageElement = (img) => { + return !!img && (isCanvas(img) || isImg(img) || isVideo(img)); +} + +/** + * Check that the provided image is an acceptable format and return it. + * If it is a p5 Image, return the underlying HTML element. + * Otherwise, return null. + * @param {any} img + * @returns {ImageElement | null} + */ +export const getImageElement = (img) => { + if (isImageElement(img)) { + return img; + } + if (typeof img === 'object') { + if (isImageElement(img.canvas)) { + return img.canvas; + } + if (isImageElement(img.elt)) { + return img.elt; + } + } + return null; +} + +/** + * For methods which accept multiple optional arguments, a specific argument might be passed in multiple positions. + * We can determine which argument is which by examining their values. + * + * Creates an object where each argument is assigned to a named property. + * All properties are optional as arguments might be missing. + */ + +/** + * @typedef {object} StandardArguments + * @property {string} [string] - A model name or any other string argument. + * @property {number} [number] - Any numeric argument. + * @property {function} [callback] - A callback function. + * @property {object} [options] - Any object which is not a media object is assumed to be options. + * @property {Array} [array] - Any array. + * @property {HTMLMediaElement} [audio] - Both video and audio-only elements will be assigned to the audio property. + * @property {HTMLVideoElement} [video] - Video elements also get their own property. + * @property {InputImage} [image] - Any video, image, or image data. + */ + +/** + * @class ArgHelper + * @implements {StandardArguments} + */ +class ArgHelper { + /** + * Arguments used to CREATE an image-based model can be: + * - video: an HTMLVideoElement or p5 video. + * - options: an object of options specific to this model. + * - callback: a function to run once the model has been loaded. Called with arguments (error) or (error, result). + * - modelName: some models accept a model name or URL as an argument. + * + * Arguments used to CALL a method an image-based model can be: + * - image: an image or video element or an ImageData object. Valid types: HTMLImageElement, HTMLCanvasElement, + * HTMLVideoElement, ImageData, p5 image, p5 video. + * - options: an object of options specific to this model. + * - callback: a function to run once the method has been completed. + * + * Expected to be provided in order modelName, video/image, options, callback with any omitted. + * This function does not actually require any particular order. + * + * Later arguments will override earlier ones, so `this.video` should always be the first when providing arguments + * from a class method call. + * + * @param {any[]} [args] + */ + constructor(...args) { + args.forEach((arg) => this.addArg(arg)); + } + + /** + * Can add arguments through the constructor or at any time after construction. + * + * @param {any} arg + */ + addArg(arg) { + // skip over falsey arguments and don't throw any error, assuming that these are omissions + // do this check first to prevent accessing properties on null, which is an object + if (arg === undefined || arg === null) { + return; + } + switch (typeof arg) { + case "string": + this.set({ string: arg }); + break; + case "number": + this.set({ number: arg }); + break; + case "function": + this.set({ callback: arg }); + break; + case "object": { + if (isTensor3D(arg) || isImageData(arg)) { + this.set({ image: arg }); + } + // Handle p5 object and HTML elements. + const element = getImageElement(arg); + if (element) { + this.set({ image: element }); + // Videos are also both images and audio. + if (isVideo(element)) { + this.set({ + audio: element, + video: element + }); + } + } + // TODO: handle p5.sound + if (isAudio(arg)) { + this.set({ audio: arg }); + } + // Check for arrays + else if (Array.isArray(arg)) { + this.set({ array: arg }); + } + // All other objects are assumed to be options. + else { + this.set({ options: arg }); + } + break; + } + default: + // Notify user about invalid arguments (would be ok to just skip) + throw new Error("invalid argument"); // TODO: better message. + } + } + + /** + * Set one or more properties and log a warning if it is already set. + * Use the second argument to suppress the warning when overriding behavior is expected. + * + * @param {Partial} values + * @param {boolean} warn + */ + set(values, warn = true) { + Object.keys(values).forEach(property => { + if (warn && this.has(property)) { + console.warn( + `Received multiple ${property} arguments, but only a single ${property} is supported. + The last ${property} will be used.` + ); + } + this[property] = values[property]; + }); + } + + /** + * Check whether or not a given property has been set. + * + * @param {string & keyof StandardArguments} property + * @returns {boolean} + */ + has(property) { + return this[property] !== undefined; + } + + /** + * Check that an argument exists and throw an error if it doesn't. + * + * @param {string & keyof StandardArguments} property + * @param {string} [message] + * @return {this} + */ + require(property, message) { + if (this.has(property)) { + return this; + } + throw new Error(message || `An argument for ${property} must be provided.`); + } +} + +/** + * Export a chainable method instead of the class itself. + * + * @param {any[]} args + * @return {ArgHelper} + */ +export default function handleArguments(...args) { + return new ArgHelper(...args); +}; diff --git a/web/webpack.config.js b/web/webpack.config.js new file mode 100644 index 0000000..0bc57d9 --- /dev/null +++ b/web/webpack.config.js @@ -0,0 +1,46 @@ +const { resolve } = require("path"); + +module.exports = function (env, argv) { + return { + context: __dirname, + entry: "./src/interface.js", + + mode: env.production ? "production" : "development", + devtool: env.production ? "source-map" : "inline-source-map", + output: { + filename: "llama2.js", + path: resolve(__dirname, "dist"), + publicPath: "", + library: { + name: "LLAMA2", + type: "umd", + export: "default", + }, + }, + resolve: { + fallback: { + crypto: false, + fs: false, + path: false + }, + }, + module: { + rules: [ + { + test: /llama2\.wasm$/, + type: "asset/resource", + generator: { + filename: "[name].wasm" + } + }, + { + test: /llama2\.data$/, + type: "asset/resource", + generator: { + filename: "[name].data" + } + } + ] + } + }; +};