Skip to content

Commit

Permalink
changed env names and added auth header
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrandstaetter committed May 22, 2024
1 parent 8edcc6a commit 708bac4
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 62 deletions.
2 changes: 2 additions & 0 deletions serving-service/.env_example
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
PERSISTENCE_SERVICE_URI=http://persistence-service.mlaas.svc.cluster.local:5000
TENANT=Auth_type <token>
2 changes: 0 additions & 2 deletions serving-service/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ services:
context: .
dockerfile: ./Dockerfile
image: mlaas-service-serving:latest
# environment:
# - PERSISTENCE_SERVICE_URL=host.docker.internal:5000
ports:
- "5001:5001"
networks:
Expand Down
84 changes: 51 additions & 33 deletions serving-service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
# Create a Flask app
app = Flask(__name__)

persistence_service_url = None
persistence_service_uri = None
auth_header = None
model = None
config = None

#---------------------------------------------------------------------
#-----------------------------functions-------------------------------
#---------------------------------------------------------------------
# ---------------------------------------------------------------------
# -----------------------------functions-------------------------------
# ---------------------------------------------------------------------


def setup():
# Define the required environment variables
REQUIRED_ENV_VARS = ["PERSISTENCE_SERVICE_URL"]
REQUIRED_ENV_VARS = ["PERSISTENCE_SERVICE_URI", "TENANT"]

# Set up logging
logging.basicConfig(
Expand All @@ -40,72 +42,85 @@ def setup():
sys.exit(f"Error: Environment variable {var} is not set")

# Set Persistence Service URL
global persistence_service_url
persistence_service_url = os.getenv('PERSISTENCE_SERVICE_URL')
global persistence_service_uri
persistence_service_uri = os.getenv("PERSISTENCE_SERVICE_URI")

# Set auth header
global auth_header
auth_header = os.getenv("TENANT")


def load_model():
try:
response = requests.get(f"http://{persistence_service_url}/model", stream=True)
headers = {"Authorization": auth_header}
response = requests.get(persistence_service_uri, headers=headers, stream=True)
if response.status_code == 200:
# Use BytesIO for in-memory bytes buffer to store the zip content
# Use BytesIO for in-memory bytes buffer to store the zip content
zip_file_bytes = io.BytesIO(response.content)

# Extract the zip file contents
with zipfile.ZipFile(zip_file_bytes, 'r') as zip_ref:
with zipfile.ZipFile(zip_file_bytes, "r") as zip_ref:
# Optionally specify a path where to extract
extract_path = "./model"
zip_ref.extractall(extract_path)

# Load the TensorFlow model
global model
model_directory = os.path.join(extract_path, 'my_model.keras')
model_directory = os.path.join(extract_path, "my_model.keras")
model = tf.keras.models.load_model(model_directory)

# Load the configuration JSON
config_path = os.path.join(extract_path, 'config.json')
with open(config_path, 'r') as json_file:
config_path = os.path.join(extract_path, "config.json")
with open(config_path, "r") as json_file:
global config
config = json.load(json_file)
else:
logging.error(f"Unexpected response from persictence service: {str(response.status_code)}")
logging.error(
f"Unexpected response from persictence service: {str(response.status_code)}"
)
sys.exit(f"Unexpected response from persictence service")

except Exception as e:
logging.error(f"Unexpected error occurred: {str(e)}")
sys.exit(f"Unexpected error occurred when loading model")


def _inference(image):
try:
predictions = model.predict(image)
score = tf.nn.softmax(predictions[0])
return (
"This image most likely belongs to {} with a {:.2f} percent confidence."
.format(config['class_names'][np.argmax(score)], 100 * np.max(score))
), 200
"This image most likely belongs to {} with a {:.2f} percent confidence.".format(
config["class_names"][np.argmax(score)], 100 * np.max(score)
)
), 200
except Exception as e:
logging.error(f"Unexpected error occurred: {str(e)}")
return jsonify({"error": "An unexpected error occurred"}), 500


def _parse_and_infer(request):
if 'file' not in request.files:
if "file" not in request.files:
logging.error("No file part in infer request")
return jsonify({"error": "No file part"}), 400
file = request.files['file']
if file.filename == '':
file = request.files["file"]
if file.filename == "":
logging.error("Infer request has empty file")
return jsonify({"error": "Infer request has empty file"}), 400

try:
# Save file
filename = secure_filename(file.filename)
save_path = os.path.join('uploads', filename)
save_path = os.path.join("uploads", filename)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
file.save(save_path)

# Load and preprocess image
image = tf.keras.preprocessing.image.load_img(save_path, target_size=(config['img_height'], config['img_width']))
image = tf.keras.preprocessing.image.load_img(
save_path, target_size=(config["img_height"], config["img_width"])
)
img_array = tf.keras.utils.img_to_array(image)
img_array = tf.expand_dims(img_array, 0) # Create a batch
img_array = tf.expand_dims(img_array, 0) # Create a batch

# Remove file
if os.path.exists(save_path):
Expand All @@ -117,26 +132,29 @@ def _parse_and_infer(request):
return jsonify({"error": "Error processing image"}), 500


#---------------------------------------------------------------------
#---------------------------------API---------------------------------
#---------------------------------------------------------------------
# ---------------------------------------------------------------------
# ---------------------------------API---------------------------------
# ---------------------------------------------------------------------


@app.route('/infer', methods=['POST'])
@app.route("/infer", methods=["POST"])
def inference():
return _parse_and_infer(request)

@app.route('/', methods=['GET'])

@app.route("/", methods=["GET"])
def hello_world():
return "Hello, World!"


def create_app(config):
setup()
load_model()
app.config.from_object(config)
return app

if __name__ == '__main__':

if __name__ == "__main__":
setup()
load_model()
app.run(host='0.0.0.0', port=5001, debug=True)
app.run(host="0.0.0.0", port=5001, debug=True)
3 changes: 2 additions & 1 deletion serving-service/test/docker-compose-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ services:
dockerfile: ./test/Dockerfile.test
image: serving-test
environment:
- PERSISTENCE_SERVICE_URL=127.0.0.1:5000
- PERSISTENCE_SERVICE_URI=http://persistence-service.mlaas.svc.cluster.local:5000
- TENANT=Auth_type <token>
volumes:
- ./test-results:/serving/test/test-results
networks:
Expand Down
60 changes: 35 additions & 25 deletions serving-service/test/serving-test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,50 @@
import os
from ..main import create_app

TEST_FILE_PATH = os.path.join('data', 'dog.jpg')
MOCK_MODEL_PATH = os.path.join('data', "model_package.zip")
TEST_FILE_PATH = os.path.join("data", "dog.jpg")
MOCK_MODEL_PATH = os.path.join("data", "model_package.zip")


@pytest.fixture(scope="module")
def mocked_requests():
with open(MOCK_MODEL_PATH, 'rb') as zip_file:
zip_file_content = zip_file.read()

with requests_mock.Mocker() as m:
# Mock the GET request to return the zip file
m.get('http://127.0.0.1:5000/model',
content=zip_file_content,
headers={'Content-Type': 'application/zip'},
status_code=200)
yield m
with open(MOCK_MODEL_PATH, "rb") as zip_file:
zip_file_content = zip_file.read()

with requests_mock.Mocker() as m:
# Mock the GET request to return the zip file
m.get(
"http://persistence-service.mlaas.svc.cluster.local:5000",
content=zip_file_content,
headers={"Content-Type": "application/zip"},
status_code=200,
)
yield m


@pytest.fixture(scope="module")
def client(mocked_requests):
app = create_app({"TESTING": True})
with app.test_client() as client:
yield client
app = create_app({"TESTING": True})
with app.test_client() as client:
yield client


# -------------------Test cases for the endpoints-------------------

#-------------------Test cases for the endpoints-------------------

def test_should_status_code_ok(client):
response = client.get('/')
assert response.status_code == 200
response = client.get("/")
assert response.status_code == 200


def test_should_return_hello_world(client):
response = client.get('/')
data = response.data.decode()
assert data == 'Hello, World!'

response = client.get("/")
data = response.data.decode()
assert data == "Hello, World!"


def test_should_return_inference(client):
response = client.post('/infer', data={"file": open(TEST_FILE_PATH, "rb")})
assert response.status_code == 200
assert 'This image most likely belongs to' in response.data.decode(), "Response does not contain the expected string"
response = client.post("/infer", data={"file": open(TEST_FILE_PATH, "rb")})
assert response.status_code == 200
assert (
"This image most likely belongs to" in response.data.decode()
), "Response does not contain the expected string"
2 changes: 1 addition & 1 deletion serving-service/test/test-results/pytest_results.xml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
<?xml version="1.0" encoding="utf-8"?><testsuites><testsuite name="pytest" errors="0" failures="0" skipped="0" tests="3" time="4.130" timestamp="2024-05-14T21:17:40.371898" hostname="df11e520772b"><testcase classname="serving-test" name="test_should_status_code_ok" time="0.823" /><testcase classname="serving-test" name="test_should_return_hello_world" time="0.001" /><testcase classname="serving-test" name="test_should_return_inference" time="0.165" /></testsuite></testsuites>
<?xml version="1.0" encoding="utf-8"?><testsuites><testsuite name="pytest" errors="0" failures="0" skipped="0" tests="3" time="3.442" timestamp="2024-05-22T15:59:23.366104" hostname="fe12f5d4efc8"><testcase classname="serving-test" name="test_should_status_code_ok" time="0.648" /><testcase classname="serving-test" name="test_should_return_hello_world" time="0.001" /><testcase classname="serving-test" name="test_should_return_inference" time="0.141" /></testsuite></testsuites>

0 comments on commit 708bac4

Please sign in to comment.