Skip to content

Commit

Permalink
Add select filter
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Sep 30, 2023
1 parent dd6a302 commit 4ce5b66
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 15 deletions.
8 changes: 7 additions & 1 deletion exllamav2/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,21 @@ def generate_simple(self, prompt: str or list,

self._gen_begin_base(ids, mask)

# Begin filters

gen_settings.begin_filters(self.tokenizer.get_id_to_piece_list()[unhealed_token] if unhealed_token is not None else None)

# Generate tokens

for i in range(num_tokens):

logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, input_mask = mask).float().cpu()
token, _ = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token = unhealed_token)
token, _, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token = unhealed_token)
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)
gen_settings.feed_filters(token)

unhealed_token = None
if eos: break

# Decode

Expand Down
4 changes: 4 additions & 0 deletions exllamav2/generator/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from exllamav2.version import __version__

from exllamav2.generator.filters.base import ExLlamaV2Filter
from exllamav2.generator.filters.select import ExLlamaV2SelectFilter
39 changes: 39 additions & 0 deletions exllamav2/generator/filters/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Tokenizer
)

class ExLlamaV2Filter:

# Internal state

model: ExLlamaV2
tokenizer: ExLlamaV2Tokenizer
sequence_str: str


def __init__(self, model, tokenizer):

self.model = model
self.tokenizer = tokenizer
self.sequence_str = ""


def clone(self):

c = ExLlamaV2Filter(self.model, self.tokenizer)
c.sequence_str = self.sequence_str
return c


def begin(self, prefix_str):
pass


def feed(self, token):
pass


def next(self):
pass

107 changes: 107 additions & 0 deletions exllamav2/generator/filters/select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Tokenizer
)

from exllamav2.generator.filters.base import ExLlamaV2Filter

class ExLlamaV2SelectFilter(ExLlamaV2Filter):

options: list
offset: int
prefix: str
case_insensitive: bool

def __init__(self, model, tokenizer, options, case_insensitive = False):
super().__init__(model, tokenizer)

self.options = options if not case_insensitive else [o.lower() for o in options]
self.case_insensitive = case_insensitive
self.offset = 0
self.prefix = ""
self.sequence_str_cmp = ""


def begin(self, prefix_str = ""):

self.sequence_str = ""
self.sequence_str_cmp = ""
self.prefix = prefix_str
self.offset = 0


def feed(self, token):

id_to_piece = self.tokenizer.get_id_to_piece_list()
piece = id_to_piece[token]
self.sequence_str += piece
if self.case_insensitive:
split = max(len(self.prefix) - self.offset, 0)
piece_l = piece[:split]
piece_r = piece[split:].lower()
self.sequence_str_cmp += piece_l + piece_r
else:
self.sequence_str_cmp += piece
self.offset += len(piece)


# TODO: Evaluate overhead and maybe move to extension

def next(self):

# prefix_to_ids = self.tokenizer.get_prefix_to_ids_dict()
id_to_piece = self.tokenizer.get_id_to_piece_list()
pass_str = set()
end_str = set()

char_trie = self.tokenizer.get_char_trie_ci() if self.case_insensitive else self.tokenizer.get_char_trie()
pass_tokens = set()
end_tokens = set()

for o in self.options:

option = (self.prefix + o)
if option[:self.offset] != self.sequence_str_cmp: continue

option = option[self.offset:]

if self.case_insensitive:
option_cased = option
option = option.lower()
else:
option_cased = None
if option_cased == option: option_cased = None

w = char_trie
while option != "":

c = option[0]
option = option[1:]

if c in w.children: w = w.children[c]
else: break

if len(w.leaf) > 0:

# Add tokens to pass set

if option_cased is None:
pass_tokens.update(w.leaf)
pass_str.update([id_to_piece[l] for l in w.leaf])
if option == "":
end_tokens.update(w.leaf)
end_str.update([id_to_piece[l] for l in w.leaf])

# Special case if prefix is cased but continuation is case-insensitive

else:
for l in list(w.leaf):
s = id_to_piece[l]
if option_cased.startswith(s):
pass_tokens.add(l)
pass_str.add(s)
if option == "":
end_tokens.add(l)
end_str.add(s)

return pass_tokens, end_tokens
33 changes: 29 additions & 4 deletions exllamav2/generator/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,24 @@ def disallow_tokens(self, tokenizer, tokens):
self.token_bias[tokens] = float("-inf")


def begin_filters(self, prefix_str = ""):

for f in self.filters: f.begin(prefix_str)


def feed_filters(self, feed_token):

for f in self.filters: f.feed(feed_token)


@staticmethod
def sample(logits: torch.tensor, settings: Settings, sequence_ids: torch.tensor, random: float, tokenizer: ExLlamaV2Tokenizer, prefix_token = None):

batch_size, _, vocab_size = logits.shape

assert logits.shape[1] == 1, "Logits tensor is incorrect shape, must be (bsz, 1, vocab_size)"
assert prefix_token is None or prefix_token.shape == (batch_size, 1), "Prefix token list doesn't match batch shape"
assert batch_size == 1 or len(settings.filters) == 0, "Filters not implemented for batch size > 1"

logits = logits.clone().squeeze(1)
logit_filter = torch.ones((batch_size, vocab_size), dtype = torch.bool)
Expand All @@ -70,8 +81,18 @@ def sample(logits: torch.tensor, settings: Settings, sequence_ids: torch.tensor,

# Evaluate filters

# for filter in settings.filters:
# pass
if len(settings.filters) > 0:

pass_tokens = None
end_tokens = None
for f in settings.filters:

pt, et = f.next()
pass_tokens = pt if pass_tokens is None else pass_tokens & pt
end_tokens = et if end_tokens is None else end_tokens | et

assert pass_tokens, "Filter excluded all tokens"
ext_c.logit_filter_exclusive(logit_filter, [sorted(list(pass_tokens))])

# Healing

Expand All @@ -85,7 +106,6 @@ def sample(logits: torch.tensor, settings: Settings, sequence_ids: torch.tensor,

ext_c.logit_filter_exclusive(logit_filter, valid_token_lists)


# Sampling

batch_size = logits.shape[0]
Expand All @@ -102,7 +122,12 @@ def sample(logits: torch.tensor, settings: Settings, sequence_ids: torch.tensor,
output_probs,
logit_filter)

return output_tokens, output_probs
# Stop condition from filters

end_filter = False
if len(settings.filters) > 0 and output_tokens[0].item() in end_tokens: end_filter = True

return output_tokens, output_probs, end_filter



Expand Down
39 changes: 34 additions & 5 deletions exllamav2/generator/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
stop_tokens: list = []
no_tokens: torch.Tensor = None

first_token = False
heal_next_token = False

def __init__(self, model, cache, tokenizer):
Expand Down Expand Up @@ -72,21 +73,42 @@ def stream(self) -> (str, bool, torch.Tensor):
self.sequence_ids = self.sequence_ids[:, :-1]
self.cache.current_seq_len -= 1

# Start filters

if self.first_token:

self.settings.begin_filters(self.tokenizer.get_id_to_piece_list()[last_token])
self.first_token = False

# Regenerate the last token again, with prefix

healed_token = self._gen_single_token(self.settings, prefix_token = last_token)
new_tail = self.tokenizer.decode(self.sequence_ids[:, -(self.tail_decode_tokens):])[0]
healed_token, eos = self._gen_single_token(self.settings, prefix_token = last_token)
new_tail = self.tokenizer.decode(self.sequence_ids[:, -self.tail_decode_tokens:])[0]
self.held_text += new_tail[len(old_tail):]

self.heal_next_token = False

# In case we only needed the healed token

if eos: return self.held_text, True, self.no_tokens

# Start filters when not healing

else:

if self.first_token:

self.settings.begin_filters()
self.first_token = False


# Decode the current tail end of the sequence

old_tail = self.tokenizer.decode(self.sequence_ids[:, -self.tail_decode_tokens:])[0]

# Generate a single token and append to the sequence

next_token = self._gen_single_token(self.settings)
next_token, eos = self._gen_single_token(self.settings)

# End immediately if it was a stop token

Expand All @@ -99,6 +121,10 @@ def stream(self) -> (str, bool, torch.Tensor):
self.held_text += new_tail[len(old_tail):]
self.held_tokens = torch.cat([self.held_tokens, next_token], dim = -1)

# Return now if newly added token ends a filter

if eos: return self.held_text, True, self.held_tokens

# Hold text as long as it contains part of a stop string

partial_ss = False
Expand Down Expand Up @@ -137,6 +163,8 @@ def _gen_begin(self, in_tokens, gen_settings):
self.cache.current_seq_len = 0
self.model.forward(self.sequence_ids[:, :-1], self.cache, preprocess_only = True)

self.first_token = True


def _gen_begin_reuse(self, in_tokens, gen_settings):

Expand Down Expand Up @@ -173,6 +201,7 @@ def _gen_feed_tokens(self, in_tokens, gen_settings):
def _gen_single_token(self, gen_settings, prefix_token = None):

logits = self.model.forward(self.sequence_ids[:, -1:], self.cache).float().cpu()
token, _ = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token)
token, _, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token)
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)
return token
gen_settings.feed_filters(token)
return token, eos
Loading

0 comments on commit 4ce5b66

Please sign in to comment.