Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using Algorithm 2 in SpecInfer paper, I get wrong outputs. #1302

Open
dutsc opened this issue Feb 19, 2024 · 3 comments
Open

Using Algorithm 2 in SpecInfer paper, I get wrong outputs. #1302

dutsc opened this issue Feb 19, 2024 · 3 comments
Assignees
Labels
bug Something isn't working inference Features and fixes related to the inference project.

Comments

@dutsc
Copy link

dutsc commented Feb 19, 2024

I'm having trouble learning the SpecInfer source code.

The pseudo code of the algorithm in the SpecInfer paper about verify model verifying the output of draft model is as follows:
image

I implemented this pseudocode using python, but the output I got was not normal and it didn't seem to be the correct answer.

my prompt: Could you please give me some advice on programming?
my generated text: Without knowing what you want to develop I cannot help you very much. You have to study the specific topics you want to learn. I assume you want to learn programming thinking about programs. I'm very vague about that since I work more in scientific computing than in programming.
no probsdpocket weed and we got stoned now i'm tripping balls high as hell now that's WHO I

I set max_length=100, draft model inference step lookahead=4, verify model uses opt-6.7b, and draft model uses opt-125m.

I hope to solve the problem by referring to the SpecInfer source code, but this is very difficult for me. I guess that the part where the verify model verifies the output of the draft model is in the prepare_next_batch_init function and the traverse_verify_tree function in the request_manager.cc file, but I can't quite understand the contents.

Here is the above pseudocode implemented in python:

def verify_stochastic(root, llm_logits, temperature=0):
    torch.no_grad()
    V = []  # List to store verified tokens
    u = root  # Start with the root node
    # index_mapping stores the output results of SSMs and merges them into a tree
    index_mapping = {}
    index_mapping[root]=-1
    create_index_mapping(root, index_mapping)
    # while u is a non-leaf node do
    while len(u.children.values()) > 0:
    
        # H = child(u)
        H = list(u.children.values()) 
        
        # while H is not empty do
        while len(H) > 0:
            
            # s ∼ rand(H)   # rand select a node
            s_idx = random.randint(0, len(H) - 1)
            s = H[s_idx]
            
            # r ~ U(0, 1) # randint [0,1]
            r = random.random()

            # x_s = H[s] # token of the node
            x_s = s.token_logits_pair.token
            
            # if r ≤ P(x_s | u, LLM)/P(x_s | u, SSMs ) then
            ssmp_s = s.token_logits_pair.logits[:,x_s].item() + 1e-9
            llmp_s = llm_logits[:, index_mapping[s.parent]+1, x_s].item() + 1e-9
            print(f"ssmp: {ssmp_s}")
            print(f"llmp: {llmp_s}")
            if r <= llmp_s / ssmp_s:
                V.append(x_s)
                u = s
                break
            # else
            else:
                # P(x | u, LLM) := norm(max(0, P(x | u, LLM) − P(x | u, SSMs )))
                llmp = llm_logits[:, index_mapping[s.parent]+1, :]
                ssmp = s.token_logits_pair.logits[:,:]
                
                new_dist = (llmp - ssmp)
                new_dist = torch.max(torch.zeros_like(new_dist), new_dist)
                new_dist = new_dist / new_dist.sum(dim=-1, keepdim=True)
                llm_logits[:, index_mapping[s.parent]+1, :] = new_dist
                
                H.remove(s)
        # if H is empty then: 
        if len(H) == 0:
            break
    # xnext ∼ P(x | u, ΘLLM) 
    llmp = llm_logits[:, index_mapping[s.parent]+1, :]
    x_next = torch.multinomial(llmp, num_samples=1)  # rand sample from new_dist
    
    V.append(x_next)
    return V

Here is a description about TreeNode:

class TokenLogitsPair:
    def __init__(self, token, logits):
        self.token = token
        self.logits = logits
    def to(self, device, non_blocking=False):
        self.token = self.token.to(device, non_blocking=non_blocking)
        self.logits = self.logits.to(device, non_blocking=non_blocking)
        return self

class TreeNode:
    def __init__(self, token_logits_pair=None, parent=None):
        self.token_logits_pair = token_logits_pair
        self.parent = parent
        self.children = {}

I hope someone can help me.

@jiazhihao jiazhihao added the bug Something isn't working label Feb 19, 2024
@jiazhihao jiazhihao added the inference Features and fixes related to the inference project. label Feb 19, 2024
@dutsc
Copy link
Author

dutsc commented Feb 20, 2024

I checked the traverse_verify_tree() function under the /FlexFlow/src/runtime/request_manager.cc path and found that it only verifies whether the tokens are equal. Does this mean that the default implementation of specinfer is the VERIFYGREEDY function of Algorithm 2 in the paper?

Algorithm 2 in the paper:
image

The meaning of this pseudocode is to find a path from the root to the leaf node in the token tree so that its token has the same result as the verify model.

traverse_verify_tree() code snippet:

for (int i = 0; i < outputSerializedTree.size(); i++) {
    auto input = inputSerializedTree.at(i);
    auto output = outputSerializedTree.at(i);

    if (i == 0) {
      verifiedTree.push_back(output); 
      new_committed_tokens.push_back(std::make_pair(
          input.second,
          committed_tokens.at(guid).at(i).second)); // <input_abs_depth,
                                                    // input_index_in_batch>
      // std::cout << committed_tokens.at(guid).at(i).first << ", "
      //           << committed_tokens.at(guid).at(i).second << std::endl;
      // std::cout << input.first << ", " << input.second << std::endl;

      assert(committed_tokens.at(guid).at(i).first == input.second);
      continue;
    }

    if (input.first == verifiedTree.back().first &&
        input.second == verifiedTree.back().second) {  //  input == verifiedTree.back()
      verifiedTree.push_back(output);
      new_committed_tokens.push_back(std::make_pair(
          input.second,
          committed_tokens.at(guid).at(i).second)); // <input_abs_depth,
                                                    // input_index_in_batch>
      assert(committed_tokens.at(guid).at(i).first == input.second);
    }
  }

traverse_verify_tree() only has about 100 lines in total. Except for the content in the picture, it is basically printing the log.

@jiazhihao
Copy link
Collaborator

The current implementation performs greedy decoding, and we are working on a PR for multi-step stochastic sampling and verification. Are the incorrect outputs generated using greedy decoding or stochastic?

@dutsc
Copy link
Author

dutsc commented Feb 21, 2024

The incorrect outputs are generated with stochastic decoding according to Algorithm 2 in SpecInfer paper. When I use greedy verify from Algorithm 2, the same prompt produces the same result.

my prompt:please introduce Kobe Bryant, who played basketball in NBA.

SpecInfer outputs:

I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or

my implementation greedy verify outputs:

I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or not, but Kobe Bryant is a basketball player.
I'm not sure if you're being sarcastic or

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inference Features and fixes related to the inference project.
Projects
Status: No status
Development

No branches or pull requests

4 participants