-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #739 from gmuffiness/master
add CRAFT training code
- Loading branch information
Showing
28 changed files
with
4,946 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
__pycache__/ | ||
model/__pycache__/ | ||
wandb/* | ||
vis_result/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# CRAFT-train | ||
On the official CRAFT github, there are many people who want to train CRAFT models. | ||
|
||
However, the training code is not published in the official CRAFT repository. | ||
|
||
There are other reproduced codes, but there is a gap between their performance and performance reported in the original paper. (https://arxiv.org/pdf/1904.01941.pdf) | ||
|
||
The trained model with this code recorded a level of performance similar to that of the original paper. | ||
|
||
```bash | ||
├── config | ||
│ ├── syn_train.yaml | ||
│ └── custom_data_train.yaml | ||
├── data | ||
│ ├── pseudo_label | ||
│ │ ├── make_charbox.py | ||
│ │ └── watershed.py | ||
│ ├── boxEnlarge.py | ||
│ ├── dataset.py | ||
│ ├── gaussian.py | ||
│ ├── imgaug.py | ||
│ └── imgproc.py | ||
├── loss | ||
│ └── mseloss.py | ||
├── metrics | ||
│ └── eval_det_iou.py | ||
├── model | ||
│ ├── craft.py | ||
│ └── vgg16_bn.py | ||
├── utils | ||
│ ├── craft_utils.py | ||
│ ├── inference_boxes.py | ||
│ └── utils.py | ||
├── trainSynth.py | ||
├── train.py | ||
├── train_distributed.py | ||
├── eval.py | ||
├── data_root_dir (place dataset folder here) | ||
└── exp (model and experiment result files will saved here) | ||
``` | ||
|
||
### Installation | ||
|
||
Install using `pip` | ||
|
||
``` bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
|
||
### Training | ||
1. Put your training, test data in the following format | ||
``` | ||
└── data_root_dir (you can change root dir in yaml file) | ||
├── ch4_training_images | ||
│ ├── img_1.jpg | ||
│ └── img_2.jpg | ||
├── ch4_training_localization_transcription_gt | ||
│ ├── gt_img_1.txt | ||
│ └── gt_img_2.txt | ||
├── ch4_test_images | ||
│ ├── img_1.jpg | ||
│ └── img_2.jpg | ||
└── ch4_training_localization_transcription_gt | ||
├── gt_img_1.txt | ||
└── gt_img_2.txt | ||
``` | ||
* localization_transcription_gt files format : | ||
``` | ||
377,117,463,117,465,130,378,130,Genaxis Theatre | ||
493,115,519,115,519,131,493,131,[06] | ||
374,155,409,155,409,170,374,170,### | ||
``` | ||
2. Write configuration in yaml format (example config files are provided in `config` folder.) | ||
* To speed up training time with multi-gpu, set num_worker > 0 | ||
3. Put the yaml file in the config folder | ||
4. Run training script like below (If you have multi-gpu, run train_distributed.py) | ||
5. Then, experiment results will be saved to ```./exp/[yaml]``` by default. | ||
|
||
* Step 1 : To train CRAFT with SynthText dataset from scratch | ||
* Note : This step is not necessary if you use <a href="https://drive.google.com/file/d/1enVIsgNvBf3YiRsVkxodspOn55PIK-LJ/view?usp=sharing">this pretrain</a> as a checkpoint when start training step 2. You can download and put it in `exp/CRAFT_clr_amp_29500.pth` and change `ckpt_path` in the config file according to your local setup. | ||
``` | ||
CUDA_VISIBLE_DEVICES=0 python3 trainSynth.py --yaml=syn_train | ||
``` | ||
|
||
* Step 2 : To train CRAFT with [SynthText + IC15] or custom dataset | ||
``` | ||
CUDA_VISIBLE_DEVICES=0 python3 train.py --yaml=custom_data_train ## if you run on single GPU | ||
CUDA_VISIBLE_DEVICES=0,1 python3 train_distributed.py --yaml=custom_data_train ## if you run on multi GPU | ||
``` | ||
|
||
### Arguments | ||
* ```--yaml``` : configuration file name | ||
|
||
### Evaluation | ||
* In the official repository issues, the author mentioned that the first row setting F1-score is around 0.75. | ||
* In the official paper, it is stated that the result F1-score of the second row setting is 0.87. | ||
* If you adjust post-process parameter 'text_threshold' from 0.85 to 0.75, then F1-score reaches to 0.856. | ||
* It took 14h to train weak-supervision 25k iteration with 8 RTX 3090 Ti. | ||
* Half of GPU assigned for training, and half of GPU assigned for supervision setting. | ||
|
||
| Training Dataset | Evaluation Dataset | Precision | Recall | F1-score | pretrained model | | ||
| ------------- |-----|:-----:|:-----:|:-----:|-----:| | ||
| SynthText | ICDAR2013 | 0.801 | 0.748 | 0.773| <a href="https://drive.google.com/file/d/1enVIsgNvBf3YiRsVkxodspOn55PIK-LJ/view?usp=sharing">download link</a>| | ||
| SynthText + ICDAR2015 | ICDAR2015 | 0.909 | 0.794 | 0.848| <a href="https://drive.google.com/file/d/1qUeZIDSFCOuGS9yo8o0fi-zYHLEW6lBP/view">download link</a>| |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
wandb_opt: False | ||
|
||
results_dir: "./exp/" | ||
vis_test_dir: "./vis_result/" | ||
|
||
data_root_dir: "./data_root_dir/" | ||
score_gt_dir: None # "/data/ICDAR2015_official_supervision" | ||
mode: "weak_supervision" | ||
|
||
|
||
train: | ||
backbone : vgg | ||
use_synthtext: False # If you want to combine SynthText in train time as CRAFT did, you can turn on this option | ||
synth_data_dir: "/data/SynthText/" | ||
synth_ratio: 5 | ||
real_dataset: custom | ||
ckpt_path: "./pretrained_model/CRAFT_clr_amp_29500.pth" | ||
eval_interval: 1000 | ||
batch_size: 5 | ||
st_iter: 0 | ||
end_iter: 25000 | ||
lr: 0.0001 | ||
lr_decay: 7500 | ||
gamma: 0.2 | ||
weight_decay: 0.00001 | ||
num_workers: 0 # On single gpu, train.py execution only works when num worker = 0 / On multi-gpu, you can set num_worker > 0 to speed up | ||
amp: True | ||
loss: 2 | ||
neg_rto: 0.3 | ||
n_min_neg: 5000 | ||
data: | ||
vis_opt: False | ||
pseudo_vis_opt: False | ||
output_size: 768 | ||
do_not_care_label: ['###', ''] | ||
mean: [0.485, 0.456, 0.406] | ||
variance: [0.229, 0.224, 0.225] | ||
enlarge_region : [0.5, 0.5] # x axis, y axis | ||
enlarge_affinity: [0.5, 0.5] | ||
gauss_init_size: 200 | ||
gauss_sigma: 40 | ||
watershed: | ||
version: "skimage" | ||
sure_fg_th: 0.75 | ||
sure_bg_th: 0.05 | ||
syn_sample: -1 | ||
custom_sample: -1 | ||
syn_aug: | ||
random_scale: | ||
range: [1.0, 1.5, 2.0] | ||
option: False | ||
random_rotate: | ||
max_angle: 20 | ||
option: False | ||
random_crop: | ||
version: "random_resize_crop_synth" | ||
option: True | ||
random_horizontal_flip: | ||
option: False | ||
random_colorjitter: | ||
brightness: 0.2 | ||
contrast: 0.2 | ||
saturation: 0.2 | ||
hue: 0.2 | ||
option: True | ||
custom_aug: | ||
random_scale: | ||
range: [ 1.0, 1.5, 2.0 ] | ||
option: False | ||
random_rotate: | ||
max_angle: 20 | ||
option: True | ||
random_crop: | ||
version: "random_resize_crop" | ||
scale: [0.03, 0.4] | ||
ratio: [0.75, 1.33] | ||
rnd_threshold: 1.0 | ||
option: True | ||
random_horizontal_flip: | ||
option: True | ||
random_colorjitter: | ||
brightness: 0.2 | ||
contrast: 0.2 | ||
saturation: 0.2 | ||
hue: 0.2 | ||
option: True | ||
|
||
test: | ||
trained_model : null | ||
custom_data: | ||
test_set_size: 500 | ||
test_data_dir: "./data_root_dir/" | ||
text_threshold: 0.75 | ||
low_text: 0.5 | ||
link_threshold: 0.2 | ||
canvas_size: 2240 | ||
mag_ratio: 1.75 | ||
poly: False | ||
cuda: True | ||
vis_opt: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import os | ||
import yaml | ||
from functools import reduce | ||
|
||
CONFIG_PATH = os.path.dirname(__file__) | ||
|
||
def load_yaml(config_name): | ||
|
||
with open(os.path.join(CONFIG_PATH, config_name)+ '.yaml') as file: | ||
config = yaml.safe_load(file) | ||
|
||
return config | ||
|
||
class DotDict(dict): | ||
def __getattr__(self, k): | ||
try: | ||
v = self[k] | ||
except: | ||
return super().__getattr__(k) | ||
if isinstance(v, dict): | ||
return DotDict(v) | ||
return v | ||
|
||
def __getitem__(self, k): | ||
if isinstance(k, str) and '.' in k: | ||
k = k.split('.') | ||
if isinstance(k, (list, tuple)): | ||
return reduce(lambda d, kk: d[kk], k, self) | ||
return super().__getitem__(k) | ||
|
||
def get(self, k, default=None): | ||
if isinstance(k, str) and '.' in k: | ||
try: | ||
return self[k] | ||
except KeyError: | ||
return default | ||
return super().get(k, default=default) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
wandb_opt: False | ||
|
||
results_dir: "./exp/" | ||
vis_test_dir: "./vis_result/" | ||
data_dir: | ||
synthtext: "/data/SynthText/" | ||
synthtext_gt: NULL | ||
|
||
train: | ||
backbone : vgg | ||
dataset: ["synthtext"] | ||
ckpt_path: null | ||
eval_interval: 1000 | ||
batch_size: 5 | ||
st_iter: 0 | ||
end_iter: 50000 | ||
lr: 0.0001 | ||
lr_decay: 15000 | ||
gamma: 0.2 | ||
weight_decay: 0.00001 | ||
num_workers: 4 | ||
amp: True | ||
loss: 3 | ||
neg_rto: 1 | ||
n_min_neg: 1000 | ||
data: | ||
vis_opt: False | ||
output_size: 768 | ||
mean: [0.485, 0.456, 0.406] | ||
variance: [0.229, 0.224, 0.225] | ||
enlarge_region : [0.5, 0.5] # x axis, y axis | ||
enlarge_affinity: [0.5, 0.5] | ||
gauss_init_size: 200 | ||
gauss_sigma: 40 | ||
syn_sample : -1 | ||
syn_aug: | ||
random_scale: | ||
range: [1.0, 1.5, 2.0] | ||
option: False | ||
random_rotate: | ||
max_angle: 20 | ||
option: False | ||
random_crop: | ||
version: "random_resize_crop_synth" | ||
rnd_threshold : 1.0 | ||
option: True | ||
random_horizontal_flip: | ||
option: False | ||
random_colorjitter: | ||
brightness: 0.2 | ||
contrast: 0.2 | ||
saturation: 0.2 | ||
hue: 0.2 | ||
option: True | ||
|
||
test: | ||
trained_model: null | ||
icdar2013: | ||
test_set_size: 233 | ||
cuda: True | ||
vis_opt: True | ||
test_data_dir : "/data/ICDAR2013/" | ||
text_threshold: 0.85 | ||
low_text: 0.5 | ||
link_threshold: 0.2 | ||
canvas_size: 960 | ||
mag_ratio: 1.5 | ||
poly: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import math | ||
import numpy as np | ||
|
||
|
||
def pointAngle(Apoint, Bpoint): | ||
angle = (Bpoint[1] - Apoint[1]) / ((Bpoint[0] - Apoint[0]) + 10e-8) | ||
return angle | ||
|
||
def pointDistance(Apoint, Bpoint): | ||
return math.sqrt((Bpoint[1] - Apoint[1])**2 + (Bpoint[0] - Apoint[0])**2) | ||
|
||
def lineBiasAndK(Apoint, Bpoint): | ||
|
||
K = pointAngle(Apoint, Bpoint) | ||
B = Apoint[1] - K*Apoint[0] | ||
return K, B | ||
|
||
def getX(K, B, Ypoint): | ||
return int((Ypoint-B)/K) | ||
|
||
def sidePoint(Apoint, Bpoint, h, w, placehold, enlarge_size): | ||
|
||
K, B = lineBiasAndK(Apoint, Bpoint) | ||
angle = abs(math.atan(pointAngle(Apoint, Bpoint))) | ||
distance = pointDistance(Apoint, Bpoint) | ||
|
||
x_enlarge_size, y_enlarge_size = enlarge_size | ||
|
||
XaxisIncreaseDistance = abs(math.cos(angle) * x_enlarge_size * distance) | ||
YaxisIncreaseDistance = abs(math.sin(angle) * y_enlarge_size * distance) | ||
|
||
if placehold == 'leftTop': | ||
x1 = max(0, Apoint[0] - XaxisIncreaseDistance) | ||
y1 = max(0, Apoint[1] - YaxisIncreaseDistance) | ||
elif placehold == 'rightTop': | ||
x1 = min(w, Bpoint[0] + XaxisIncreaseDistance) | ||
y1 = max(0, Bpoint[1] - YaxisIncreaseDistance) | ||
elif placehold == 'rightBottom': | ||
x1 = min(w, Bpoint[0] + XaxisIncreaseDistance) | ||
y1 = min(h, Bpoint[1] + YaxisIncreaseDistance) | ||
elif placehold == 'leftBottom': | ||
x1 = max(0, Apoint[0] - XaxisIncreaseDistance) | ||
y1 = min(h, Apoint[1] + YaxisIncreaseDistance) | ||
return int(x1), int(y1) | ||
|
||
def enlargebox(box, h, w, enlarge_size, horizontal_text_bool): | ||
|
||
if not horizontal_text_bool: | ||
enlarge_size = (enlarge_size[1], enlarge_size[0]) | ||
|
||
box = np.roll(box, -np.argmin(box.sum(axis=1)), axis=0) | ||
|
||
Apoint, Bpoint, Cpoint, Dpoint = box | ||
K1, B1 = lineBiasAndK(box[0], box[2]) | ||
K2, B2 = lineBiasAndK(box[3], box[1]) | ||
X = (B2 - B1)/(K1 - K2) | ||
Y = K1 * X + B1 | ||
center = [X, Y] | ||
|
||
x1, y1 = sidePoint(Apoint, center, h, w, 'leftTop', enlarge_size) | ||
x2, y2 = sidePoint(center, Bpoint, h, w, 'rightTop', enlarge_size) | ||
x3, y3 = sidePoint(center, Cpoint, h, w, 'rightBottom', enlarge_size) | ||
x4, y4 = sidePoint(Dpoint, center, h, w, 'leftBottom', enlarge_size) | ||
newcharbox = np.array([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) | ||
return newcharbox |
Oops, something went wrong.