Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 7th】Fundable project 5. Support GOT-OCR-2.0 inference and training #885

Merged
merged 6 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
439 changes: 439 additions & 0 deletions paddlemix/datasets/got_dataset.py

Large diffs are not rendered by default.

Binary file added paddlemix/demo_images/hospital.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
53 changes: 53 additions & 0 deletions paddlemix/examples/GOT_OCR_2_0/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# GOT-OCR2.0

## 1. 模型介绍

[GOT-OCR2.0](https://arxiv.org/abs/2409.01704)是一款极具突破性的通用OCR模型,旨在解决传统OCR系统(OCR-1.0)和当前大规模视觉语言模型(LVLMs)在OCR任务中的局限性。本仓库提供paddle版本的`GOT-OCR2.0`模型。


## 2. 环境要求
- **python >= 3.10**
- **paddlepaddle-gpu 要求3.0.0b2或版本develop**
```
# develop版安装示例
python -m pip install paddlepaddle-gpu==0.0.0.post118 -f https://www.paddlepaddle.org.cn/whl/linux/gpu/develop.html
```

- **paddlenlp == 3.0.0b2**

> 注:(默认开启flash_attn)使用flash_attn 要求A100/A800显卡或者H20显卡。V100请用float16推理。


## 3 推理预测

### 3.1. plain texts OCR:
```bash
python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \
--model_name_or_path stepfun-ai/GOT-OCR2_0 \
--image_file paddlemix/demo_images/hospital.jpeg \
--ocr_type ocr \
```

### 3.2. format texts OCR:
```bash
python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py \
--model_name_or_path stepfun-ai/GOT-OCR2_0 \
--image_file paddlemix/demo_images/hospital.jpeg \
--ocr_type format \
```

## 4 训练
```bash
sh paddlemix/examples/GOT_OCR_2_0/run_train.sh
```


## 参考文献
```BibTeX
@article{wei2024general,
title={General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model},
author={Wei, Haoran and Liu, Chenglong and Chen, Jinyue and Wang, Jia and Kong, Lingyu and Xu, Yanming and Ge, Zheng and Zhao, Liang and Sun, Jianjian and Peng, Yuang and others},
journal={arXiv preprint arXiv:2409.01704},
year={2024}
}
```
6 changes: 6 additions & 0 deletions paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"synthdog_en": {
"images": "synthdog_en/",
"annotations": "synthdog_en/synthdog_en_29765_ocr_1k.json"
}
}
78 changes: 78 additions & 0 deletions paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import paddle
from paddlenlp.transformers import QWenTokenizer
from paddlemix.models.GOT.GOT_ocr_2_0 import GOTQwenForCausalLM

parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="stepfun-ai/GOT-OCR2_0", help="pretrained ckpt and tokenizer")
parser.add_argument("--image_file", type=str, default="paddlemix/demo_images/hospital.jpeg")
parser.add_argument("--multi_crop", action="store_true")
parser.add_argument("--ocr_type", type=str, default="plain", choices=["ocr", "format"])
parser.add_argument("--box", type=str, default="")
parser.add_argument("--color", type=str, default="")
parser.add_argument("--render", action="store_true")
args = parser.parse_args()
model_name_or_path = args.model_name_or_path

tokenizer = QWenTokenizer.from_pretrained(model_name_or_path)
model = GOTQwenForCausalLM.from_pretrained(
model_name_or_path, dtype=paddle.bfloat16, pad_token_id=tokenizer.eos_token_id
).eval()

# input test image
image_file = args.image_file
with paddle.no_grad():
if args.multi_crop:
# multi-crop OCR:
res = model.chat_crop(
tokenizer, image_file, ocr_type=args.ocr_type, render=args.render, save_render_file="./demo.html"
)
else:
# plain texts OCR
# format texts OCR
# fine-grained OCR
# render the formatted OCR results
res = model.chat(
tokenizer,
image_file,
ocr_type=args.ocr_type,
ocr_box=args.box,
ocr_color=args.color,
render=args.render,
save_render_file="./demo.html",
)

# plain texts OCR
# res = model.chat(tokenizer, image_file, ocr_type='ocr')

# format texts OCR:
# res = model.chat(tokenizer, image_file, ocr_type='format')

# fine-grained OCR:
# res = model.chat(tokenizer, image_file, ocr_type='ocr', ocr_box='')
# res = model.chat(tokenizer, image_file, ocr_type='format', ocr_box='')
# res = model.chat(tokenizer, image_file, ocr_type='ocr', ocr_color='')
# res = model.chat(tokenizer, image_file, ocr_type='format', ocr_color='')

# multi-crop OCR:
# res = model.chat_crop(tokenizer, image_file, ocr_type='ocr')
# res = model.chat_crop(tokenizer, image_file, ocr_type='format')

# render the formatted OCR results:
# res = model.chat(tokenizer, image_file, ocr_type='format', render=True, save_render_file = './demo.html')

print(res)
78 changes: 78 additions & 0 deletions paddlemix/examples/GOT_OCR_2_0/run_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -x

GPUS=${GPUS:-8}
BATCH_SIZE=${BATCH_SIZE:-8}
PER_DEVICE_BATCH_SIZE=${PER_DEVICE_BATCH_SIZE:-1}

GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / GPUS))
tensor_parallel_degree=${tensor_parallel_degree:-1}
sharding_parallel_degree=$((GPUS / tensor_parallel_degree))

export PYTHONPATH="${PYTHONPATH}:$(pwd)"
export MASTER_PORT=34229
export TF_CPP_MIN_LOG_LEVEL=3

OUTPUT_DIR='work_dirs/got_ocr_20'

# meta='pdf-ocr+scence'

if [ ! -d "$OUTPUT_DIR" ]; then
mkdir -p "$OUTPUT_DIR"
fi

TRAINING_MODEL_RESUME="None"
TRAINER_INSTANCES='127.0.0.1'
MASTER='127.0.0.1:8080'

TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes 1 --nproc_per_node ${GPUS} --rank 0 --ips ${TRAINER_INSTANCES} --run_mode=collective"
${TRAINING_PYTHON} --log_dir ${OUTPUT_DIR}/paddle_distributed_logs \
paddlemix/examples/GOT_OCR_2_0/train_GOT.py \
--do_train \
--model_name_or_path "stepfun-ai/GOT-OCR2_0" \
--output_dir ${OUTPUT_DIR} \
--logging_dir ${OUTPUT_DIR}/logs \
--meta_path paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json \
--overwrite_output_dir True \
--dataloader_num_workers 8 \
--bf16 True \
--fp16 False \
--fp16_opt_level "O2" \
--num_train_epochs 1 \
--per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \
--gradient_accumulation_steps ${GRADIENT_ACC} \
--freeze_vision_tower False \
--use_im_start_end True \
--max_seq_length 8192 \
--recompute False \
--max_grad_norm 1.0 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 200 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.001 \
--optim "adamw" \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--report_to "visualdl" \
--tensor_parallel_degree=${tensor_parallel_degree} \
--sharding_parallel_degree=${sharding_parallel_degree} \
--pipeline_parallel_degree=1 \
--sep_parallel_degree=1 \
--sharding="stage1" \
2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt"
Loading