From e1a7042943a0016dc979554372641112150dc346 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sun, 2 Jul 2023 00:10:56 +0800 Subject: [PATCH] try out the new rwkv but it seems worse, may revert --- gpttype_adapter.cpp | 25 +- otherarch/llama_v2.cpp | 4 +- otherarch/rwkv_v3.cpp | 1111 +++++++++++++++++++++++++++------------- otherarch/rwkv_v3.h | 65 ++- 4 files changed, 830 insertions(+), 375 deletions(-) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index ce2b6da150514..bdfe326ac3c41 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -431,6 +431,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in else //rwkv_2 { rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads); + + if(inputs.gpulayers>0) + { + rwkv_gpu_offload_layers(rwkv_ctx_v3,inputs.gpulayers); + } + const struct rwkv_file_header & header = rwkv_ctx_v3->instance->model.header; const size_t n_vocab = header.n_vocab; printf("\nDetected Vocab: %d",n_vocab); @@ -811,7 +817,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o { params.top_k = 120; //to disable top_k we actually need to increase this value to a very high number } - if (params.seed <= 0) + if (params.seed <= 0 || params.seed==0xFFFFFFFF) { params.seed = time(NULL); } @@ -1060,14 +1066,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o } else { - if(embd.size()>1) - { - evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out); - } - else - { - evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out); - } + // if(embd.size()>1) + // { + // evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out); + // } + // else + // { + bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2)); + evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out); + //} memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size()); rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out; diff --git a/otherarch/llama_v2.cpp b/otherarch/llama_v2.cpp index ff9f4e6f3c06c..9ba3b236b6338 100644 --- a/otherarch/llama_v2.cpp +++ b/otherarch/llama_v2.cpp @@ -2204,7 +2204,7 @@ struct llama_v2_context * llama_v2_init_from_file( llama_v2_context * ctx = new llama_v2_context; - if (params.seed < 0) { + if (params.seed < 0 || params.seed==0xFFFFFFFF) { params.seed = time(NULL); } @@ -2552,7 +2552,7 @@ int llama_v2_get_kv_cache_token_count(const struct llama_v2_context * ctx) { #define LLAMA_V2_MAX_RNG_STATE (64*1024) void llama_v2_set_rng_seed(struct llama_v2_context * ctx, int seed) { - if (seed < 0) { + if (seed < 0 || seed==0xFFFFFFFF) { seed = time(NULL); } ctx->rng.seed(seed); diff --git a/otherarch/rwkv_v3.cpp b/otherarch/rwkv_v3.cpp index 0396f99340167..42b4820861b80 100644 --- a/otherarch/rwkv_v3.cpp +++ b/otherarch/rwkv_v3.cpp @@ -6,6 +6,13 @@ #include "rwkv_v3.h" #include "ggml.h" +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#endif +#if defined(GGML_USE_CLBLAST) +#include "ggml-opencl.h" +#endif + #include #include #include @@ -17,6 +24,7 @@ #include #define _FILE_OFFSET_BITS 64 +// Puts an optional break point, if debug is enabled. #define RWKV_MAYBE_BREAK #include @@ -38,9 +46,6 @@ #endif #endif -// static_assert(sizeof(stat::st_size) >= 8, "File offsets should be 64-bit or else rwkv.cpp will not be able to load model files over 2GB"); -// static_assert(sizeof(decltype(ftell(NULL))) >= 8, "File offsets should be 64-bit or else rwkv.cpp will not be able to load model files over 2GB"); - // --- Error handling --- thread_local enum rwkv_error_flags global_last_error = RWKV_ERROR_NONE; @@ -124,20 +129,17 @@ inline enum rwkv_error_flags operator|=(enum rwkv_error_flags & a, enum rwkv_err #define RWKV_ASSERT_FALSE_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, false, x, __VA_ARGS__) #define RWKV_ASSERT_NULL_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, NULL, x, __VA_ARGS__) + #define RWKV_CTX_ASSERT_FALSE_MSG(ctx, ERR_VAL, x, ...) RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, false, x, __VA_ARGS__) -#define RWKV_CTX_ASSERT_NULL_MSG(ctx, ERR_VAL, x, ...) RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, NULL, x, __VA_ARGS__) #define RWKV_ASSERT_FALSE(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, false, x) #define RWKV_ASSERT_NULL(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, NULL, x) + #define RWKV_CTX_ASSERT_FALSE(ctx, ERR_VAL, x) RWKV_CTX_ASSERT(ctx, ERR_VAL, false, x) -#define RWKV_CTX_ASSERT_NULL(ctx, ERR_VAL, x) RWKV_CTX_ASSERT(ctx, ERR_VAL, NULL, x) #define RWKV_ENSURE_OR_FALSE(x) RWKV_ENSURE(false, x) #define RWKV_ENSURE_OR_NULL(x) RWKV_ENSURE(NULL, x) #define RWKV_ENSURE_OR_FALSE_MSG(x, ...) RWKV_ENSURE_MSG(false, x, __VA_ARGS__) -#define RWKV_ENSURE_OR_NULL_MSG(x, ...) RWKV_ENSURE_MSG(NULL, x, __VA_ARGS__) -#define RWKV_CTX_ENSURE_OR_FALSE_MSG(ctx, x, ...) RWKV_CTX_ENSURE_MSG(ctx, false, x, __VA_ARGS__) -#define RWKV_CTX_ENSURE_OR_NULL_MSG(ctx, x, ...) RWKV_CTX_ENSURE_MSG(ctx, NULL, x, __VA_ARGS__) // --- Utilities --- @@ -172,13 +174,13 @@ bool rwkv_fwrite_data(FILE * file, const void * data, const size_t length) { return fwrite(data, length, 1, file) == 1; } -// --- File data structures --- +// --- File handling --- #define TYPE_UNKNOWN TYPE_COUNT enum rwkv_type { - TYPE_F32, - TYPE_F16, + TYPE_FP32, + TYPE_FP16, TYPE_Q4_0, TYPE_Q4_1, TYPE_Q4_1_O, // Unsupported @@ -193,8 +195,8 @@ enum rwkv_type { #define GGML_TYPE_UNKNOWN GGML_TYPE_COUNT extern const enum ggml_type rwkv_type_to_ggml[TYPE_COUNT + 1] = { - GGML_TYPE_F32, /* F32 */ - GGML_TYPE_F16, /* F16 */ + GGML_TYPE_F32, /* FP32 */ + GGML_TYPE_F16, /* FP16 */ GGML_TYPE_Q4_0, /* Q4_0 */ GGML_TYPE_Q4_1, /* Q4_1 */ GGML_TYPE_UNKNOWN, /* Q4_1_O */ @@ -207,8 +209,8 @@ extern const enum ggml_type rwkv_type_to_ggml[TYPE_COUNT + 1] = { }; extern const enum rwkv_type rwkv_type_from_ggml[GGML_TYPE_COUNT + 1] = { - TYPE_F32, /* F32 */ - TYPE_F16, /* F16 */ + TYPE_FP32, /* FP32 */ + TYPE_FP16, /* FP16 */ TYPE_Q4_0, /* Q4_0 */ TYPE_Q4_1, /* Q4_1 */ TYPE_Q4_2, /* Q4_2 */ @@ -223,7 +225,7 @@ extern const enum rwkv_type rwkv_type_from_ggml[GGML_TYPE_COUNT + 1] = { TYPE_COUNT, /* COUNT */ }; -extern const char * rwkv_type_to_string[TYPE_COUNT + 1] = {"float32", "float16", "Q4_0", "Q4_1", "Q4_1_O", "Q4_2", "Q4_3", "Q5_0", "Q5_1", "Q8_0", "unknown"}; +extern const char * rwkv_type_to_string[TYPE_COUNT + 1] = {"FP32", "FP16", "Q4_0", "Q4_1", "Q4_1_O", "Q4_2", "Q4_3", "Q5_0", "Q5_1", "Q8_0", "unknown"}; enum rwkv_type rwkv_type_from_string(const char * str) { for (int ord = 0; ord < TYPE_COUNT; ord++) { @@ -290,6 +292,8 @@ struct rwkv_tensor_header { uint32_t data_type; uint32_t width; uint32_t height; + + const size_t size() const; }; struct rwkv_tensor { @@ -303,7 +307,12 @@ bool rwkv_fread_tensor_header(FILE * file, struct rwkv_tensor_header & header) { header.height = 1; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_SHAPE, header.dim_count == 1 || header.dim_count == 2, "Tensor has an invalid shape (%" PRId32 " dimensions)", header.dim_count); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_DATA_TYPE, header.data_type < TYPE_COUNT, "Tensor data type out of range (%" PRId32 " > %" PRId32 ")", header.data_type, TYPE_COUNT - 1); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_DATA_TYPE, rwkv_type_to_ggml[header.data_type] != GGML_TYPE_UNKNOWN, "Tensor data type (%s) is no longer supported", rwkv_type_to_string[header.data_type]); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_DATA_TYPE, + rwkv_type_to_ggml[header.data_type] != GGML_TYPE_UNKNOWN, + "Tensor data type (%s) is no longer supported", + rwkv_type_to_string[header.data_type] + ); if (header.dim_count == 2) { RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_uint32(file, header.height)); @@ -317,22 +326,8 @@ bool rwkv_fwrite_tensor_header(FILE * file, const struct rwkv_tensor_header & he return true; } -size_t rwkv_tensor_size(enum ggml_type type, const int64_t width, const int64_t height = 1) { - struct ggml_tensor decoy {}; - decoy.type = type; - decoy.ne[0] = width; - decoy.ne[1] = height; - decoy.ne[2] = 1; - decoy.ne[3] = 1; - return ggml_nbytes(&decoy); -} - -size_t rwkv_tensor_size(const struct rwkv_tensor_header & header) { - return rwkv_tensor_size(rwkv_type_to_ggml[header.data_type], header.width, header.height); -} - bool rwkv_fskip_tensor_data(FILE * file, const struct rwkv_tensor_header & header) { - return fseek(file, header.key_length + rwkv_tensor_size(header), SEEK_CUR) == 0; + return fseek(file, header.key_length + header.size(), SEEK_CUR) == 0; } bool rwkv_fread_tensor_header_and_skip(FILE * file, struct rwkv_tensor_header & header) { @@ -342,7 +337,7 @@ bool rwkv_fread_tensor_header_and_skip(FILE * file, struct rwkv_tensor_header & } bool rwkv_fread_tensor_data(FILE * file, struct rwkv_tensor & output, void * buffer = NULL) { - size_t data_size = rwkv_tensor_size(output.header); + size_t data_size = output.header.size(); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, output.header.key_length, output.name)); if (buffer) { @@ -361,10 +356,33 @@ bool rwkv_fread_tensor(FILE * file, struct rwkv_tensor & output, void * buffer = return true; } +bool rwkv_fread_ggml_tensor_data(FILE * file, const struct rwkv_tensor_header & header, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, header.key_length, name), "Failed to read tensor name"); + + enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type]; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, ggml_type != GGML_TYPE_UNKNOWN, "Unsupported tensor data type %s from %s", rwkv_type_to_string[header.data_type], name.c_str()); + + tensor = header.dim_count == 1 + ? ggml_new_tensor_1d(ctx, ggml_type, header.width) + : ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); + ggml_set_name(tensor, name.c_str()); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); + return true; +} + +bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { + struct rwkv_tensor_header header; + RWKV_ENSURE_OR_FALSE_MSG(rwkv_fread_tensor_header(file, header), "Invalid tensor header"); + return rwkv_fread_ggml_tensor_data(file, header, ctx, name, tensor); +} + bool rwkv_fwrite_tensor(FILE * file, const struct rwkv_tensor & tensor) { RWKV_ENSURE_OR_FALSE(rwkv_fwrite_tensor_header(file, tensor.header)); RWKV_ENSURE_OR_FALSE(rwkv_fwrite_string(file, tensor.name)); - RWKV_ENSURE_OR_FALSE(rwkv_fwrite_data(file, tensor.data, rwkv_tensor_size(tensor.header))); + RWKV_ENSURE_OR_FALSE(rwkv_fwrite_data(file, tensor.data, tensor.header.size())); return true; } @@ -404,7 +422,7 @@ struct rwkv_model { struct ggml_tensor * ln0_weight; struct ggml_tensor * ln0_bias; - std::unique_ptr layers; + std::unique_ptr layers; struct ggml_tensor * ln_out_weight; struct ggml_tensor * ln_out_bias; @@ -457,28 +475,153 @@ struct ggml_tensor * rwkv_max(ggml_context * ctx, struct ggml_tensor * x, struct struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) { // LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias` // Looks like ggml_norm does the first part, we only need to apply weight & bias. - return ggml_add_inplace(ctx, ggml_mul(ctx, ggml_norm(ctx, x), weight), bias); + return ggml_add_inplace(ctx, ggml_mul_inplace(ctx, ggml_norm(ctx, x), weight), bias); } // --- Implementation --- -// Used to calculate the memory usage of GGML contexts before allocating them. -// Since GGML uses an internal bump allocator that can't be grown at runtime, we need to ensure we have enough space, -// while at the same time, not using more memory than necessary. -struct rwkv_ctx_size { +// Used as a helper during rwkv_ctx_size calculation. +struct rwkv_future_tensor; + +// Used to calculate the memory usage of ggml contexts before allocating them. +// Since ggml uses an internal bump allocator that can't be grown at runtime, we need to ensure we have enough space, +// while at the same time not using more memory than necessary. +struct rwkv_future_ctx { size_t objects_count = 0; - size_t objects_size = 0; + size_t memory_size = 0; size_t scratch_size = 0; + + // Align to GGML_MEM_ALIGN, which can currently be up to 16 + static const size_t align(const size_t size) { + return ((size + 15) & ~15); + } + + void add_objects(const size_t size, const size_t count = 1) { + this->objects_count += count; + + if (size && count) { + this->add_memory(size, count); + } + } + + void add_memory(const size_t size, const size_t count = 1) { + this->memory_size += this->align(size) * count; + } + + void add_scratch(const size_t size, const size_t count = 1) { + this->scratch_size += this->align(size) * count; + } + + void add_data(const bool use_scratch, const size_t size, const size_t count = 1) { + if (use_scratch) { + this->add_scratch(size, count); + } else { + this->add_memory(size, count); + } + } + + struct rwkv_future_tensor declare(const enum ggml_type type, const uint64_t width, const uint64_t height = 1); + + struct rwkv_future_tensor alloc(const enum ggml_type type, const uint64_t width, const uint64_t height = 1, const bool use_scratch = true); }; +struct rwkv_future_tensor { + enum ggml_type type = GGML_TYPE_COUNT; + uint64_t width = 0; + uint64_t height = 0; + + static const size_t size(const enum ggml_type type, const uint64_t width, const uint64_t height) { + struct ggml_tensor decoy {}; + decoy.type = type; + decoy.ne[0] = width; + decoy.ne[1] = height; + decoy.ne[2] = 1; + decoy.ne[3] = 1; + return ggml_nbytes(&decoy); + } + + rwkv_future_tensor() {} + rwkv_future_tensor(const enum ggml_type type, const uint64_t width, const uint64_t height = 1): type(type), width(width), height(height) {} + rwkv_future_tensor(const struct ggml_tensor * ref): type(ref->type), width(ref->ne[0]), height(ref->ne[1]) {} + + struct rwkv_future_tensor alloc(struct rwkv_future_ctx & ctx, const bool use_scratch = true) const { + ctx.add_objects(sizeof(struct ggml_tensor)); + ctx.add_data(use_scratch, rwkv_future_tensor::size(type, width, height)); + return *this; + } + + struct rwkv_future_tensor view(struct rwkv_future_ctx & ctx) const { + ctx.add_objects(sizeof(struct ggml_tensor)); + return *this; + } + + struct rwkv_future_tensor subview(struct rwkv_future_ctx & ctx, const uint32_t width, const uint32_t height = 1) const { + ctx.add_objects(sizeof(struct ggml_tensor), 2); + ctx.add_memory(sizeof(uint32_t) * 2); + return rwkv_future_tensor(type, width, height); + } + + struct rwkv_future_tensor dup(struct rwkv_future_ctx & ctx) const { + return this->alloc(ctx); + } + + struct rwkv_future_tensor layer_norm(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & weight, const struct rwkv_future_tensor & bias) const { + return this->dup(ctx).view(ctx).view(ctx); + } + + struct rwkv_future_tensor repeat(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor reference) const { + return reference.dup(ctx); + } + + struct rwkv_future_tensor set_inplace(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor src) { + ctx.add_objects(sizeof(struct ggml_tensor)); + ctx.add_memory(sizeof(uint32_t) * 5); + return this->view(ctx); + } + + struct rwkv_future_tensor consume(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & other) { + return this->view(ctx); + } + + struct rwkv_future_tensor combine(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & other) const { + return this->dup(ctx); + } + + struct rwkv_future_tensor fn(struct rwkv_future_ctx & ctx) const { + ctx.add_objects(sizeof(struct ggml_tensor)); + ctx.add_memory(sizeof(void *) / sizeof(uint32_t)); + return this->dup(ctx); + } + + struct rwkv_future_tensor mul_mat(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & other) const { + return ctx.alloc(GGML_TYPE_F32, this->height, other.height); + } + + struct rwkv_future_tensor get_rows(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & other) const { + return ctx.alloc(GGML_TYPE_F32, this->width, other.width); + } +}; + +const size_t rwkv_tensor_header::size() const { + return rwkv_future_tensor::size(rwkv_type_to_ggml[this->data_type], this->width, this->height); +} + +struct rwkv_future_tensor rwkv_future_ctx::declare(const enum ggml_type type, const uint64_t width, const uint64_t height) { + return rwkv_future_tensor(type, width, height); +} + +struct rwkv_future_tensor rwkv_future_ctx::alloc(const enum ggml_type type, const uint64_t width, const uint64_t height, const bool use_scratch) { + return this->declare(type, width, height).alloc(*this, use_scratch); +} + struct rwkv_ggml_context { - std::unique_ptr scratch; + std::unique_ptr scratch; struct ggml_context * ctx; rwkv_ggml_context(): ctx(NULL) {} - rwkv_ggml_context(struct rwkv_ctx_size size): ctx(NULL) { - scratch.reset(new(std::nothrow) uint8_t [size.scratch_size]); + rwkv_ggml_context(const struct rwkv_future_ctx future_ctx): ctx(NULL) { + scratch.reset(new(std::nothrow) uint8_t[future_ctx.scratch_size]); if (!scratch) { return; @@ -487,13 +630,13 @@ struct rwkv_ggml_context { const size_t memory_required_overhead = size_t(128) * 1024 * 1024; const size_t memory_required_overhead_sc = size_t(64) * 1024 * 1024; - ctx = ggml_init({ size.objects_count * GGML_OBJECT_SIZE + size.objects_size + memory_required_overhead, NULL, false}); + ctx = ggml_init({ future_ctx.objects_count * GGML_OBJECT_SIZE + future_ctx.memory_size + memory_required_overhead, NULL, false}); if (!ctx) { return; } - ggml_set_scratch(ctx, { 0, memory_required_overhead_sc + size.scratch_size, scratch.get() }); + ggml_set_scratch(ctx, { 0, memory_required_overhead_sc + future_ctx.scratch_size, scratch.get() }); } struct rwkv_ggml_context & operator=(struct rwkv_ggml_context && source) { @@ -516,7 +659,7 @@ struct rwkv_instance { struct rwkv_ggml_context ctx; struct rwkv_model model; - // TODO come up with a better solution to estimate "work tensor" size. + // TODO Come up with a better solution to estimate "work tensor" size // The ggml_cgraph allocates a "work tensor" the first time it is used. // Currently, the height of blocks.0.ffn.key.weight is the bottleneck in our implementation of RWKV. // Since it is the largest dimension used in any matrix multiply, it is the size used for the "work tensor". @@ -528,8 +671,8 @@ struct rwkv_instance { // The hidden state of a single RWKV layer. // These are mostly used for dividing up the input state, and writing portions of the output state. -// But they're also used in building the computation graphs, to represent the operations used from input->output -// (operating "in place" on a rwkv_layer_state). +// But they're also used in building the computation graphs to represent the operations +// used from input->output (operating "in place" on a rwkv_layer_state). struct rwkv_layer_state { struct ggml_tensor * ffn_xx; struct ggml_tensor * att_xx; @@ -538,7 +681,7 @@ struct rwkv_layer_state { struct ggml_tensor * att_pp; }; -// Holds a single computation graph and its GGML context. +// Holds a single computation graph and its ggml context. // Graphs each have their own context so that they can be individually freed and rebuilt. // Graphs read hidden state from the rwkv_context and then write it back to the rwkv_context. // (see rwkv_context.input_layers and rwkv_context.output_layers) @@ -548,6 +691,11 @@ struct rwkv_graph { // ggml_cgraph is so large that it can cause stack overflows if not stored on the heap std::unique_ptr cgraph; + + size_t pre_logits_nodes; + size_t pre_logits_leafs; + size_t post_logits_nodes; + size_t post_logits_leafs; }; // RWKV context for a specific instance. @@ -558,9 +706,9 @@ struct rwkv_context { // Reused by all graphs. struct rwkv_ggml_context ctx; struct ggml_tensor * input_state; - std::unique_ptr input_layers; + std::unique_ptr input_layers; struct ggml_tensor * output_state; - std::unique_ptr output_layers; + std::unique_ptr output_layers; struct ggml_tensor * logits; uint32_t n_threads; @@ -581,40 +729,17 @@ struct rwkv_context { float * logits_out = 0; //stores address of output logit buffer size_t gpu_layers; - size_t vram_total; }; -bool rwkv_fread_ggml_tensor_data(FILE * file, const struct rwkv_tensor_header & header, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, header.key_length, name), "Failed to read tensor name"); - - enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type]; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, ggml_type != GGML_TYPE_UNKNOWN, "Unsupported tensor data type %s from %s", rwkv_type_to_string[header.data_type], name.c_str()); - - tensor = header.dim_count == 1 - ? ggml_new_tensor_1d(ctx, ggml_type, header.width) - : ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height); - - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); - ggml_set_name(tensor, name.c_str()); - - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); - return true; -} - -bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { - struct rwkv_tensor_header header; - RWKV_ENSURE_OR_FALSE_MSG(rwkv_fread_tensor_header(file, header), "Invalid tensor header"); - return rwkv_fread_ggml_tensor_data(file, header, ctx, name, tensor); -} - -template // https://stackoverflow.com/a/6458689 +// https://stackoverflow.com/a/6458689 +template bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback("emb.weight", model.emb)); RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.weight", model.ln0_weight)); RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.bias", model.ln0_bias)); uint32_t n_layer = model.header.n_layer; - std::unique_ptr layers(new(std::nothrow) struct rwkv_layer [n_layer]); + std::unique_ptr layers(new(std::nothrow) struct rwkv_layer[n_layer]); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, layers.get(), "Failed to allocate model layers"); model.layers = std::move(layers); @@ -652,121 +777,108 @@ bool rwkv_set_params(struct rwkv_model & model, F callback) { return true; } -void rwkv_ctx_size_add_objects(struct rwkv_ctx_size & ctx_size, size_t objects, size_t object_size = sizeof(struct ggml_tensor)) { - ctx_size.objects_count += objects; - ctx_size.objects_size += ((object_size + 15) & ~15) * objects; -} - -void rwkv_ctx_size_add_scratch(struct rwkv_ctx_size & ctx_size, size_t length, size_t count = 1) { - ctx_size.scratch_size += ((length + 15) & ~15) * count; -} - -void rwkv_ctx_size_add(struct rwkv_ctx_size & ctx_size, size_t objects, size_t scratch = 0, size_t scratches = 1) { - rwkv_ctx_size_add_objects(ctx_size, objects); - rwkv_ctx_size_add_scratch(ctx_size, scratch, scratches); -} - -void rwkv_ctx_size_add(struct rwkv_ctx_size & ctx_size, size_t count, const struct rwkv_ctx_size & other) { - ctx_size.objects_count += other.objects_count * count; - ctx_size.objects_size += other.objects_size * count; - ctx_size.scratch_size += other.scratch_size * count; -} - -void rwkv_ctx_size_add_tensor(struct rwkv_ctx_size & ctx_size, const uint64_t tensors, const uint64_t views, const enum ggml_type type, const uint64_t width, const uint64_t height = 1) { - rwkv_ctx_size_add_objects(ctx_size, tensors + views); - rwkv_ctx_size_add_scratch(ctx_size, rwkv_tensor_size(type, width, height), tensors); -} - -void rwkv_ctx_size_add_tensor(struct rwkv_ctx_size & size, const uint64_t tensors, const uint64_t views, const struct rwkv_tensor_header & header) { - rwkv_ctx_size_add_tensor(size, tensors, views, rwkv_type_to_ggml[header.data_type], header.width, header.height); -} - -struct rwkv_ctx_size rwkv_xx_size(const size_t n_embed = 0, const size_t sequence_len = 1) { - struct rwkv_ctx_size ctx_size; - - if (sequence_len == 1) { - /* x0 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); +void rwkv_future_carry_x(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor weight, + const struct rwkv_future_tensor bias, + struct rwkv_future_tensor & x, + struct rwkv_future_tensor & x_prev, + struct rwkv_future_tensor & carry +) { + if (x.height == 1) { + x = x.layer_norm(ctx, weight, bias); + x_prev = carry; + carry = x; } else { - /* x0 */ rwkv_ctx_size_add_tensor(ctx_size, 4, 1, GGML_TYPE_F32, n_embed, sequence_len); + x = x.layer_norm(ctx, weight.repeat(ctx, x), bias.repeat(ctx, x)); - /* xx */ rwkv_ctx_size_add_tensor(ctx_size, 1, 2, GGML_TYPE_F32, n_embed, sequence_len); - /* xx */ rwkv_ctx_size_add_objects(ctx_size, 2, sizeof(struct ggml_tensor) + rwkv_tensor_size(GGML_TYPE_I32, 5)); - /* xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, 1, GGML_TYPE_F32, n_embed * sequence_len - 1); + x_prev = x.dup(ctx) + .set_inplace(ctx, carry) + .set_inplace(ctx, x.subview(ctx, x.width, x.height - 1)); - /* xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, 1, GGML_TYPE_F32, n_embed); + carry = x.subview(ctx, x.width); } - - return ctx_size; } -void rwkv_xx(struct ggml_context * ctx, struct ggml_tensor * weight, struct ggml_tensor * bias, struct ggml_tensor *& x, struct ggml_tensor *& xx, struct ggml_tensor *& state) { - size_t n_embed = x->ne[0]; - size_t sequence_len = x->ne[1]; +void rwkv_carry_x(struct ggml_context * ctx, + struct ggml_tensor * weight, + struct ggml_tensor * bias, + struct ggml_tensor *& x, + struct ggml_tensor *& x_prev, + struct ggml_tensor *& carry +) { + const size_t n_embed = x->ne[0]; + const size_t sequence_len = x->ne[1]; if (sequence_len == 1) { // self.layer_norm(x, self.w.blocks[i].ln2) x = rwkv_layer_norm(ctx, x, weight, bias); // xx = state[5*i+0] - xx = state; + x_prev = carry; // state[5*i+0] = x - state = x; + carry = x; } else { // self.layer_norm(x, self.w.blocks[i].ln2) x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, weight, x), ggml_repeat(ctx, bias, x)); // xx = torch.cat((state[5*i+0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1,:])) - xx = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len); - xx = ggml_set_1d_inplace(ctx, xx, state, 0); - xx = ggml_set_1d_inplace(ctx, xx, ggml_view_1d(ctx, x, n_embed * (sequence_len - 1), 0), n_embed * sizeof(float)); + x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len); + x_prev = ggml_set_1d_inplace(ctx, x_prev, carry, 0); + x_prev = ggml_set_1d_inplace(ctx, x_prev, ggml_view_1d(ctx, x, n_embed * (sequence_len - 1), 0), n_embed * sizeof(float)); // state[5*i+0] = x[-1,:] - state = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float)); + carry = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float)); } } -struct rwkv_ctx_size rwkv_att_rkv_size(const size_t n_embed = 0, const size_t sequence_len = 1) { - size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); - - struct rwkv_ctx_size ctx_size; - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - - /* xv */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); - /* xv */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* xv */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - - /* r */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed, sequence_len); - /* r */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* k */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed, sequence_len); - /* v */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed, sequence_len); +void rwkv_future_att_rkv(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor time_mix_k, + const struct rwkv_future_tensor time_mix_v, + const struct rwkv_future_tensor time_mix_r, + const struct rwkv_future_tensor x, + const struct rwkv_future_tensor x_prev, + const struct rwkv_future_tensor att_r, + const struct rwkv_future_tensor att_k, + const struct rwkv_future_tensor att_v, + struct rwkv_future_tensor & r, + struct rwkv_future_tensor & k, + struct rwkv_future_tensor & v +) { + const struct rwkv_future_tensor xk = x.combine(ctx, time_mix_k).consume(ctx, x_prev.combine(ctx, time_mix_k.fn(ctx))); + const struct rwkv_future_tensor xv = x.combine(ctx, time_mix_v).consume(ctx, x_prev.combine(ctx, time_mix_v.fn(ctx))); + const struct rwkv_future_tensor xr = x.combine(ctx, time_mix_r).consume(ctx, x_prev.combine(ctx, time_mix_r.fn(ctx))); - return ctx_size; + r = att_r.mul_mat(ctx, xr).fn(ctx); + k = att_k.mul_mat(ctx, xk); + v = att_v.mul_mat(ctx, xv); } -void rwkv_att_rkv(struct ggml_context * ctx, struct rwkv_layer layer, struct ggml_tensor * x0, struct ggml_tensor * xx, struct ggml_tensor *& r, struct ggml_tensor *& k, struct ggml_tensor *& v) { +void rwkv_att_rkv( + struct ggml_context * ctx, + struct rwkv_layer layer, + struct ggml_tensor * x, + struct ggml_tensor * x_prev, + struct ggml_tensor *& r, + struct ggml_tensor *& k, + struct ggml_tensor *& v +) { // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) struct ggml_tensor * xk = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_k), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) + ggml_mul(ctx, x, layer.att_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) ); // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) struct ggml_tensor * xv = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_v), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) + ggml_mul(ctx, x, layer.att_time_mix_v), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) ); // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) struct ggml_tensor * xr = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_r), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) + ggml_mul(ctx, x, layer.att_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) ); // r = torch.sigmoid(rw @ xr) @@ -777,39 +889,47 @@ void rwkv_att_rkv(struct ggml_context * ctx, struct rwkv_layer layer, struct ggm v = ggml_mul_mat(ctx, layer.att_value, xv); } -struct rwkv_ctx_size rwkv_att_wkv_size(const size_t n_embed = 0) { - size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); - - struct rwkv_ctx_size ctx_size; - /* ww */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* e1 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); - /* e1 */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); - /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - - /* a */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); - /* b */ rwkv_ctx_size_add_tensor(ctx_size, 1, 1, GGML_TYPE_F32, n_embed); - - /* ww */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* e1 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); - /* e1 */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); - /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - - /* aa */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); - /* bb */ rwkv_ctx_size_add_tensor(ctx_size, 1, 1, GGML_TYPE_F32, n_embed); - /* pp */ rwkv_ctx_size_add_tensor(ctx_size, 0, 0, GGML_TYPE_F32, n_embed); - - /* wkv */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); +struct rwkv_future_tensor rwkv_future_att_wkv(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor time_first, + const struct rwkv_future_tensor time_decay, + struct rwkv_future_tensor & aa, + struct rwkv_future_tensor & bb, + struct rwkv_future_tensor & pp, + const struct rwkv_future_tensor k, + const struct rwkv_future_tensor v +) { + struct rwkv_future_tensor ww = time_first.combine(ctx, k); + struct rwkv_future_tensor qq = pp.fn(ctx); + struct rwkv_future_tensor e1 = pp.combine(ctx, qq).fn(ctx); + struct rwkv_future_tensor e2 = ww.combine(ctx, qq).fn(ctx); + + struct rwkv_future_tensor a = e1.combine(ctx, aa).consume(ctx, e2.combine(ctx, v)); + struct rwkv_future_tensor b = e1.combine(ctx, bb).consume(ctx, e2); + + ww = pp.combine(ctx, time_decay); + qq = ww.fn(ctx); + e1 = ww.combine(ctx, qq).fn(ctx); + e2 = k.combine(ctx, qq).fn(ctx); + + // aa, bb + aa = e1.combine(ctx, aa).consume(ctx, e2.combine(ctx, v)); + bb = e1.combine(ctx, bb).consume(ctx, e2); + pp = qq; - return ctx_size; + // wkv + return a.combine(ctx, b); } -struct ggml_tensor * rwkv_att_wkv(struct ggml_context * ctx, struct ggml_tensor * att_time_first, struct ggml_tensor * att_time_decay, struct ggml_tensor * k, struct ggml_tensor * v, struct ggml_tensor *& aa, struct ggml_tensor *& bb, struct ggml_tensor *& pp) { +struct ggml_tensor * rwkv_att_wkv( + struct ggml_context * ctx, + struct ggml_tensor * att_time_first, + struct ggml_tensor * att_time_decay, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor *& aa, + struct ggml_tensor *& bb, + struct ggml_tensor *& pp +) { // ww = time_first + k struct ggml_tensor * ww = ggml_add(ctx, att_time_first, k); // qq = torch.maximum(pp, ww) @@ -844,24 +964,42 @@ struct ggml_tensor * rwkv_att_wkv(struct ggml_context * ctx, struct ggml_tensor return ggml_div(ctx, a, b); } -struct rwkv_ctx_size rwkv_att_size(const size_t n_embed = 0) { - size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); - struct rwkv_ctx_size ctx_size; - /* xx */ rwkv_ctx_size_add(ctx_size, 1, rwkv_xx_size(n_embed)); - /* rkv */ rwkv_ctx_size_add(ctx_size, 1, rwkv_att_rkv_size(n_embed)); - /* wkv */ rwkv_ctx_size_add(ctx_size, 1, rwkv_att_wkv_size(n_embed)); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); +struct rwkv_future_tensor rwkv_future_att(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor ln1_weight, + const struct rwkv_future_tensor ln1_bias, + const struct rwkv_future_tensor time_mix_k, + const struct rwkv_future_tensor time_mix_v, + const struct rwkv_future_tensor time_mix_r, + const struct rwkv_future_tensor time_first, + const struct rwkv_future_tensor time_decay, + const struct rwkv_future_tensor att_r, + const struct rwkv_future_tensor att_k, + const struct rwkv_future_tensor att_v, + const struct rwkv_future_tensor att_output, + struct rwkv_future_tensor x, + struct rwkv_future_tensor & att_xx, + struct rwkv_future_tensor & att_aa, + struct rwkv_future_tensor & att_bb, + struct rwkv_future_tensor & att_pp +) { + struct rwkv_future_tensor x_prev; + rwkv_future_carry_x(ctx, ln1_weight, ln1_bias, x, x_prev, att_xx); + + struct rwkv_future_tensor r, k, v; + rwkv_future_att_rkv(ctx, time_mix_k, time_mix_v, time_mix_r, x, x_prev, att_r, att_k, att_v, r, k, v); + + struct rwkv_future_tensor wkv = rwkv_future_att_wkv(ctx, time_first, time_decay, att_aa, att_bb, att_pp, k, v); - return ctx_size; + return att_output.mul_mat(ctx, r.combine(ctx, wkv)); } struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { - struct ggml_tensor * x0 = x, * xx; - rwkv_xx(ctx, layer.ln1_weight, layer.ln1_bias, x0, xx, state.att_xx); + struct ggml_tensor * x_prev; + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); struct ggml_tensor * r, * k, * v; - rwkv_att_rkv(ctx, layer, x0, xx, r, k, v); + rwkv_att_rkv(ctx, layer, x, x_prev, r, k, v); struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, k, v, state.att_aa, state.att_bb, state.att_pp); @@ -869,74 +1007,133 @@ struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tensor * x, return ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv)); } -struct rwkv_ctx_size rwkv_ffn_size(const size_t n_embed = 0, const size_t ffn_key = 0, const size_t sequence_len = 1) { - size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); - - struct rwkv_ctx_size ctx_size; - /* xx */ rwkv_ctx_size_add(ctx_size, 1, rwkv_xx_size(n_embed, sequence_len)); - - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); +struct rwkv_future_tensor rwkv_future_ffn(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor ln2_weight, + const struct rwkv_future_tensor ln2_bias, + const struct rwkv_future_tensor time_mix_k, + const struct rwkv_future_tensor time_mix_r, + const struct rwkv_future_tensor ffn_k, + const struct rwkv_future_tensor ffn_v, + const struct rwkv_future_tensor ffn_r, + struct rwkv_future_tensor x, + struct rwkv_future_tensor & ffn_xx +) { + struct rwkv_future_tensor x_prev; + rwkv_future_carry_x(ctx, ln2_weight, ln2_bias, x, x_prev, ffn_xx); - /* r */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed, sequence_len); - /* r */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* k */ rwkv_ctx_size_add_tensor(ctx_size, 3, 0, GGML_TYPE_F32, ffn_key, sequence_len); + struct rwkv_future_tensor xk = x.combine(ctx, time_mix_k).consume(ctx, x_prev.combine(ctx, time_mix_k.fn(ctx))); + struct rwkv_future_tensor xr = x.combine(ctx, time_mix_r).consume(ctx, x_prev.combine(ctx, time_mix_r.fn(ctx))); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed, sequence_len); + struct rwkv_future_tensor r = ffn_r.mul_mat(ctx, xr).fn(ctx); + struct rwkv_future_tensor k = ffn_k.mul_mat(ctx, xk).view(ctx).view(ctx); - return ctx_size; + return r.consume(ctx, ffn_v.mul_mat(ctx, k)); } struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { - struct ggml_tensor * x0 = x, * xx; - rwkv_xx(ctx, layer.ln2_weight, layer.ln2_bias, x0, xx, state.ffn_xx); + struct ggml_tensor * x_prev; + rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) struct ggml_tensor * xk = ggml_add_inplace( ctx, - ggml_mul(ctx, x0, layer.ffn_time_mix_k), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) + ggml_mul(ctx, x, layer.ffn_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) ); // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) struct ggml_tensor * xr = ggml_add_inplace( ctx, - ggml_mul(ctx, x0, layer.ffn_time_mix_r), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) + ggml_mul(ctx, x, layer.ffn_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) ); // r = torch.sigmoid(rw @ xr) struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); // k = torch.square(torch.relu(kw @ xk)) - struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); + struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); // r * (vw @ k) - return ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); + return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); } -struct rwkv_ctx_size rwkv_serial_graph_size(const size_t n_vocab, const size_t n_embed, const size_t n_layer, const size_t ffn_key_size) { - struct rwkv_ctx_size ctx_size; - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); +struct rwkv_future_tensor rwkv_future_graph_work(struct rwkv_future_ctx & ctx, + const enum ggml_type type, + const size_t ffn_key_height, + const size_t n_threads, + const size_t sequence_len = 1 +) { +#if defined(GGML_USE_CLBLAST) || defined(GGML_USE_CUBLAS) + enum ggml_type mul_mat_type = type == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; +#else + enum ggml_type mul_mat_type = ggml_is_quantized(type) ? GGML_TYPE_Q8_1 : type; +#endif + return ctx.alloc(GGML_TYPE_I8, rwkv_future_tensor::size(mul_mat_type, ffn_key_height, sequence_len) * n_threads + 64 * (n_threads - 1)); +} + +struct rwkv_future_tensor rwkv_future_serial_graph(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor tokens, + const size_t n_threads, + + const struct rwkv_future_tensor emb, + const struct rwkv_future_tensor ln0_weight, + const struct rwkv_future_tensor ln0_bias, + + const size_t n_layer, + + const struct rwkv_future_tensor ln1_weight, + const struct rwkv_future_tensor ln1_bias, + const struct rwkv_future_tensor att_time_mix_k, + const struct rwkv_future_tensor att_time_mix_v, + const struct rwkv_future_tensor att_time_mix_r, + const struct rwkv_future_tensor att_time_first, + const struct rwkv_future_tensor att_time_decay, + const struct rwkv_future_tensor att_r, + const struct rwkv_future_tensor att_k, + const struct rwkv_future_tensor att_v, + const struct rwkv_future_tensor att_output, + struct rwkv_future_tensor & att_xx, + struct rwkv_future_tensor & att_aa, + struct rwkv_future_tensor & att_bb, + struct rwkv_future_tensor & att_pp, + + const struct rwkv_future_tensor ln2_weight, + const struct rwkv_future_tensor ln2_bias, + const struct rwkv_future_tensor ffn_time_mix_k, + const struct rwkv_future_tensor ffn_time_mix_r, + const struct rwkv_future_tensor ffn_k, + const struct rwkv_future_tensor ffn_v, + const struct rwkv_future_tensor ffn_r, + struct rwkv_future_tensor & ffn_xx, + + const struct rwkv_future_tensor ln_out_weight, + const struct rwkv_future_tensor ln_out_bias, + const struct rwkv_future_tensor head +) { + struct rwkv_future_tensor x = emb.get_rows(ctx, tokens).layer_norm(ctx, ln0_weight, ln0_bias); - /* att */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_att_size(n_embed)); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); - /* ffn */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_ffn_size(n_embed, ffn_key_size)); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); + for (size_t i = 0; i < n_layer; i++) { + x = x.consume(ctx, rwkv_future_att(ctx, + ln1_weight, ln1_bias, att_time_mix_k, att_time_mix_v, att_time_mix_r, att_time_first, att_time_decay, + att_r, att_k, att_v, att_output, x, att_xx, att_aa, att_bb, att_pp)); + + x = x.consume(ctx, rwkv_future_ffn(ctx, + ln2_weight, ln2_bias, ffn_time_mix_k, ffn_time_mix_r, ffn_k, ffn_v, ffn_r, x, ffn_xx)); + + ffn_xx.view(ctx); + att_xx.view(ctx); + att_aa.view(ctx); + att_bb.view(ctx); + att_pp.view(ctx); + } - /* output */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * 5, GGML_TYPE_F32, n_embed); + x = x.layer_norm(ctx, ln_out_weight, ln_out_bias); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); - /* logits */ rwkv_ctx_size_add_tensor(ctx_size, 1, 1, GGML_TYPE_F32, n_vocab); + rwkv_future_graph_work(ctx, ffn_k.type, ffn_k.height, n_threads, tokens.width); - return ctx_size; + return head.mul_mat(ctx, x).view(ctx); } bool rwkv_build_serial_graph( @@ -946,10 +1143,13 @@ bool rwkv_build_serial_graph( struct rwkv_layer_state * inputs, struct rwkv_layer_state * outputs, struct ggml_tensor * logits, - struct ggml_cgraph * cgraph -) { - size_t n_embed = model.header.n_embed; + struct ggml_cgraph * cgraph, + size_t * const pre_logits_nodes, + size_t * const pre_logits_leafs, + size_t * const post_logits_nodes, + size_t * const post_logits_leafs +) { // x = self.w.emb.weight[token] struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, tokens); @@ -971,40 +1171,93 @@ bool rwkv_build_serial_graph( ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp)); } + *pre_logits_nodes = cgraph->n_nodes; + *pre_logits_leafs = cgraph->n_leafs; + // x = self.layer_norm(x[-1,:], self.w.ln_out) x = rwkv_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias); // x = (self.w.head.weight @ x).float() ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits)); + *post_logits_nodes = cgraph->n_nodes; + *post_logits_leafs = cgraph->n_leafs; + return true; } -struct rwkv_ctx_size rwkv_sequence_graph_size(const size_t n_vocab, const size_t n_embed, const size_t n_layer, const size_t ffn_key_size, const size_t sequence_len) { - struct rwkv_ctx_size ctx_size; - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed, sequence_len); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 4, 1, GGML_TYPE_F32, n_embed, sequence_len); +struct rwkv_future_tensor rwkv_future_sequence_graph(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor tokens, + const size_t n_threads, + + const struct rwkv_future_tensor emb, + const struct rwkv_future_tensor ln0_weight, + const struct rwkv_future_tensor ln0_bias, + + const size_t n_layer, + + const struct rwkv_future_tensor ln1_weight, + const struct rwkv_future_tensor ln1_bias, + const struct rwkv_future_tensor att_time_mix_k, + const struct rwkv_future_tensor att_time_mix_v, + const struct rwkv_future_tensor att_time_mix_r, + const struct rwkv_future_tensor att_time_first, + const struct rwkv_future_tensor att_time_decay, + const struct rwkv_future_tensor att_r, + const struct rwkv_future_tensor att_k, + const struct rwkv_future_tensor att_v, + const struct rwkv_future_tensor att_output, + struct rwkv_future_tensor & att_xx, + struct rwkv_future_tensor & att_aa, + struct rwkv_future_tensor & att_bb, + struct rwkv_future_tensor & att_pp, + + const struct rwkv_future_tensor ln2_weight, + const struct rwkv_future_tensor ln2_bias, + const struct rwkv_future_tensor ffn_time_mix_k, + const struct rwkv_future_tensor ffn_time_mix_r, + const struct rwkv_future_tensor ffn_k, + const struct rwkv_future_tensor ffn_v, + const struct rwkv_future_tensor ffn_r, + struct rwkv_future_tensor & ffn_xx, + + const struct rwkv_future_tensor ln_out_weight, + const struct rwkv_future_tensor ln_out_bias, + const struct rwkv_future_tensor head +) { + struct rwkv_future_tensor x = emb.get_rows(ctx, tokens); + x = x.layer_norm(ctx, ln0_weight.repeat(ctx, x), ln0_bias.repeat(ctx, x)); - /* xx */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_xx_size(n_embed, sequence_len)); - /* rkv */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_att_rkv_size(n_embed, sequence_len)); + for (size_t i = 0; i < n_layer; i++) { + struct rwkv_future_tensor x0 = x, x_prev; + rwkv_future_carry_x(ctx, ln1_weight, ln1_bias, x0, x_prev, att_xx); + + struct rwkv_future_tensor r, k, v; + rwkv_future_att_rkv(ctx, att_time_mix_k, att_time_mix_v, att_time_mix_r, x0, x_prev, att_r, att_k, att_v, r, k, v); + + for (size_t i = 0; i < tokens.width; i++) { + struct rwkv_future_tensor kt = k.subview(ctx, emb.width); + struct rwkv_future_tensor vt = v.subview(ctx, emb.width); + struct rwkv_future_tensor xt = x_prev.subview(ctx, emb.width); + struct rwkv_future_tensor wkv = rwkv_future_att_wkv(ctx, att_time_first, att_time_decay, att_aa, att_bb, att_pp, k, v); + wkv.view(ctx); + } - /* kt */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * sequence_len, GGML_TYPE_F32, n_embed); - /* vt */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * sequence_len, GGML_TYPE_F32, n_embed); - /* xt */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * sequence_len, GGML_TYPE_F32, n_embed); - /* wkv */ rwkv_ctx_size_add(ctx_size, n_layer * sequence_len, rwkv_att_wkv_size(n_embed)); - /* xt */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * sequence_len, GGML_TYPE_F32, n_embed); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, n_layer * 2, 0, GGML_TYPE_F32, n_embed, sequence_len); + x = x.consume(ctx, att_output.mul_mat(ctx, r.combine(ctx, x_prev))); + x = x.consume(ctx, rwkv_future_ffn(ctx, ln2_weight, ln2_bias, ffn_time_mix_k, ffn_time_mix_r, ffn_k, ffn_v, ffn_r, x, ffn_xx)); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed, sequence_len); - /* ffn */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_ffn_size(n_embed, ffn_key_size, sequence_len)); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed, sequence_len); + ffn_xx.view(ctx); + att_xx.view(ctx); + att_aa.view(ctx); + att_bb.view(ctx); + att_pp.view(ctx); + } - /* output */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * 5, GGML_TYPE_F32, n_embed); + x = x.subview(ctx, emb.width).layer_norm(ctx, ln_out_weight, ln_out_bias); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 2, GGML_TYPE_F32, n_embed); - /* logits */ rwkv_ctx_size_add_tensor(ctx_size, 1, 1, GGML_TYPE_F32, n_vocab); + rwkv_future_graph_work(ctx, ffn_k.type, ffn_k.height, n_threads, tokens.width); - return ctx_size; + return head.mul_mat(ctx, x).view(ctx); } bool rwkv_build_sequence_graph( @@ -1014,7 +1267,12 @@ bool rwkv_build_sequence_graph( struct rwkv_layer_state * inputs, struct rwkv_layer_state * outputs, struct ggml_tensor * logits, - struct ggml_cgraph * cgraph + struct ggml_cgraph * cgraph, + + size_t * const pre_logits_nodes, + size_t * const pre_logits_leafs, + size_t * const post_logits_nodes, + size_t * const post_logits_leafs ) { const uint32_t n_embed = model.header.n_embed; const size_t sequence_len = tokens->ne[0]; @@ -1026,23 +1284,23 @@ bool rwkv_build_sequence_graph( struct rwkv_layer & layer = model.layers[i]; struct rwkv_layer_state state = inputs[i]; - struct ggml_tensor * x0 = x, * xx; - rwkv_xx(ctx, layer.ln1_weight, layer.ln1_bias, x0, xx, state.att_xx); + struct ggml_tensor * x0 = x, * x_prev; + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x0, x_prev, state.att_xx); struct ggml_tensor * r, * k, * v; - rwkv_att_rkv(ctx, layer, x0, xx, r, k, v); + rwkv_att_rkv(ctx, layer, x0, x_prev, r, k, v); ggml_build_forward_expand(cgraph, r); for (uint32_t t = 0; t < sequence_len; t++) { struct ggml_tensor * kt = ggml_view_1d(ctx, k, n_embed, n_embed * sizeof(float) * t); struct ggml_tensor * vt = ggml_view_1d(ctx, v, n_embed, n_embed * sizeof(float) * t); - struct ggml_tensor * xt = ggml_view_1d(ctx, xx, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * xt = ggml_view_1d(ctx, x_prev, n_embed, n_embed * sizeof(float) * t); struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, kt, vt, state.att_aa, state.att_bb, state.att_pp); ggml_build_forward_expand(cgraph, ggml_cpy(ctx, wkv, xt)); } - x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, xx))); + x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, x_prev))); x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); struct rwkv_layer_state & output = outputs[i]; @@ -1053,33 +1311,21 @@ bool rwkv_build_sequence_graph( ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp)); } + *pre_logits_nodes = cgraph->n_nodes; + *pre_logits_leafs = cgraph->n_leafs; + // x = self.layer_norm(x[-1,:], self.w.ln_out) x = rwkv_layer_norm(ctx, ggml_view_1d(ctx, x, n_embed, n_embed * sizeof(float) * (sequence_len - 1)), model.ln_out_weight, model.ln_out_bias); // x = (self.w.head.weight @ x).float() ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits)); - return true; -} - -size_t rwkv_estimate_graph_work(const enum ggml_type type, const size_t ffn_key_size, const uint32_t n_threads, const size_t sequence_len = 1) { + *post_logits_nodes = cgraph->n_nodes; + *post_logits_leafs = cgraph->n_leafs; - enum ggml_type mul_mat_type = ggml_is_quantized(type) ? GGML_TYPE_Q8_1 : type; - return rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, ffn_key_size, sequence_len) * n_threads + 64 * (n_threads - 1)); + return true; } -struct rwkv_file { - FILE * file; - - rwkv_file(FILE * file): file(file) {} - - ~rwkv_file() { - if (file) { - fclose(file); - } - } -}; - void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) { bool * ptr = ctx ? &ctx->print_errors : &global_print_errors; *ptr = print_errors; @@ -1096,6 +1342,18 @@ enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) { return value; } +struct rwkv_file { + FILE * file; + + rwkv_file(FILE * file): file(file) {} + + ~rwkv_file() { + if (file) { + fclose(file); + } + } +}; + bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance) { struct stat file_stat; struct rwkv_model model; @@ -1114,14 +1372,14 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst struct rwkv_tensor_header tensor_header; std::string name; - struct rwkv_ctx_size ctx_size; + struct rwkv_future_ctx future_ctx; while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) { RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(file.file, tensor_header), "Invalid tensor header"); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(file.file, tensor_header.key_length, name), "Failed to read tensor name"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, rwkv_tensor_size(tensor_header), SEEK_CUR) == 0, "Failed to read tensor data"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, tensor_header.size(), SEEK_CUR) == 0, "Failed to read tensor data"); - rwkv_ctx_size_add_tensor(ctx_size, 1, 0, tensor_header); + future_ctx.alloc(rwkv_type_to_ggml[tensor_header.data_type], tensor_header.width, tensor_header.height); if (ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") { ffn_key_size = tensor_header.height; @@ -1131,7 +1389,7 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, ffn_key_size, "Model is missing parameter blocks.0.ffn.key.weight"); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, sizeof(struct rwkv_file_header), SEEK_SET) == 0, "Failed to seek in file"); - ctx = ctx_size; + ctx = future_ctx; RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx.ctx, "Failed to allocate model context"); struct ggml_tensor * tensor; @@ -1170,25 +1428,31 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr inputs(new(std::nothrow) struct rwkv_layer_state [n_layer]); + std::unique_ptr inputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, inputs.get(), "Failed to allocate input state parts"); // We collect parts of output state here. Each part is (n_embed) vector. - std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state [n_layer]); + std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts"); for (size_t i = 0; i < n_layer; i++) { @@ -1209,19 +1473,52 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptrffn_key_size)); - /* work */ rwkv_ctx_size_add(graph_ctx_size, 1, rwkv_estimate_graph_work(rwkv_type_to_ggml[header.data_type], instance->ffn_key_size, n_threads)); + struct rwkv_future_ctx graph_future_ctx; + const struct rwkv_future_tensor future_token = graph_future_ctx.alloc(GGML_TYPE_I32, 1, 1, false); + + const struct rwkv_model & model = instance->model; + const struct rwkv_layer & layer = model.layers[0]; + const struct rwkv_layer_state & state = inputs[0]; + struct rwkv_future_tensor ffn_xx = state.ffn_xx; + struct rwkv_future_tensor att_xx = state.att_xx; + struct rwkv_future_tensor att_aa = state.att_aa; + struct rwkv_future_tensor att_bb = state.att_bb; + struct rwkv_future_tensor att_pp = state.att_pp; + + const struct rwkv_future_tensor future_graph = rwkv_future_serial_graph(graph_future_ctx, future_token, n_threads, + model.emb, + model.ln0_weight, model.ln0_bias, + + n_layer, + layer.ln1_weight, layer.ln1_bias, + layer.att_time_mix_k, layer.att_time_mix_v, layer.att_time_mix_r, + layer.att_time_first, layer.att_time_decay, + layer.att_receptance, layer.att_key, layer.att_value, layer.att_output, + att_xx, att_aa, att_bb, att_pp, + + layer.ln2_weight, layer.ln2_bias, + layer.ffn_time_mix_k, layer.ffn_time_mix_r, + layer.ffn_key, layer.ffn_value, layer.ffn_receptance, + ffn_xx, + + model.ln_out_weight, model.ln_out_weight, + model.head + ); struct rwkv_graph serial_graph; - serial_graph.ctx = graph_ctx_size; + serial_graph.ctx = graph_future_ctx; RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, serial_graph.ctx.ctx, "Failed to allocate serial graph context"); serial_graph.tokens = ggml_new_i32(serial_graph.ctx.ctx, 0); serial_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, serial_graph.cgraph, "Failed to allocate serial graph"); serial_graph.cgraph->n_threads = n_threads; - RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph(serial_graph.ctx.ctx, instance->model, serial_graph.tokens, inputs.get(), outputs.get(), logits, serial_graph.cgraph.get())); + + RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph( + serial_graph.ctx.ctx, instance->model, + serial_graph.tokens, inputs.get(), outputs.get(), logits, + serial_graph.cgraph.get(), + &serial_graph.pre_logits_nodes, &serial_graph.pre_logits_leafs, &serial_graph.post_logits_nodes, &serial_graph.post_logits_leafs + )); std::unique_ptr rwkv_ctx(new(std::nothrow) struct rwkv_context()); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx, "Failed to allocate rwkv_context"); @@ -1258,20 +1555,47 @@ struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32 return clone; } -bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers) { +bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers) { +#if defined(GGML_USE_CLBLAST) || defined(GGML_USE_CUBLAS) + printf("\nRWKV: Attempting offload of %u layers",n_layers); + const auto offload = [&](struct ggml_tensor * tensor) { + // TODO support multi-GPU + tensor->backend = GGML_BACKEND_GPU; + #if defined(GGML_USE_CLBLAST) + ggml_cl_transform_tensor(tensor->data, tensor); + #else + ggml_cuda_transform_tensor(tensor->data, tensor); + #endif + }; + + const size_t n_gpu = std::min(n_layers, ctx->instance->model.header.n_layer); + + if (ctx->gpu_layers < n_gpu) { + for (size_t & i = ctx->gpu_layers; i < n_gpu; i++) { + const struct rwkv_layer & layer = ctx->instance->model.layers[i]; + + // TODO also offload other operations to GPU with ggml_cuda_assign_buffers + offload(layer.att_key); + offload(layer.att_value); + offload(layer.att_receptance); + offload(layer.att_output); + + offload(layer.ffn_key); + offload(layer.ffn_value); + offload(layer.ffn_receptance); + } - return true; + return true; + } +#endif + return false; } void rwkv_set_inputs(const struct rwkv_context * ctx, const float * state_in) { if (state_in) { memcpy(ctx->input_state->data, state_in, ggml_nbytes(ctx->input_state)); } else { - ggml_set_f32(ctx->input_state, 0.0F); - - for (size_t i = 0; i < ctx->instance->model.header.n_layer; i++) { - ggml_set_f32(ctx->input_layers[i].att_pp, -1e30F); - } + rwkv_init_state(ctx, (float *) ctx->input_state->data); } } @@ -1285,23 +1609,33 @@ void rwkv_get_outputs(const struct rwkv_context * ctx, float * state_out, float } } -bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) { - ((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE; +bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) { + ctx->last_error = RWKV_ERROR_NONE; const struct rwkv_file_header & header = ctx->instance->model.header; const size_t n_vocab = header.n_vocab; - RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < n_vocab, "Token (%" PRId32 ") is out of range (0 ..= %zu)", token, n_vocab - 1); + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < n_vocab, "Token (%" PRId32 ") is out of range (0 .. %zu)", token, n_vocab - 1); rwkv_set_inputs(ctx, state_in); ggml_set_i32(ctx->serial_graph.tokens, token); + + // Short circuit computation of logits if nobody actually cares + if (!logits_out) { + ctx->serial_graph.cgraph->n_nodes = ctx->serial_graph.pre_logits_nodes; + ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.pre_logits_leafs; + } else { + ctx->serial_graph.cgraph->n_nodes = ctx->serial_graph.post_logits_nodes; + ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs; + } + ggml_graph_compute(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get()); rwkv_get_outputs(ctx, state_out, logits_out); return true; } -bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * sequence, const size_t sequence_len, const float * state_in, float * state_out, float * logits_out) { - ((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE; +bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, const size_t sequence_len, const float * state_in, float * state_out, float * logits_out) { + ctx->last_error = RWKV_ERROR_NONE; const struct rwkv_file_header & header = ctx->instance->model.header; const size_t n_vocab = header.n_vocab; @@ -1311,34 +1645,78 @@ bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * sequen if (sequence) { for (size_t i = 0; i < sequence_len; i++) { const uint32_t token = sequence[i]; - RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < n_vocab, "Tokens[%zu] (%" PRId32 ") is out of range (0 ..= %zu)", i, token, n_vocab - 1); + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < n_vocab, "Token at index %zu (%" PRId32 ") is out of range (0 .. %zu)", i, token, n_vocab - 1); } } if (ctx->sequence_len != sequence_len) { // Build new sequence graph - struct rwkv_ctx_size ctx_size; - /* tokens */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, sequence_len); - /* graph */ rwkv_ctx_size_add(ctx_size, 1, rwkv_sequence_graph_size(n_vocab, n_embed, n_layer, ctx->instance->ffn_key_size, sequence_len)); - /* work */ rwkv_ctx_size_add(ctx_size, 1, rwkv_estimate_graph_work(rwkv_type_to_ggml[header.data_type], ctx->instance->ffn_key_size, 1, sequence_len)); - - struct rwkv_graph graph; - graph.ctx = ctx_size; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, graph.ctx.ctx, "Failed to allocate sequence graph context"); - graph.tokens = ggml_new_tensor_1d(graph.ctx.ctx, GGML_TYPE_I32, sequence_len); - graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, graph.cgraph, "Failed to allocate sequence graph"); - graph.cgraph->n_threads = 1; - RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph(graph.ctx.ctx, ctx->instance->model, graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits, graph.cgraph.get())); - - ((struct rwkv_context *) ctx)->sequence_len = sequence_len; - ((struct rwkv_context *) ctx)->sequence_graph = std::move(graph); + + struct rwkv_future_ctx graph_future_ctx; + const struct rwkv_future_tensor future_tokens = graph_future_ctx.alloc(GGML_TYPE_I32, sequence_len); + + const struct rwkv_model & model = ctx->instance->model; + const struct rwkv_layer & layer = model.layers[0]; + const struct rwkv_layer_state & state = ctx->input_layers[0]; + struct rwkv_future_tensor ffn_xx = state.ffn_xx; + struct rwkv_future_tensor att_xx = state.att_xx; + struct rwkv_future_tensor att_aa = state.att_aa; + struct rwkv_future_tensor att_bb = state.att_bb; + struct rwkv_future_tensor att_pp = state.att_pp; + + const struct rwkv_future_tensor future_graph = rwkv_future_sequence_graph(graph_future_ctx, future_tokens, 1, + model.emb, + model.ln0_weight, model.ln0_bias, + + n_layer, + layer.ln1_weight, layer.ln1_bias, + layer.att_time_mix_k, layer.att_time_mix_v, layer.att_time_mix_r, + layer.att_time_first, layer.att_time_decay, + layer.att_receptance, layer.att_key, layer.att_value, layer.att_output, + att_xx, att_aa, att_bb, att_pp, + + layer.ln2_weight, layer.ln2_bias, + layer.ffn_time_mix_k, layer.ffn_time_mix_r, + layer.ffn_key, layer.ffn_value, layer.ffn_receptance, + ffn_xx, + + model.ln_out_weight, model.ln_out_weight, + model.head + ); + + struct rwkv_graph sequence_graph; + sequence_graph.ctx = graph_future_ctx; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, sequence_graph.ctx.ctx, "Failed to allocate sequence graph context"); + sequence_graph.tokens = ggml_new_tensor_1d(sequence_graph.ctx.ctx, GGML_TYPE_I32, sequence_len); + sequence_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, sequence_graph.cgraph, "Failed to allocate sequence graph"); + sequence_graph.cgraph->n_threads = 1; + + RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph( + sequence_graph.ctx.ctx, ctx->instance->model, + sequence_graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits, + sequence_graph.cgraph.get(), + &sequence_graph.pre_logits_nodes, &sequence_graph.pre_logits_leafs, &sequence_graph.post_logits_nodes, &sequence_graph.post_logits_leafs + )); + + ctx->sequence_len = sequence_len; + ctx->sequence_graph = std::move(sequence_graph); } // Allow building the sequence graph without actually evaluating, by specifying sequence = NULL. if (sequence) { rwkv_set_inputs(ctx, state_in); memcpy(ctx->sequence_graph.tokens->data, sequence, sequence_len * sizeof(uint32_t)); + + // Short circuit computation of logits if nobody actually cares + if (!logits_out) { + ctx->sequence_graph.cgraph->n_nodes = ctx->sequence_graph.pre_logits_nodes; + ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.pre_logits_leafs; + } else { + ctx->sequence_graph.cgraph->n_nodes = ctx->sequence_graph.post_logits_nodes; + ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs; + } + ggml_graph_compute(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get()); rwkv_get_outputs(ctx, state_out, logits_out); } @@ -1346,12 +1724,52 @@ bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * sequen return true; } -uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) { - return ctx->instance->model.header.n_layer * 5 * ctx->instance->model.header.n_embed; +// Provided for compatibility. +extern "C" RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) { + return rwkv_get_state_len(ctx); +} + +// Provided for compatibility. +extern "C" RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) { + return rwkv_get_logits_len(ctx); +} + +size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) { + return (size_t) ctx->instance->model.header.n_vocab; +} + +size_t rwkv_get_n_embed(const struct rwkv_context * ctx) { + return (size_t) ctx->instance->model.header.n_embed; +} + +size_t rwkv_get_n_layer(const struct rwkv_context * ctx) { + return (size_t) ctx->instance->model.header.n_layer; +} + +size_t rwkv_get_state_len(const struct rwkv_context * ctx) { + const struct rwkv_file_header & header = ctx->instance->model.header; + return (size_t) header.n_embed * 5 * (size_t) header.n_layer; } -uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) { - return ctx->instance->model.header.n_vocab; +size_t rwkv_get_logits_len(const struct rwkv_context * ctx) { + return (size_t) ctx->instance->model.header.n_vocab; +} + +void rwkv_init_state(const struct rwkv_context * ctx, float * state) { + const struct rwkv_file_header & header = ctx->instance->model.header; + const size_t layer_size = (size_t) header.n_embed * 5; + const size_t layer_zero = (size_t) header.n_embed * 4; + const size_t layers_size = (size_t) header.n_layer * layer_size; + + for (size_t start = 0; start < layers_size; start += layer_size) { + for (size_t i = 0; i < layer_zero; i++) { + state[start + i] = 0.0F; + } + + for (size_t i = layer_zero; i < layer_size; i++) { + state[start + i] = -1e30F; + } + } } void rwkv_free(struct rwkv_context * ctx) { @@ -1381,7 +1799,12 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(in_file.file, in_header), "Invalid file header"); enum ggml_type in_type = rwkv_type_to_ggml[in_header.data_type]; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, in_type == GGML_TYPE_F32 || in_type == GGML_TYPE_F16, "Unsupported input data type (%s); needs to be f32 or f16", rwkv_type_to_string[rwkv_type_from_ggml[in_type]]); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_FILE, + in_type == GGML_TYPE_F32 || in_type == GGML_TYPE_F16, + "Unsupported input data type (%s); needs to be FP32 or FP16", + rwkv_type_to_string[rwkv_type_from_ggml[in_type]] + ); struct rwkv_file_header out_header = in_header; out_header.version = RWKV_FILE_VERSION; @@ -1392,7 +1815,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const size_t orig_total_size = 0; size_t new_total_size = 0; - // Required to init the fp16 tables + // Required to init the F16 tables // Doesn't crash if ggml_init fails ggml_free(ggml_init({ 0, NULL, true })); @@ -1404,26 +1827,26 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const struct rwkv_tensor_header header; RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, rwkv_fread_tensor_header_and_skip(in_file.file, header)); - size_t in_size = rwkv_tensor_size(header); + size_t in_size = header.size(); if (in_size > max_in_size) { max_in_size = in_size; } // f16 type tensors get relocated to out and then converted into f32 at in - if (header.data_type == TYPE_F16) { + if (header.data_type == TYPE_FP16) { if (in_size > max_out_size) { max_out_size = in_size; } - size_t f32_size = rwkv_tensor_size(GGML_TYPE_F32, header.width, header.height); + size_t f32_size = rwkv_future_tensor::size(GGML_TYPE_F32, header.width, header.height); if (f32_size > max_in_size) { max_in_size = f32_size; } } - size_t out_size = rwkv_tensor_size(out_type, header.width, header.height); + size_t out_size = rwkv_future_tensor::size(out_type, header.width, header.height); if (out_size > max_out_size) { max_out_size = out_size; @@ -1439,7 +1862,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const // This is a histogram of quantized values. If it shows single 1.0, then all 0.0, something went very wrong! int64_t hist_all[16] {}; - std::unique_ptr scratch(new(std::nothrow) uint8_t [max_in_size + max_out_size]); + std::unique_ptr scratch(new(std::nothrow) uint8_t[max_in_size + max_out_size]); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate buffer"); uint8_t * in_buf = scratch.get(); @@ -1457,19 +1880,19 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const const char * name_str = name.c_str(); RWKV_MSG("%*s - [%5" PRId32 ", %5" PRId32 "], type = %6s ", (int) max_key_length, name_str, header.width, header.height, rwkv_type_to_string[header.data_type]); - data = header.data_type == TYPE_F16 ? out_buf : in_buf; - size_t orig_size = rwkv_tensor_size(header), new_size = orig_size; + data = header.data_type == TYPE_FP16 ? out_buf : in_buf; + size_t orig_size = header.size(), new_size = orig_size; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_data(in_file.file, orig_size, data), "\nFailed to read tensor data of %s", name_str); // Quantize only 2D tensors, except embedding and head matrices. // Embedding and head take not too much space, especially in bigger models; // but they significantly increase perplexity when quantized. - if ((header.data_type == TYPE_F32 || header.data_type == TYPE_F16) && header.dim_count == 2 && name != "emb.weight" && name != "head.weight") { + if ((header.data_type == TYPE_FP32 || header.data_type == TYPE_FP16) && header.dim_count == 2 && name != "emb.weight" && name != "head.weight") { RWKV_MSG("quantizing... "); size_t nelements = (size_t) header.width * (size_t) header.height; - if (header.data_type == TYPE_F16) { + if (header.data_type == TYPE_FP16) { ggml_fp16_to_fp32_row((const ggml_fp16_t *) out_buf, (float *) in_buf, nelements); } diff --git a/otherarch/rwkv_v3.h b/otherarch/rwkv_v3.h index efadaa7ae23f4..b24812fc2ff98 100644 --- a/otherarch/rwkv_v3.h +++ b/otherarch/rwkv_v3.h @@ -84,7 +84,7 @@ extern "C" { RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx); // Loads the model from a file and prepares it for inference. - // Returns NULL on any error. Error messages would be printed to stderr. + // Returns NULL on any error. // - model_file_path: path to model file in ggml format. // - n_threads: count of threads to use, must be positive. RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads); @@ -97,39 +97,64 @@ extern "C" { // - n_threads: count of threads to use, must be positive. RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads); - // Offloads specified layers of context onto GPU using cuBLAS, if it is enabled. - // If rwkv.cpp was compiled without cuBLAS support, this function is a no-op. - RWKV_API bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers); + // Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS. + // Returns true if at least one layer was offloaded. + // If rwkv.cpp was compiled without cuBLAS support, this function is a no-op and always returns false. + RWKV_API bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers); // Evaluates the model for a single token. // Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. - // Returns false on any error. Error messages would be printed to stderr. + // Returns false on any error. + // You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration + // that you do not calculate logits. // - token: next token index, in range 0 <= token < n_vocab. - // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass. - // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to if non-NULL. - // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to if non-NULL. - RWKV_API bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out); + // - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass. + // - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. + // - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL. + RWKV_API bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out); // Evaluates the model for a sequence of tokens. // Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so. // Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length. - // - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed. (Useful for initialization.) // Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. - // Returns false on any error. Error messages would be printed to stderr. + // Returns false on any error. + // You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration + // that you do not calculate logits. + // - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization. // - sequence_len: number of tokens to read from the array. - // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count, or NULL if this is a first pass. - // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to if non-NULL. - // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to if non-NULL. - RWKV_API bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out); + // - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass. + // - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. + // - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL. + RWKV_API bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out); + + // Returns the number of tokens in the given model's vocabulary. + // Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536). + RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx); + + // Returns the number of elements in the given model's embedding. + // Useful for reading individual fields of a model's hidden state. + RWKV_API size_t rwkv_get_n_embed(const struct rwkv_context * ctx); + + // Returns the number of layers in the given model. + // Useful for always offloading the entire model to GPU. + RWKV_API size_t rwkv_get_n_layer(const struct rwkv_context * ctx); + + // Returns the number of float elements in a complete state for the given model. + // This is the number of elements you'll need to allocate for a call to rwkv_eval, rwkv_eval_sequence, or rwkv_init_state. + RWKV_API size_t rwkv_get_state_len(const struct rwkv_context * ctx); - // Returns count of FP32 elements in state buffer. - RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx); + // Returns the number of float elements in the logits output of a given model. + // This is currently always identical to n_vocab. + RWKV_API size_t rwkv_get_logits_len(const struct rwkv_context * ctx); - // Returns count of FP32 elements in logits buffer. - RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx); + // Initializes the given state so that passing it to rwkv_eval or rwkv_eval_sequence would be identical to passing NULL. + // Useful in cases where tracking the first call to these functions may be annoying or expensive. + // State must be initialized for behavior to be defined, passing a zeroed state to rwkv.cpp functions will result in NaNs. + // - state: FP32 buffer of size rwkv_get_state_len() to initialize + RWKV_API void rwkv_init_state(const struct rwkv_context * ctx, float * state); // Frees all allocated memory and the context. - // Does not need to be the same thread that created the rwkv_context. + // Does not need to be called on the same thread that created the rwkv_context. RWKV_API void rwkv_free(struct rwkv_context * ctx); // Quantizes FP32 or FP16 model to one of quantized formats.