Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add CRAFT training code #739

Merged
merged 6 commits into from
May 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions trainer/craft/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__pycache__/
model/__pycache__/
wandb/*
vis_result/*
105 changes: 105 additions & 0 deletions trainer/craft/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# CRAFT-train
On the official CRAFT github, there are many people who want to train CRAFT models.

However, the training code is not published in the official CRAFT repository.

There are other reproduced codes, but there is a gap between their performance and performance reported in the original paper. (https://arxiv.org/pdf/1904.01941.pdf)

The trained model with this code recorded a level of performance similar to that of the original paper.

```bash
├── config
│ ├── syn_train.yaml
│ └── custom_data_train.yaml
├── data
│ ├── pseudo_label
│ │ ├── make_charbox.py
│ │ └── watershed.py
│ ├── boxEnlarge.py
│ ├── dataset.py
│ ├── gaussian.py
│ ├── imgaug.py
│ └── imgproc.py
├── loss
│ └── mseloss.py
├── metrics
│ └── eval_det_iou.py
├── model
│ ├── craft.py
│ └── vgg16_bn.py
├── utils
│ ├── craft_utils.py
│ ├── inference_boxes.py
│ └── utils.py
├── trainSynth.py
├── train.py
├── train_distributed.py
├── eval.py
├── data_root_dir (place dataset folder here)
└── exp (model and experiment result files will saved here)
```

### Installation

Install using `pip`

``` bash
pip install -r requirements.txt
```


### Training
1. Put your training, test data in the following format
```
└── data_root_dir (you can change root dir in yaml file)
├── ch4_training_images
│ ├── img_1.jpg
│ └── img_2.jpg
├── ch4_training_localization_transcription_gt
│ ├── gt_img_1.txt
│ └── gt_img_2.txt
├── ch4_test_images
│ ├── img_1.jpg
│ └── img_2.jpg
└── ch4_training_localization_transcription_gt
├── gt_img_1.txt
└── gt_img_2.txt
```
* localization_transcription_gt files format :
```
377,117,463,117,465,130,378,130,Genaxis Theatre
493,115,519,115,519,131,493,131,[06]
374,155,409,155,409,170,374,170,###
```
2. Write configuration in yaml format (example config files are provided in `config` folder.)
* To speed up training time with multi-gpu, set num_worker > 0
3. Put the yaml file in the config folder
4. Run training script like below (If you have multi-gpu, run train_distributed.py)
5. Then, experiment results will be saved to ```./exp/[yaml]``` by default.

* Step 1 : To train CRAFT with SynthText dataset from scratch
* Note : This step is not necessary if you use <a href="https://drive.google.com/file/d/1enVIsgNvBf3YiRsVkxodspOn55PIK-LJ/view?usp=sharing">this pretrain</a> as a checkpoint when start training step 2. You can download and put it in `exp/CRAFT_clr_amp_29500.pth` and change `ckpt_path` in the config file according to your local setup.
```
CUDA_VISIBLE_DEVICES=0 python3 trainSynth.py --yaml=syn_train
```

* Step 2 : To train CRAFT with [SynthText + IC15] or custom dataset
```
CUDA_VISIBLE_DEVICES=0 python3 train.py --yaml=custom_data_train ## if you run on single GPU
CUDA_VISIBLE_DEVICES=0,1 python3 train_distributed.py --yaml=custom_data_train ## if you run on multi GPU
```

### Arguments
* ```--yaml``` : configuration file name

### Evaluation
* In the official repository issues, the author mentioned that the first row setting F1-score is around 0.75.
* In the official paper, it is stated that the result F1-score of the second row setting is 0.87.
* If you adjust post-process parameter 'text_threshold' from 0.85 to 0.75, then F1-score reaches to 0.856.
* It took 14h to train weak-supervision 25k iteration with 8 RTX 3090 Ti.
* Half of GPU assigned for training, and half of GPU assigned for supervision setting.

| Training Dataset | Evaluation Dataset | Precision | Recall | F1-score | pretrained model |
| ------------- |-----|:-----:|:-----:|:-----:|-----:|
| SynthText | ICDAR2013 | 0.801 | 0.748 | 0.773| <a href="https://drive.google.com/file/d/1enVIsgNvBf3YiRsVkxodspOn55PIK-LJ/view?usp=sharing">download link</a>|
| SynthText + ICDAR2015 | ICDAR2015 | 0.909 | 0.794 | 0.848| <a href="https://drive.google.com/file/d/1qUeZIDSFCOuGS9yo8o0fi-zYHLEW6lBP/view">download link</a>|
Empty file.
100 changes: 100 additions & 0 deletions trainer/craft/config/custom_data_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
wandb_opt: False

results_dir: "./exp/"
vis_test_dir: "./vis_result/"

data_root_dir: "./data_root_dir/"
score_gt_dir: None # "/data/ICDAR2015_official_supervision"
mode: "weak_supervision"


train:
backbone : vgg
use_synthtext: False # If you want to combine SynthText in train time as CRAFT did, you can turn on this option
synth_data_dir: "/data/SynthText/"
synth_ratio: 5
real_dataset: custom
ckpt_path: "./pretrained_model/CRAFT_clr_amp_29500.pth"
eval_interval: 1000
batch_size: 5
st_iter: 0
end_iter: 25000
lr: 0.0001
lr_decay: 7500
gamma: 0.2
weight_decay: 0.00001
num_workers: 0 # On single gpu, train.py execution only works when num worker = 0 / On multi-gpu, you can set num_worker > 0 to speed up
amp: True
loss: 2
neg_rto: 0.3
n_min_neg: 5000
data:
vis_opt: False
pseudo_vis_opt: False
output_size: 768
do_not_care_label: ['###', '']
mean: [0.485, 0.456, 0.406]
variance: [0.229, 0.224, 0.225]
enlarge_region : [0.5, 0.5] # x axis, y axis
enlarge_affinity: [0.5, 0.5]
gauss_init_size: 200
gauss_sigma: 40
watershed:
version: "skimage"
sure_fg_th: 0.75
sure_bg_th: 0.05
syn_sample: -1
custom_sample: -1
syn_aug:
random_scale:
range: [1.0, 1.5, 2.0]
option: False
random_rotate:
max_angle: 20
option: False
random_crop:
version: "random_resize_crop_synth"
option: True
random_horizontal_flip:
option: False
random_colorjitter:
brightness: 0.2
contrast: 0.2
saturation: 0.2
hue: 0.2
option: True
custom_aug:
random_scale:
range: [ 1.0, 1.5, 2.0 ]
option: False
random_rotate:
max_angle: 20
option: True
random_crop:
version: "random_resize_crop"
scale: [0.03, 0.4]
ratio: [0.75, 1.33]
rnd_threshold: 1.0
option: True
random_horizontal_flip:
option: True
random_colorjitter:
brightness: 0.2
contrast: 0.2
saturation: 0.2
hue: 0.2
option: True

test:
trained_model : null
custom_data:
test_set_size: 500
test_data_dir: "./data_root_dir/"
text_threshold: 0.75
low_text: 0.5
link_threshold: 0.2
canvas_size: 2240
mag_ratio: 1.75
poly: False
cuda: True
vis_opt: False
37 changes: 37 additions & 0 deletions trainer/craft/config/load_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import yaml
from functools import reduce

CONFIG_PATH = os.path.dirname(__file__)

def load_yaml(config_name):

with open(os.path.join(CONFIG_PATH, config_name)+ '.yaml') as file:
config = yaml.safe_load(file)

return config

class DotDict(dict):
def __getattr__(self, k):
try:
v = self[k]
except:
return super().__getattr__(k)
if isinstance(v, dict):
return DotDict(v)
return v

def __getitem__(self, k):
if isinstance(k, str) and '.' in k:
k = k.split('.')
if isinstance(k, (list, tuple)):
return reduce(lambda d, kk: d[kk], k, self)
return super().__getitem__(k)

def get(self, k, default=None):
if isinstance(k, str) and '.' in k:
try:
return self[k]
except KeyError:
return default
return super().get(k, default=default)
68 changes: 68 additions & 0 deletions trainer/craft/config/syn_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
wandb_opt: False

results_dir: "./exp/"
vis_test_dir: "./vis_result/"
data_dir:
synthtext: "/data/SynthText/"
synthtext_gt: NULL

train:
backbone : vgg
dataset: ["synthtext"]
ckpt_path: null
eval_interval: 1000
batch_size: 5
st_iter: 0
end_iter: 50000
lr: 0.0001
lr_decay: 15000
gamma: 0.2
weight_decay: 0.00001
num_workers: 4
amp: True
loss: 3
neg_rto: 1
n_min_neg: 1000
data:
vis_opt: False
output_size: 768
mean: [0.485, 0.456, 0.406]
variance: [0.229, 0.224, 0.225]
enlarge_region : [0.5, 0.5] # x axis, y axis
enlarge_affinity: [0.5, 0.5]
gauss_init_size: 200
gauss_sigma: 40
syn_sample : -1
syn_aug:
random_scale:
range: [1.0, 1.5, 2.0]
option: False
random_rotate:
max_angle: 20
option: False
random_crop:
version: "random_resize_crop_synth"
rnd_threshold : 1.0
option: True
random_horizontal_flip:
option: False
random_colorjitter:
brightness: 0.2
contrast: 0.2
saturation: 0.2
hue: 0.2
option: True

test:
trained_model: null
icdar2013:
test_set_size: 233
cuda: True
vis_opt: True
test_data_dir : "/data/ICDAR2013/"
text_threshold: 0.85
low_text: 0.5
link_threshold: 0.2
canvas_size: 960
mag_ratio: 1.5
poly: False
65 changes: 65 additions & 0 deletions trainer/craft/data/boxEnlarge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import math
import numpy as np


def pointAngle(Apoint, Bpoint):
angle = (Bpoint[1] - Apoint[1]) / ((Bpoint[0] - Apoint[0]) + 10e-8)
return angle

def pointDistance(Apoint, Bpoint):
return math.sqrt((Bpoint[1] - Apoint[1])**2 + (Bpoint[0] - Apoint[0])**2)

def lineBiasAndK(Apoint, Bpoint):

K = pointAngle(Apoint, Bpoint)
B = Apoint[1] - K*Apoint[0]
return K, B

def getX(K, B, Ypoint):
return int((Ypoint-B)/K)

def sidePoint(Apoint, Bpoint, h, w, placehold, enlarge_size):

K, B = lineBiasAndK(Apoint, Bpoint)
angle = abs(math.atan(pointAngle(Apoint, Bpoint)))
distance = pointDistance(Apoint, Bpoint)

x_enlarge_size, y_enlarge_size = enlarge_size

XaxisIncreaseDistance = abs(math.cos(angle) * x_enlarge_size * distance)
YaxisIncreaseDistance = abs(math.sin(angle) * y_enlarge_size * distance)

if placehold == 'leftTop':
x1 = max(0, Apoint[0] - XaxisIncreaseDistance)
y1 = max(0, Apoint[1] - YaxisIncreaseDistance)
elif placehold == 'rightTop':
x1 = min(w, Bpoint[0] + XaxisIncreaseDistance)
y1 = max(0, Bpoint[1] - YaxisIncreaseDistance)
elif placehold == 'rightBottom':
x1 = min(w, Bpoint[0] + XaxisIncreaseDistance)
y1 = min(h, Bpoint[1] + YaxisIncreaseDistance)
elif placehold == 'leftBottom':
x1 = max(0, Apoint[0] - XaxisIncreaseDistance)
y1 = min(h, Apoint[1] + YaxisIncreaseDistance)
return int(x1), int(y1)

def enlargebox(box, h, w, enlarge_size, horizontal_text_bool):

if not horizontal_text_bool:
enlarge_size = (enlarge_size[1], enlarge_size[0])

box = np.roll(box, -np.argmin(box.sum(axis=1)), axis=0)

Apoint, Bpoint, Cpoint, Dpoint = box
K1, B1 = lineBiasAndK(box[0], box[2])
K2, B2 = lineBiasAndK(box[3], box[1])
X = (B2 - B1)/(K1 - K2)
Y = K1 * X + B1
center = [X, Y]

x1, y1 = sidePoint(Apoint, center, h, w, 'leftTop', enlarge_size)
x2, y2 = sidePoint(center, Bpoint, h, w, 'rightTop', enlarge_size)
x3, y3 = sidePoint(center, Cpoint, h, w, 'rightBottom', enlarge_size)
x4, y4 = sidePoint(Dpoint, center, h, w, 'leftBottom', enlarge_size)
newcharbox = np.array([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
return newcharbox
Loading