Skip to content

Commit

Permalink
try out the new rwkv but it seems worse, may revert
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Jul 1, 2023
1 parent 632bf27 commit e1a7042
Show file tree
Hide file tree
Showing 4 changed files with 830 additions and 375 deletions.
25 changes: 16 additions & 9 deletions gpttype_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
else //rwkv_2
{
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);

if(inputs.gpulayers>0)
{
rwkv_gpu_offload_layers(rwkv_ctx_v3,inputs.gpulayers);
}

const struct rwkv_file_header & header = rwkv_ctx_v3->instance->model.header;
const size_t n_vocab = header.n_vocab;
printf("\nDetected Vocab: %d",n_vocab);
Expand Down Expand Up @@ -811,7 +817,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
params.top_k = 120; //to disable top_k we actually need to increase this value to a very high number
}
if (params.seed <= 0)
if (params.seed <= 0 || params.seed==0xFFFFFFFF)
{
params.seed = time(NULL);
}
Expand Down Expand Up @@ -1060,14 +1066,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
else
{
if(embd.size()>1)
{
evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
}
else
{
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
}
// if(embd.size()>1)
// {
// evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
// }
// else
// {
bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2));
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
//}

memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
Expand Down
4 changes: 2 additions & 2 deletions otherarch/llama_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2204,7 +2204,7 @@ struct llama_v2_context * llama_v2_init_from_file(

llama_v2_context * ctx = new llama_v2_context;

if (params.seed < 0) {
if (params.seed < 0 || params.seed==0xFFFFFFFF) {
params.seed = time(NULL);
}

Expand Down Expand Up @@ -2552,7 +2552,7 @@ int llama_v2_get_kv_cache_token_count(const struct llama_v2_context * ctx) {
#define LLAMA_V2_MAX_RNG_STATE (64*1024)

void llama_v2_set_rng_seed(struct llama_v2_context * ctx, int seed) {
if (seed < 0) {
if (seed < 0 || seed==0xFFFFFFFF) {
seed = time(NULL);
}
ctx->rng.seed(seed);
Expand Down
Loading

0 comments on commit e1a7042

Please sign in to comment.