Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature(MInference): support LLaMA-3-70B-1M and multi-gpu PP #59

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ get_support_models()

Currently, we support the following LLMs:
- LLaMA-3.1: [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
- LLaMA-3: [gradientai/Llama-3-8B-Instruct-262k](https://huggingface.co/gradientai/Llama-3-8B-Instruct-262k), [gradientai/Llama-3-8B-Instruct-Gradient-1048k](https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k), [gradientai/Llama-3-8B-Instruct-Gradient-4194k](https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-4194k)
- LLaMA-3: [gradientai/Llama-3-8B-Instruct-262k](https://huggingface.co/gradientai/Llama-3-8B-Instruct-262k), [gradientai/Llama-3-8B-Instruct-Gradient-1048k](https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k), [gradientai/Llama-3-8B-Instruct-Gradient-4194k](https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-4194k), [gradientai/Llama-3-70B-Instruct-Gradient-262k](https://huggingface.co/gradientai/Llama-3-70B-Instruct-Gradient-262k), [gradientai/Llama-3-70B-Instruct-Gradient-1048k](https://huggingface.co/gradientai/Llama-3-70B-Instruct-Gradient-1048k)
- GLM-4: [THUDM/glm-4-9b-chat-1m](https://huggingface.co/THUDM/glm-4-9b-chat-1m)
- Yi: [01-ai/Yi-9B-200K](https://huggingface.co/01-ai/Yi-9B-200K)
- Phi-3: [microsoft/Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct)
Expand Down
2 changes: 1 addition & 1 deletion experiments/infinite_bench/run_infinitebench.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def load_model(
model_name,
config=config,
torch_dtype="auto",
device_map="cuda",
device_map="auto",
resume_download=None,
trust_remote_code=trust_remote_code,
)
Expand Down

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions minference/configs/model2path.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
"meta-llama/Meta-Llama-3.1-8B-Instruct": os.path.join(
BASE_DIR, "Llama_3.1_8B_Instruct_128k_kv_out_v32_fit_o_best_pattern.json"
),
"gradientai/Llama-3-70B-Instruct-Gradient-262k": os.path.join(
BASE_DIR, "Llama_3_70B_Instruct_262k_kv_out_v32_fit_o_best_pattern.json"
),
"gradientai/Llama-3-70B-Instruct-Gradient-1048k": os.path.join(
BASE_DIR, "Llama_3_70B_Instruct_262k_kv_out_v32_fit_o_best_pattern.json"
),
}


Expand Down
35 changes: 22 additions & 13 deletions minference/modules/inf_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def append(self, tensor: torch.Tensor):
self.append_cache()

self.data[self.length : self.length + append_l, ...].copy_(tensor)
self.data = self.data.to(tensor.device)

self.length += append_l

Expand Down Expand Up @@ -567,9 +568,14 @@ def _append(self, local_q, local_k, local_v, global_q):

# calc local result first to overlap host-device communication
attn = self.Attn(local_h_q.shape, local_h_q.dtype, local_h_q.device)
attn.append(
local_h_q, local_h_k, local_h_v, get_score=True, sliding_window=self.n_local
)
with torch.cuda.device(local_h_k.device):
attn.append(
local_h_q,
local_h_k,
local_h_v,
get_score=True,
sliding_window=self.n_local,
)

# calc topk global repr k and load cache
with torch.cuda.stream(GLOBAL_STREAM):
Expand Down Expand Up @@ -612,15 +618,16 @@ def _append(self, local_q, local_k, local_v, global_q):
torch.cuda.current_stream().wait_stream(GLOBAL_STREAM)

# calc global result
attn.append(
global_h_q,
global_h_k,
global_h_v,
end=True,
get_score=self.calc_block_score,
sliding_window=global_sliding_window,
complement_sliding_window=True,
)
with torch.cuda.device(global_h_q.device):
attn.append(
global_h_q,
global_h_k,
global_h_v,
end=True,
get_score=self.calc_block_score,
sliding_window=global_sliding_window,
complement_sliding_window=True,
)

o, score_list = attn.get_result()
loc_score = score_list[0]
Expand Down Expand Up @@ -1238,7 +1245,9 @@ def _decode(
if word.item() in end_token_ids or i == max_length:
break

input_ids = torch.cat((input_ids, word.view(1, 1)), dim=-1)
input_ids = torch.cat(
(input_ids, word.view(1, 1).to(input_ids.device)), dim=-1
)
attention_mask = torch.cat(
(
attention_mask,
Expand Down
17 changes: 9 additions & 8 deletions minference/modules/minference_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ def block_sparse(topk_ratio, slash_size=None):
attention_mask = torch.full((q_len, q_len), torch.finfo(q.dtype).min, device="cuda")
mask_cond = torch.arange(attention_mask.size(-1), device="cuda")
attention_mask.masked_fill_(mask_cond < (mask_cond + 1).view(attention_mask.size(-1), 1), 0)
attention_mask = attention_mask[None, None, :]
attention_mask = attention_mask[None, None, :].to(q.device)
SEARCH_MASK = attention_mask
else:
attention_mask = SEARCH_MASK
attention_mask = SEARCH_MASK.to(q.device)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim) + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
best_s, best_v, best_score, best_ty = 0, 0, 0, ""
Expand Down Expand Up @@ -531,7 +531,7 @@ def forward(
if self.is_search:
if os.path.exists(self.config_path):
config_list = json.load(open(self.config_path))
if self.layer_idx < len(config_list):
if self.config.num_hidden_layers == len(config_list):
assert False, f"Search completed. The config is located in {self.config_path}."
else:
config_list = []
Expand All @@ -543,7 +543,7 @@ def forward(
q = query_states[:, head, :, :].unsqueeze(1)
k = key_states[:, head, :, :].unsqueeze(1)
v = value_states[:, head, :, :].unsqueeze(1)
if self.is_search:
if self.is_search and self.layer_idx >= len(config_list):
config[head] = search_pattern(q, k, head)
if self.layer_idx >= self.starting_layer and not self.is_search:
attn_output = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)
Expand All @@ -553,13 +553,14 @@ def forward(
attn_output = gather_qkv(q, k, v, attention_mask)
output[:, head:head + 1] = attn_output
if self.is_search:
config_list.append(config)
if len(config):
config_list.append(config)
with open(self.config_path, 'w') as json_file:
json.dump(config_list, json_file)
else:
output = flash_attn_func(query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, query_states.size(1), q_len, self.head_dim)
attn_output = output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value
Expand Down Expand Up @@ -741,7 +742,7 @@ def forward(
output[:, head:head + 1] = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)

attn_output = output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

Expand All @@ -757,7 +758,7 @@ def forward(
attn_output = gather_qkv(q, k, v, attention_mask)
output[:, head:head + 1] = attn_output
attn_output = output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value
Expand Down
6 changes: 6 additions & 0 deletions minference/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def apply_rotary_pos_emb(self, x, length, right, cos, sin):
cos = cos[:, :, right - length : right, :]
sin = sin[:, :, right - length : right, :]

if cos.device != x.device:
cos, sin = cos.to(x.device), sin.to(x.device)

return ((x.float() * cos) + (self.rotate_half(x).float() * sin)).to(dtype)

def _update_cos_sin_tables(self, x, seq_dim):
Expand Down Expand Up @@ -144,6 +147,9 @@ def apply_rotary_pos_emb_one_angle(self, x: torch.Tensor, index):
cos = cos[:, :, index - 1 : index, :]
sin = sin[:, :, index - 1 : index, :]

if cos.device != x.device:
cos, sin = cos.to(x.device), sin.to(x.device)

return ((x.float() * cos) + (self.rotate_half(x).float() * sin)).to(dtype)

def forward(
Expand Down