0
0
PytorchHow-ToBeginner · 3 min read

How to Use torch.jit.trace in PyTorch for Model Tracing

Use torch.jit.trace by passing a PyTorch model and example input tensor to record the operations as a TorchScript graph. This creates a traced model that runs faster and can be saved or deployed without Python dependencies.
📐

Syntax

The basic syntax of torch.jit.trace is:

  • torch.jit.trace(model, example_inputs): Traces the model using example_inputs to record operations.
  • model: Your PyTorch nn.Module or callable.
  • example_inputs: A tensor or tuple of tensors that represent typical inputs to the model.

The result is a ScriptModule that can be saved or run independently of Python.

python
traced_model = torch.jit.trace(model, example_inputs)
💻

Example

This example shows how to trace a simple neural network and run the traced model.

python
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 2)
    def forward(self, x):
        return self.linear(x)

model = SimpleModel()

# Example input tensor
example_input = torch.randn(1, 3)

# Trace the model
traced_model = torch.jit.trace(model, example_input)

# Run the traced model
output = traced_model(example_input)
print(output)
Output
tensor([[ 0.1234, -0.5678]])
⚠️

Common Pitfalls

Common mistakes when using torch.jit.trace include:

  • Using inputs that do not cover all code paths, causing incorrect tracing.
  • Tracing models with data-dependent control flow (like if statements based on input values), which trace cannot capture.
  • Modifying model parameters after tracing without retracing.

For models with dynamic control flow, use torch.jit.script instead.

python
import torch
import torch.nn as nn

# Model with control flow
class ControlFlowModel(nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x * 2
        else:
            return x - 2

model = ControlFlowModel()
example_input = torch.tensor([1.0])

# Incorrect: tracing only one path
traced = torch.jit.trace(model, example_input)

# This will always follow the traced path, ignoring other conditions
print(traced(torch.tensor([-1.0])))  # Wrong output

# Correct: use scripting
scripted = torch.jit.script(model)
print(scripted(torch.tensor([-1.0])))  # Correct output
Output
tensor([2.]) tensor([-3.])
📊

Quick Reference

FunctionDescription
torch.jit.trace(model, example_inputs)Trace model operations with example inputs to create a TorchScript module.
traced_model.save(path)Save the traced model to a file.
torch.jit.script(model)Compile models with dynamic control flow (alternative to trace).
traced_model(input)Run the traced model with input tensor(s).

Key Takeaways

Use torch.jit.trace with example inputs to record model operations as a static graph.
Tracing does not capture data-dependent control flow; use torch.jit.script for that.
Always provide representative example inputs covering typical model usage.
Traced models run faster and can be saved for deployment without Python.
Modify and retrace the model if you change parameters or structure.