0
0
PyTorchml~5 mins

REST API inference in PyTorch

Choose your learning style9 modes available
Introduction

A REST API lets you send data to a machine learning model and get predictions back easily over the internet.

You want to let a mobile app get predictions from your ML model.
You want to share your model with others without giving them the code.
You want to use your model in a website to show live results.
You want to automate predictions from other software.
You want to test your model remotely without running it locally.
Syntax
PyTorch
from flask import Flask, request, jsonify
import torch

app = Flask(__name__)

# Load your trained PyTorch model
model = torch.load('model.pth')
model.eval()

@app.route('/predict', methods=['POST'])
def predict():
    data = request.json  # Get JSON data from POST request
    input_tensor = torch.tensor(data['input'], dtype=torch.float32).unsqueeze(0)  # Convert input list to tensor with batch dimension
    with torch.no_grad():
        output = model(input_tensor)
    prediction = output.argmax(dim=1).item()  # Get predicted class
    return jsonify({'prediction': prediction})

if __name__ == '__main__':
    app.run(debug=True)

This example uses Flask, a simple web framework for Python.

The model is loaded once and used for all requests to save time.

Examples
This command sends input data to the API and gets the prediction back.
PyTorch
curl -X POST http://localhost:5000/predict -H "Content-Type: application/json" -d '{"input": [0.5, 1.2, 3.3]}'
This Python code sends data to the REST API and prints the prediction result.
PyTorch
import requests

response = requests.post('http://localhost:5000/predict', json={'input': [0.5, 1.2, 3.3]})
print(response.json())
Sample Model

This program creates a simple PyTorch model, saves it, then loads it in a Flask app. The app listens for POST requests at /predict, runs the model on input data, and returns the predicted class.

PyTorch
from flask import Flask, request, jsonify
import torch
import torch.nn as nn

# Define a simple model for demonstration
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 2)  # 3 inputs, 2 classes
    def forward(self, x):
        return self.linear(x)

app = Flask(__name__)

# Create and save a dummy model
model = SimpleModel()
torch.save(model, 'model.pth')

# Load the model
model = torch.load('model.pth')
model.eval()

@app.route('/predict', methods=['POST'])
def predict():
    data = request.json
    input_tensor = torch.tensor(data['input'], dtype=torch.float32).unsqueeze(0)  # batch size 1
    with torch.no_grad():
        output = model(input_tensor)
    prediction = output.argmax(dim=1).item()
    return jsonify({'prediction': prediction})

if __name__ == '__main__':
    app.run(debug=False)
OutputSuccess
Important Notes

Always set your model to evaluation mode with model.eval() before inference.

Use torch.no_grad() to avoid tracking gradients during prediction, which saves memory.

For production, use a more robust server than Flask's built-in one.

Summary

A REST API lets you send data and get predictions from your ML model over the web.

Flask is a simple way to create such an API in Python.

Remember to load your model once and use model.eval() and torch.no_grad() for efficient inference.