generated from Borda/kaggle_SandBox
-
Notifications
You must be signed in to change notification settings - Fork 12
/
streamlit-app.py
73 lines (54 loc) · 2.2 KB
/
streamlit-app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Simple StreamLit app for plant classification.
>> streamlit run streamlit-app.py
"""
import os
import gdown
import numpy as np
import streamlit as st
import torch
from PIL import Image
from kaggle_imgclassif.plant_pathology.augment import TORCHVISION_VALID_TRANSFORM
from kaggle_imgclassif.plant_pathology.data import PlantPathologyDM
from kaggle_imgclassif.plant_pathology.models import LitPlantPathology, MultiPlantPathology
MODEL_PATH_GDRIVE = "https://drive.google.com/uc?id=1bynbFW0FpIt7fnqzImu2UIM1PHb9-yjw"
MODEL_PATH_LOCAL = "fgvc8_resnet50.pt"
UNIQUE_LABELS = ("scab", "rust", "complex", "frog_eye_leaf_spot", "powdery_mildew", "cider_apple_rust", "healthy")
LUT_LABELS = dict(enumerate(sorted(UNIQUE_LABELS)))
@st.cache(allow_output_mutation=True)
def get_model(model_path: str = MODEL_PATH_LOCAL) -> LitPlantPathology:
if not os.path.isfile(model_path):
# download models if it missing locally
gdown.download(MODEL_PATH_GDRIVE, model_path, quiet=False)
net = torch.load(model_path)
model = MultiPlantPathology(model=net)
return model.eval()
def process_image(
model: LitPlantPathology,
img_path: str = "tests/_data/plant-pathology/test_images/8a0d7cad7053f18d.jpg",
streamlit_app: bool = False,
):
if not img_path:
return
img = Image.open(img_path)
if streamlit_app:
st.image(img)
img = TORCHVISION_VALID_TRANSFORM(img)
with torch.no_grad():
encode = model(img.unsqueeze(0))[0]
# process classification outputs
binary = np.round(encode.detach().numpy(), decimals=2)
labels = PlantPathologyDM.binary_mapping(encode, LUT_LABELS)
if streamlit_app:
st.write(", ".join(labels))
else:
print(f"Binary: {binary} >> {labels}")
if __name__ == "__main__":
st.set_option("deprecation.showfileUploaderEncoding", False)
# Upload an image and set some options for demo purposes
st.header("Plant Pathology Demo")
img_file = st.sidebar.file_uploader(label="Upload an image", type=["png", "jpg"])
# load model and ideally use cache version to speedup
model = get_model()
# run the app
process_image(model, img_file, streamlit_app=True)
# process_image(model) # dry rn with locals