Serving Trained TensorFlow Model with REST API using Flask
Introduction
This article guides you through the process of deploying a trained TensorFlow model as a REST API using Flask, a popular Python web framework. This allows you to seamlessly integrate your machine learning models into web applications or other systems.
Steps
- Install Required Libraries
- Prepare Your Trained Model
- Ensure your TensorFlow model is trained and saved.
- You can use either a SavedModel or a Keras model.
- Create the Flask Application
- Test the API
- Run the Flask application:
python app.py
- Use a tool like Postman or curl to send a POST request to
http://localhost:5000/predict
with the input data in JSON format. - Local Deployment: Run the Flask app locally.
- Cloud Deployment: Use services like AWS Lambda, Google Cloud Functions, or Heroku to deploy your API to the cloud.
- Docker: Containerize your app and deploy it to a Docker container.
pip install tensorflow flask
from flask import Flask, request, jsonify import tensorflow as tf app = Flask(__name__) # Load your trained model model = tf.keras.models.load_model('path/to/your/model') @app.route('/', methods=['GET']) def index(): return "Welcome to the TensorFlow Model API" @app.route('/predict', methods=['POST']) def predict(): data = request.get_json() # Preprocess the data # ... prediction = model.predict(data) # Postprocess the prediction # ... return jsonify({'prediction': prediction.tolist()}) if __name__ == '__main__': app.run(debug=True)
Example Code: Image Classification
from flask import Flask, request, jsonify from PIL import Image import numpy as np import tensorflow as tf app = Flask(__name__) # Load your trained model model = tf.keras.models.load_model('path/to/your/model') @app.route('/predict', methods=['POST']) def predict(): # Get the image from the request image_data = request.files['image'].read() image = Image.open(io.BytesIO(image_data)).convert('RGB') image = image.resize((224, 224)) image = np.array(image) / 255.0 image = np.expand_dims(image, axis=0) # Make a prediction prediction = model.predict(image) class_names = ['cat', 'dog'] # Example class names predicted_class = class_names[np.argmax(prediction)] # Return the prediction as a JSON response return jsonify({'predicted_class': predicted_class}) if __name__ == '__main__': app.run(debug=True)
Output
{ "predicted_class": "cat" }
Deployment
Conclusion
By following these steps, you can effectively serve a trained TensorFlow model through a REST API using Flask. This enables you to leverage your models in various applications and enhance the functionality of your web projects.