Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P1] Possible to do batch inference? #105

Open
thistleknot opened this issue Jun 8, 2024 · 3 comments
Open

[P1] Possible to do batch inference? #105

thistleknot opened this issue Jun 8, 2024 · 3 comments
Assignees
Labels
question Further information is requested

Comments

@thistleknot
Copy link

I'm doing this atm

for q_ in tqdm(rando):
    #print('quote:',q_)
    quotes_fol = []
    quotes_nodes_edges = []
    sentences = sent_tokenize(q_)
    for q in sentences:
        # tokenize and prepare the input
        prompt = prompt_no_input_template % q
        prompt = tokenizer(prompt, return_tensors="pt").to(device)
        
        unit_locations = torch.IntTensor([pyreft.get_intervention_locations(
            last_position=prompt["input_ids"].shape[-1], 
            first_n=first_n, 
            last_n=last_n,
            pad_mode="last",
            num_interventions=len(reft_config.representations),
            share_weights=share_weights
        )]).permute(1, 0, 2).tolist()
        

        # Generate with beam search
        _, reft_response = reft_model.generate(
            prompt, 
            unit_locations={"sources->base": (None, unit_locations)},
            intervene_on_prompt=True, 
            max_new_tokens=537, 
            do_sample=True,
            top_k=50,
            temperature=0.7,
            num_beams=5,  # Using beam search with 5 beams
            eos_token_id=terminators, 
            early_stopping=True
        )
        response = tokenizer.decode(reft_response[0], skip_special_tokens=True)
        #print(response)
        #out = lcel_chain.invoke({"input": response})
        #print('node/csv:',out)
        quotes_fol.append(response)
    
        #quotes_nodes_edges.append(out)
    quotes_fol_.append(quotes_fol)
    #quotes_nodes_edges_.append(quotes_nodes_edges)

but i'd like to escape the iteration, and I'm not sure how to format unit_locations. Normally one would do something like model.generate(**inputs), but this being pyreft, I'm not sure if that is supported as it's a custom class (I haven't delved into the class for this specific feature).

Thought I'd ask first as well as for visibility for others who might be interested.

@frankaging frankaging changed the title Possible to do batch inference? [P1] Possible to do batch inference? Jun 9, 2024
@frankaging frankaging self-assigned this Jun 9, 2024
@frankaging frankaging added the question Further information is requested label Jun 9, 2024
@frankaging
Copy link
Collaborator

@thistleknot Yes, it supports batched inference calls.

You can take a look at this function for batching:
https://github.com/stanfordnlp/pyreft/blob/main/examples/loreft/compute_metrics.py#L111

In a nutshell, you need to apply left padding to your tokenizer and calculate the batched intervention locations accordingly.

@thistleknot
Copy link
Author

'calculate the batched intervention locations accordingly.'

that doesn't sound easy.

I'm not sure if I can use the same -1 position as I was before for each prompt... or if it's expecting it to be where it is within the batch tensor.

@thistleknot
Copy link
Author

thistleknot commented Jun 9, 2024

you able to help a brother out?


dataset = load_dataset("Abirate/english_quotes")
quotes = [q for q in dataset['train']['quote'] if (len(q) > 23 and len(q) < 140)]
#for q in quotes[0:10]:
    #print(q)
    
#rando = np.random.choice(quotes, 100, replace=False)
cleaned_quotes = [q.replace('“','').replace('”','') for q in quotes]

rando = random.choices(cleaned_quotes,k=100)

# Define constants
max_tokens = 115
desired_token_limit = 8192
batch_size = desired_token_limit // max_tokens

# Define the tokenizer with left-side padding
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=115,
    padding_side="left", use_fast=True,
    attn_implementation=attn_implementation
    # , add_eos_token=True, add_bos_token=True
)
tokenizer.pad_token = tokenizer.eos_token

# Position info about the interventions
share_weights = True # Whether the prefix and suffix interventions share weights.
positions = "f3+l3"  # The intervening positions of prefix tokens (f[irst]1) and suffix tokens (l[ast]1).
first_n, last_n = pyreft.parse_positions(positions)

terminators = [tokenizer.eos_token_id]

def get_intervention_locations(last_position, first_n, last_n, pad_mode, num_interventions, share_weights):
    # Placeholder function for getting intervention locations, replace with actual logic
    return [[0] * last_position for _ in range(num_interventions)]

tokenized_prompts = []
# Preprocess: Split into sentences and tokenize
for q_ in range(len(rando)):
    sentences = sent_tokenize(rando[q_])
    for s_ in sentences:
        original_prompt = prompt_no_input_template % s_
        last_position = len(tokenizer.encode(original_prompt))  # Get actual length before padding
        tokenized_prompt = tokenizer(original_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=max_tokens)
        tokenized_prompts.append((q_, tokenized_prompt.to(device), last_position))

#Incorrect batch
if(True):
    #attempted batch
    
    # Process: Generate responses in batches
    quotes_fol_ = []
    
    for r in tqdm(range(0, len(tokenized_prompts), batch_size)):
        batch_prompts = tokenized_prompts[r: r + batch_size]
        
        input_ids = torch.cat([bp[1]['input_ids'] for bp in batch_prompts], dim=0)
        attention_masks = torch.cat([bp[1]['attention_mask'] for bp in batch_prompts], dim=0)
        
        unit_locations = torch.IntTensor([get_intervention_locations(
            last_position=max_tokens,
            first_n=first_n,
            last_n=last_n,
            pad_mode="last",
            num_interventions=len(reft_config.representations),
            share_weights=share_weights
        )]).repeat(input_ids.shape[0] // len(batch_prompts), 1, 1).permute(1, 0, 2).tolist()
        
        # Generate with beam search
        generation_args = {
            "base": {"input_ids": input_ids, "attention_mask": attention_masks},
            "unit_locations": {"sources->base": (None, unit_locations)},
            "intervene_on_prompt": True,
            "max_new_tokens": max_tokens,
            "do_sample": True,
            "top_k": 50,
            "temperature": 0.7,
            "num_beams": 5,
            "eos_token_id": terminators,
            "early_stopping": True
        }
    
        _, reft_response = reft_model.generate(**generation_args)
        
        responses = tokenizer.batch_decode(reft_response, skip_special_tokens=True)
        
        quotes_fol = []
        for i, response in enumerate(responses):
            quotes_fol.append(response)
            original_index = batch_prompts[i][0]
            quotes_fol_.append([original_index, quotes_fol])
    
    # Output the final results
    print(quotes_fol_)


that's what I got atm, but it's not applying the control vector appropriately

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants