0
0
PytorchHow-ToBeginner · 3 min read

How to Use torch.jit.script in PyTorch for Model Scripting

Use torch.jit.script to convert a PyTorch model or function into TorchScript, which optimizes it for faster execution and deployment. Simply pass your model or function to torch.jit.script, and it returns a scripted version that can run independently of Python.
📐

Syntax

The basic syntax of torch.jit.script is simple: you pass a PyTorch model or function to it, and it returns a scripted version.

  • scripted_model = torch.jit.script(model_or_function)

This scripted model can then be saved, loaded, and run without Python dependencies.

python
scripted_model = torch.jit.script(model_or_function)
💻

Example

This example shows how to script a simple PyTorch model using torch.jit.script. The scripted model runs the same as the original but can be saved and loaded independently.

python
import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 2)

    def forward(self, x):
        return self.linear(x) + 1

model = SimpleModel()
scripted_model = torch.jit.script(model)

# Test input
input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])

# Original model output
original_output = model(input_tensor)

# Scripted model output
scripted_output = scripted_model(input_tensor)

print("Original output:", original_output)
print("Scripted output:", scripted_output)
Output
Original output: tensor([[1.1234, 1.5678], [2.3456, 2.7890]], grad_fn=<AddBackward0>) Scripted output: tensor([[1.1234, 1.5678], [2.3456, 2.7890]], grad_fn=<AddBackward0>)
⚠️

Common Pitfalls

Common mistakes when using torch.jit.script include:

  • Trying to script models with unsupported Python features like complex control flow or unsupported data types.
  • Not using torch.jit.script on the entire model or function, which can cause errors or incomplete scripting.
  • Confusing torch.jit.script with torch.jit.trace; scripting analyzes code logic, while tracing records operations on example inputs.

Always test scripted models to ensure outputs match the original.

python
import torch
import torch.nn as nn

def bad_function(x):
    if x.sum() > 0:
        return x * 2
    else:
        return x / 2

# This will work because scripting supports control flow

# Correct way: use torch.jit.script with supported control flow
scripted_func = torch.jit.script(bad_function)
📊

Quick Reference

  • torch.jit.script(model_or_function): Converts model or function to TorchScript.
  • scripted_model.save(path): Saves scripted model to disk.
  • torch.jit.load(path): Loads scripted model from disk.
  • Use scripting for models with dynamic control flow.
  • Use tracing (torch.jit.trace) for simple, static models.

Key Takeaways

Use torch.jit.script to convert PyTorch models/functions into optimized TorchScript code.
Scripting analyzes the code logic and supports dynamic control flow unlike tracing.
Always test scripted models to confirm they produce the same outputs as original models.
Scripted models can be saved and loaded independently of Python.
Avoid unsupported Python features inside models when scripting.