Introduction
TorchScript export lets you save your PyTorch model so it can run fast and work without Python. This helps when you want to use the model in apps or share it easily.
Jump into concepts and practice - no test required
import torch # Convert a PyTorch model to TorchScript scripted_model = torch.jit.script(your_model) # Or trace the model with example input scripted_model = torch.jit.trace(your_model, example_input) # Save the scripted model scripted_model.save('model_scripted.pt')
scripted_model = torch.jit.script(model)
scripted_model = torch.jit.trace(model, example_input)
scripted_model.save('model_scripted.pt')import torch import torch.nn as nn # Define a simple model class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(3, 1) def forward(self, x): if x.sum() > 0: return self.linear(x) else: return self.linear(-x) model = SimpleModel() # Create example input example_input = torch.randn(1, 3) # Convert model to TorchScript using scripting (because of if condition) scripted_model = torch.jit.script(model) # Save the scripted model scripted_model.save('simple_model_scripted.pt') # Load the model back loaded_model = torch.jit.load('simple_model_scripted.pt') # Run inference output = loaded_model(example_input) print(f'Input: {example_input}') print(f'Output: {output}')
print(traced_model(torch.tensor([2.0])))?
import torch
class SimpleModel(torch.nn.Module):
def forward(self, x):
return x * 3
model = SimpleModel()
example_input = torch.tensor([1.0])
traced_model = torch.jit.trace(model, example_input)
print(traced_model(torch.tensor([2.0])))import torch
class MyModel(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x * 2
else:
return x - 2
model = MyModel()
scripted_model = torch.jit.trace(model, torch.tensor([1.0]))forward method. Which approach should you use to export it with TorchScript, and why?