Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOLUTION] How to run inference on Windows 10? #138

Closed
aleksusklim opened this issue Nov 29, 2023 · 11 comments
Closed

[SOLUTION] How to run inference on Windows 10? #138

aleksusklim opened this issue Nov 29, 2023 · 11 comments
Assignees

Comments

@aleksusklim
Copy link

aleksusklim commented Nov 29, 2023

UPD: the solution is down below

Is this even working on Windows?
I tried to follow your official guide, but pip failed to install deepspeed requirement, because it needs to be built.

I have Microsoft Build Tools, but still couldn't build it (the best I could get was the error about aio.lib)
Then I've found this thread where somebody shared already compiled WHL binary:
microsoft/DeepSpeed#2588 (comment)

The next error I got was from yours SwissArmyTransformer because it has import triton but triton is only available for Linux.
I commented-out all references to triton from sat's source hoping that nothing from that would be actually needed.

But unfortunately, there are direct references to FastRotaryEmbedding from sat.model.position_embedding.triton_rotary_embeddings and I assume there is no way to make it work without triton right away.

How much modifications the code needs? Or I should just wait for some quantized versions of CogVLM to run with llama.cpp?
Like ggerganov/llama.cpp#4196

@1049451037
Copy link
Member

There is an equivalent implementation of rotary in sat which does not depend on triton:

from sat.model.position_embedding.rotary_embeddings import RotaryEmbedding, rotate_half
ass RotaryMixin(BaseMixin):
     def __init__(self, hidden_size, num_heads):
         super().__init__()
         self.rotary_emb = RotaryEmbedding(
             hidden_size // num_heads,
             base=10000,
             precision=torch.half,
             learnable=False,
         )

     def attention_forward(self, hidden_states, mask, **kw_args):
         origin = self
         query_layer = self._transpose_for_scores(mixed_query_layer)
         key_layer = self._transpose_for_scores(mixed_key_layer)
         value_layer = self._transpose_for_scores(mixed_value_layer)
         cos, sin = origin.rotary_emb(value_layer, seq_len=kw_args['position_ids'].max()+1)
         query_layer, key_layer = apply_rotary_pos_emb_index_bhs(query_layer, key_layer, cos, sin, kw_args['position_ids'])

This code piece is equivalent to:

from sat.model.position_embedding.triton_rotary_embeddings import FastRotaryEmbedding

 class RotaryMixin(BaseMixin):
     def __init__(self, hidden_size, num_heads):
         super().__init__()
         self.rotary_emb = FastRotaryEmbedding(hidden_size // num_heads)

     def attention_forward(self, hidden_states, mask, **kw_args):
         origin = self
         query_layer = self._transpose_for_scores(mixed_query_layer)
         key_layer = self._transpose_for_scores(mixed_key_layer)
         value_layer = self._transpose_for_scores(mixed_value_layer)
         query_layer, key_layer = origin.rotary_emb(query_layer,key_layer, kw_args['position_ids'], max_seqlen=kw_args['position_ids'].max()+1, layer_id=kw_args['layer_id'])

@zRzRzRzRzRzRzR zRzRzRzRzRzRzR self-assigned this Dec 1, 2023
@SoYuCry
Copy link

SoYuCry commented Dec 5, 2023

sat 会装不上吧,windows

@zRzRzRzRzRzRzR
Copy link
Member

可以使用Huggingface的版本

@aleksusklim
Copy link
Author

aleksusklim commented Dec 12, 2023

Finally, I've got it to work with 12 Gb GPU on Windows!

But only Huggingface transformer quantized version. Here is how:

Installation:

python -m venv venv
venv\Scripts\activate
pip install xformers==0.0.22.post7+cu118 torchvision==0.16.0+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install bitsandbytes==0.41.2.post2 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
pip install transformers==4.36.0 accelerate==0.25.0 gradio==3.41.0 sentencepiece==0.1.99 protobuf==4.23.4 einops==0.7.0

I've modified the official Gradio demo for transformers version:
(here is a simple textbox that uses internal Question: xxx? Answer: format, so the user can have more control)

import gradio as gr
import os, sys

from transformers import LlamaForCausalLM, LlamaTokenizer, AutoModelForCausalLM
from PIL import Image
import torch
import inspect

tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
model = AutoModelForCausalLM.from_pretrained(
    'THUDM/cogvlm-chat-hf',
    load_in_4bit=True,
    trust_remote_code=True,
).eval()

def main():
    gr.close_all()
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column(scale=4.5):
                with gr.Group():
                    image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)

                with gr.Row():
                    temperature = gr.Slider(maximum=1, value=0, minimum=0, step=0.01, label='Temperature')
                    top_p = gr.Slider(maximum=1, value=0.85, minimum=0, step=0.01, label='Top P')
                    top_k = gr.Slider(maximum=100, value=100, minimum=1, step=1, label='Top K')

                with gr.Row():
                    input_text = gr.components.Textbox(lines=4,label='Examples',value='Question: Describe this image Answer:\nQuestion: How many people are there? Short answer:',interactive=False)

            with gr.Column(scale=5.5):
                with gr.Row():
                    input_text = gr.components.Textbox(lines=10,label='Input Text', placeholder='Question: xxx? Answer:\n\n(separate turns with newlines; make sure there are no spaces after the last "Answer:" or "Short answer:" for VQA')
                with gr.Row():
                        run_button = gr.Button('Generate',variant='primary')
                with gr.Row():
                    result_text = gr.components.Textbox(lines=4,label='Result Text', placeholder='')

        run_button.click(fn=post,inputs=[input_text, temperature, top_p, top_k, image_prompt],outputs=[result_text])

    demo.queue(concurrency_count=1)
    demo.launch()

def post(input_text, temperature, top_p, top_k, image_prompt):
    try:
        with torch.no_grad():
            image = Image.open(image_prompt).convert('RGB') if image_prompt is not None else None
            print(image_prompt)
            print(input_text)
            inputs = model.build_conversation_input_ids(tokenizer, query=input_text, history=[], images=([image] if image else None), template_version='base')
            inputs = {
                'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
                'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
                'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
                'images': [[inputs['images'][0].to('cuda').to(torch.float16)]] if image else None,
            }
            max_length = 2048
            do_sample = (top_p>0) and (top_k>1) and (temperature>0)
            gen_kwargs = {
                "max_length": max_length,
                "do_sample": do_sample
            }
            if do_sample:
                gen_kwargs['top_p'] = top_p
                gen_kwargs['top_k'] = top_k
                gen_kwargs['temperature'] = temperature
            outputs = model.generate(**inputs, **gen_kwargs)
            outputs = outputs[:, inputs['input_ids'].shape[1]:]
            res = tokenizer.decode(outputs[0])
            print(res)
            return res
    except Exception as e:
        print(e)
        return str(e)

main()

Inference:

Save the above code to .py file and just run with Python as normal Gradio local web application.
(The first time it would download around 32 Gb of the model data to C:\Users\User\.cache\huggingface\hub\models--THUDM--cogvlm-chat-hf)

@aleksusklim aleksusklim changed the title How to run inference on Windows 10? [SOLUTION] How to run inference on Windows 10? Dec 12, 2023
@FurkanGozukara
Copy link

@aleksusklim for THUDM/cogagent-vqa-hf model what changes we need?

@aleksusklim
Copy link
Author

Haven't tried CogAgent yet; maybe if I will, I'll reply here in several days.

@aleksusklim
Copy link
Author

Running CogAgent with 12 Gb VRAM is not difficult either!

Here is how I did it, with a new venv since you don't need Gradio but Streamlit:

Installation:

(Assuming Python 3.10 and Git for Windows)

  1. Open the console window in a new folder.
  2. You can change the path where models will be downloaded. To do this, set HUGGINGFACE_HUB_CACHE to a folder on your disk with enough space (each of Chat and VQA model take 34 Gb), for example to choose H: drive you may run
set "HUGGINGFACE_HUB_CACHE=H:\HF"
  1. Clone the repo and create venv:
git clone https://github.com/THUDM/CogVLM
cd CogVLM
python -m venv venv
venv\Scripts\activate
  1. Now install pip dependencies:
pip install xformers==0.0.23.post1+cu118 torchvision==0.16.2+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install bitsandbytes==0.41.2.post2 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
pip install transformers==4.36.2 sentencepiece==0.1.99 protobuf==4.25.2 einops==0.7.0 timm==0.9.12 accelerate==0.26.1 streamlit==1.30.0
  1. To start model download, you can start CLI demo for Chat or VQA model:
python basic_demo\cli_demo_hf.py --bf16 --quant 4 --from_pretrained THUDM/cogagent-chat-hf

python basic_demo\cli_demo_hf.py --bf16 --quant 4 --from_pretrained THUDM/cogagent-vqa-hf

If your video card does not support bfloat16, replace --bf16 with --fp16
If you have troubles with bitsandbytes, remove --quant 4 part.
If you got CUDA OOM errors, install the newest NVidia drivers and enable Sysmem Fallback Policy

  1. Web demo must be adjusted to be used with only single quantized model:

Open composite_demo\client.py in text editor.
Locate the string if you just use one model, use like this
Change the commented code under it to this:

#if you just use one model, use like this
models_info = {
    'tokenizer': {
        'path': os.environ.get('TOKENIZER_PATH', 'lmsys/vicuna-7b-v1.5'),
    },
    'agent_chat': {
        'path': os.environ.get('MODEL_PATH_AGENT_CHAT', 'THUDM/cogagent-chat-hf'),
        'device': ['cuda:0']
    },
}

Then, locate model = AutoModelForCausalLM.from_pretrained there below.
Replace it with this:

                    model = AutoModelForCausalLM.from_pretrained(
                        model_info['path'],
                        torch_dtype=torch_type,
                        low_cpu_mem_usage=True,
                        trust_remote_code=True,
                        load_in_4bit=True,
                    ).eval()

– I've added load_in_4bit=True, and deleted .to(device) before .eval; if quant version was not working for you, do not do this.

Usage:

This is the command to run the main web demo:
streamlit run composite_demo\main.py
(Note that executable is not "python" but "streamlit" from your venv).

The first CogVLM tab would throw errors (unless you change client.py as I showed above).

You can save this .bat file to quickly run the demo from the initial folder:

@echo off
cd /d "%~dp0"
if exist .\CogVLM\ cd CogVLM
call venv\Scripts\activate
set "HUGGINGFACE_HUB_CACHE=H:\HF"
streamlit run composite_demo\main.py

Change cache directory to match what you have used during installation, or delete the line if it was not used.
To run the CLI demo instead, replace the last line with:

python basic_demo\cli_demo_hf.py --bf16 --quant 4 --from_pretrained THUDM/cogagent-chat-hf

Notes:

  • I did not understand how to use more than one image with the demo. It is bugging when I change images in the third time in one dialogue.
  • Also I did not understand how to specify input grounding with the demo; output grounding works fine (ask the model to show "bounding box" of something)
  • VQA model seems to look like CogVLM: it is censored, it tends to make mistakes, it stands on its ground instead of admitting the mistake.
  • The new Chat model is much better: it does not refuse to answer, and it fixes its mistakes in subsequent responses. Though it tends to hallucinate, and it has a strong bias to describe pictures even when not asked to.
  • Sometimes the agent prints "Plan" and "Next action" instead of answering properly; probably it is some kind of desired format for the main use-case which is not processed in the demo.

@FurkanGozukara
Copy link

aleksusklim what is your experience about this. best model to caption images for stable diffusion image model training? like Dall E3

@FurkanGozukara
Copy link

@aleksusklim the code works in 4 bit loading but not in 8 bit loading any ideas why?

#369

@aleksusklim
Copy link
Author

I tried to set load_in_8bit and it failed at https://huggingface.co/THUDM/cogagent-chat-hf/blob/main/cross_visual.py

if self.q_proj.weight.dtype == torch.uint8:
    import bitsandbytes as bnb
    q = bnb.matmul_4bit(x, self.q_proj.weight.t(), bias=self.q_bias, quant_state=self.q_proj.weight.quant_state)
    k = bnb.matmul_4bit(x, self.k_proj.weight.t(), bias=None, quant_state=self.k_proj.weight.quant_state)
    v = bnb.matmul_4bit(x, self.v_proj.weight.t(), bias=self.v_bias, quant_state=self.v_proj.weight.quant_state)
else:
    q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
    k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
    v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)

– the condition was false, and else branch was executed.

Does this code even supports 8 bit?
As per https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig

load_in_8bit (bool, optional, defaults to False) — This flag is used to enable 8-bit quantization with LLM.int8().

Where LLM.int8() is explained at https://huggingface.co/blog/hf-bitsandbytes-integration

import bitsandbytes as bnb
from bnb.nn import Linear8bitLt

int8_model = nn.Sequential(
Linear8bitLt(64, 64, has_fp16_weights=False),
Linear8bitLt(64, 64, has_fp16_weights=False)
)

It looks like this method is not used in CogAgent's code?

Here it is very important to add the flag has_fp16_weights. By default, this is set to True which is used to train in mixed Int8/FP16 precision. However, we are interested in memory efficient inference for which we need to use has_fp16_weights=False.

I don't see any has_fp16_weights there.

@FurkanGozukara
Copy link

@aleksusklim ye it doesn't support

the model itself is not supporting 8bit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants