Skip to content

Commit

Permalink
Clean main, suppr USE, TF/TF HUB
Browse files Browse the repository at this point in the history
  • Loading branch information
BiGHeaDMaX committed Nov 7, 2023
1 parent de56b75 commit 7d2830b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 34 deletions.
68 changes: 37 additions & 31 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,46 @@
###########
# IMPORTS #
###########

# Imports spécifiques à l'API
from fastapi import FastAPI
from fastapi.responses import HTMLResponse

# Imports pour notre modèle de prédiction
import pickle
import tensorflow_hub as hub

# Pour transformer les données reçu en array
import numpy as np

# Instanciation de notre API
##############################
# Instanciation de notre API #
##############################

app = FastAPI()

# Chargement du modèle choisi précédemment enregistré
##################################################
# Chargement du modèle de prédiction préentraîné #
##################################################

# Contient le modèle KNN et le MultiLabelBinarizer
with open('KNeighborsClassifier_and_bin.pkl', 'rb') as fichier:
model_and_bin = pickle.load(fichier)

# Chargement du modèle USE pré-entraîné (s'il n'avait pas été chargé précédemment)
#embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")
##########################
# Fonction de prédiction #
##########################

# Fonction de prédiction
def tags_predict(document):

ddd = document.replace('[', '').replace(']', '')
ddd = ddd.split()
ddd = np.array([float(item) for item in ddd])


# Encodage du document avec USE
#document_USE = embed([document])
#document_USE = ddd
# Prédiction avec le modèle entraîné
#prediction = model_and_bin[0].predict(ddd)
# Décodage de la prédiction pour avoir les tags sous forme textuelle
#tags = ', '.join(model_and_bin[1].inverse_transform(prediction)[0])


prediction = model_and_bin[0].predict([ddd])
# Conversion des données réçue en arrray
# qui pourra alors être passé dans le modèle
doc = document.replace('[', '').replace(']', '')
doc = doc.split()
doc = np.array([float(item) for item in doc])

# Prédiction sur les données reçues
prediction = model_and_bin[0].predict([doc])
# Inverse transform des prédictions du modèle
# pour passer de données encodées avec MultiLabelBinarizer
# à des tags sous forme textuelle
tags = ', '.join(model_and_bin[1].inverse_transform(prediction)[0])

return tags
Expand All @@ -44,6 +49,8 @@ def tags_predict(document):
# Page d'accueil #
##################
# Page d'accueil avec un formulaire pour entrer un document
# Pour la version dans le cloud, pour des raisons de limitation techniques,
# le document devra être encodé avec USE en local avant d'être transmis.
@app.get("/", response_class=HTMLResponse)
def prediction_form():
"""
Expand Down Expand Up @@ -88,6 +95,9 @@ def prediction_form():
#########################
# Page de résultats web #
#########################
# Page de résultats avec affichage des prédictions
# Pour la version dans le cloud, pour des raisons de limitation techniques,
# le document devra être encodé avec USE en local avant d'être transmis.
@app.get("/predict_web", response_class=HTMLResponse)
def prediction_result_web(document: str):
"""
Expand Down Expand Up @@ -130,19 +140,15 @@ def prediction_result_web(document: str):
###############################
# Renvoi des résultats en STR #
###############################
# Pour un accès direct, sans passer par formulaire HTML.
# Pour la version dans le cloud, pour des raisons de limitation techniques,
# le document devra être encodé avec USE en local avant d'être transmis.
@app.get("/predict")
def prediction_result(document):
#async def prediction_result(document: str):
"""
- Fonction de prédiction qui retourne uniquement
les tags prédits sous forme de string.
- À utiliser comme API depuis un autre programme.
"""
predicted_tags = tags_predict(document)
return predicted_tags

# Si le fichier est exécuté en tant que
# programme principal et non importé
#if __name__ == '__main__':
# uvicorn.run('main:app')

return predicted_tags
4 changes: 1 addition & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
fastapi==0.103.1
uvicorn==0.23.2
scikit-learn==1.2.2
numpy
tensorflow==2.14.0
tensorflow-hub==0.14.0
numpy

0 comments on commit 7d2830b

Please sign in to comment.