0
0
PyTorchml~20 mins

REST API inference in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - REST API inference
Problem:You have a trained PyTorch image classification model. You want to serve it so users can send images via a REST API and get predictions back.
Current Metrics:Model accuracy on test set: 88%. No API implemented yet.
Issue:No REST API exists to serve the model for real-time inference.
Your Task
Create a REST API using FastAPI that loads the trained PyTorch model and returns predictions for input images. The API should accept image files and respond with predicted class labels.
Use FastAPI for the REST API.
Load the PyTorch model once at startup.
Accept images as file uploads.
Return JSON with predicted class label.
Do not retrain or change the model.
Hint 1
Hint 2
Hint 3
Hint 4
Solution
PyTorch
import io
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
import torch
from torchvision import transforms

app = FastAPI()

# Load the trained model (assume model.pth exists and model architecture is known)
class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, 3, 1)
        self.fc = torch.nn.Linear(16*30*30, 10)  # example for 32x32 input
    def forward(self, x):
        x = torch.relu(self.conv(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = SimpleCNN()
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model.eval()

# Define image preprocessing
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

class_names = [f'class_{i}' for i in range(10)]

@app.post('/predict')
async def predict(file: UploadFile = File(...)):
    image_bytes = await file.read()
    image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
    input_tensor = transform(image).unsqueeze(0)  # add batch dimension
    with torch.no_grad():
        outputs = model(input_tensor)
        _, predicted = torch.max(outputs, 1)
        predicted_class = class_names[predicted.item()]
    return JSONResponse(content={'predicted_class': predicted_class})

# To run: uvicorn filename:app --reload
Created a FastAPI app to serve the model.
Loaded the PyTorch model once at startup with map_location to ensure CPU compatibility.
Added an endpoint '/predict' to accept image uploads.
Preprocessed images with torchvision transforms.
Returned JSON response with predicted class label.
Results Interpretation

Before: No API, model only usable offline.

After: REST API accepts images and returns predictions in JSON, enabling real-time inference.

Serving a PyTorch model via a REST API allows easy integration of ML models into applications for real-time predictions.
Bonus Experiment
Add a batch prediction endpoint that accepts multiple images and returns predictions for all.
💡 Hint
Modify the API to accept a list of files, preprocess all, stack tensors, and return a list of predicted classes.