A REST API lets you send data to a machine learning model and get predictions back easily over the internet.
REST API inference in PyTorch
from flask import Flask, request, jsonify import torch app = Flask(__name__) # Load your trained PyTorch model model = torch.load('model.pth') model.eval() @app.route('/predict', methods=['POST']) def predict(): data = request.json # Get JSON data from POST request input_tensor = torch.tensor(data['input'], dtype=torch.float32).unsqueeze(0) # Convert input list to tensor with batch dimension with torch.no_grad(): output = model(input_tensor) prediction = output.argmax(dim=1).item() # Get predicted class return jsonify({'prediction': prediction}) if __name__ == '__main__': app.run(debug=True)
This example uses Flask, a simple web framework for Python.
The model is loaded once and used for all requests to save time.
curl -X POST http://localhost:5000/predict -H "Content-Type: application/json" -d '{"input": [0.5, 1.2, 3.3]}'
import requests response = requests.post('http://localhost:5000/predict', json={'input': [0.5, 1.2, 3.3]}) print(response.json())
This program creates a simple PyTorch model, saves it, then loads it in a Flask app. The app listens for POST requests at /predict, runs the model on input data, and returns the predicted class.
from flask import Flask, request, jsonify import torch import torch.nn as nn # Define a simple model for demonstration class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(3, 2) # 3 inputs, 2 classes def forward(self, x): return self.linear(x) app = Flask(__name__) # Create and save a dummy model model = SimpleModel() torch.save(model, 'model.pth') # Load the model model = torch.load('model.pth') model.eval() @app.route('/predict', methods=['POST']) def predict(): data = request.json input_tensor = torch.tensor(data['input'], dtype=torch.float32).unsqueeze(0) # batch size 1 with torch.no_grad(): output = model(input_tensor) prediction = output.argmax(dim=1).item() return jsonify({'prediction': prediction}) if __name__ == '__main__': app.run(debug=False)
Always set your model to evaluation mode with model.eval() before inference.
Use torch.no_grad() to avoid tracking gradients during prediction, which saves memory.
For production, use a more robust server than Flask's built-in one.
A REST API lets you send data and get predictions from your ML model over the web.
Flask is a simple way to create such an API in Python.
Remember to load your model once and use model.eval() and torch.no_grad() for efficient inference.