This example requires the torchvision package: https://github.com/pytorch/vision/ . Please note, that SOL does not support the use of model.eval()
or model.train()
. SOL always assumes model.eval()
for running inference, and model.train()
when running training.
import torch
import sol
import torchvision.models as models
''' Training in PyTorch requires to use a loss function at the end of the network
that is normal not part of the structure. To add the loss function into the SOL
model you can embed it into a wrapper model like this.'''
class TrainingModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.m_model = model
self.m_loss = torch.nn.L1Loss()
def forward(self, x, y, z, target):
output = self.m_model(x, y, z)
loss = self.m_loss(A, target)
return sol.no_grad(output), loss
py_model = models.__dict__["alexnet"]()
input = torch.rand(32, 3, 224, 224)
target = torch.rand(32, 1000)
opt = sol.optimize(TrainingModel(py_model), input, target)
''' Run training '''
opt.train()
for batch in ...:
input, target = ...
output, loss = opt(input, target)
loss.backward()
...
''' Run validation '''
opt.eval()
with torch.no_grad():
for batch in ...:
input, target = ...
output = opt(input, target)[0]
...
''' Deploying model needs to be called on the PyTorch model. Direct deployment of SOL models is not supported yet. '''
py.load_state_dict(opt.state_dict(), strict=False)
sol.deploy(py, sol.input([1, 3, 224, 224]), libName="libmynetwork", funcName="predict", deployPath=".", sol.deployment.SharedLib, sol.device.x86)
IMPORTANT: SOL does not provide the “num_batches_tracked” value for BatchNormalization, therefore loading the state dict with
load_state_dict(…, strict=True)
will fail in these cases!
Please refer to https://pytorch.org/docs/stable/ for how these functions are used. This documentation only contains which layers, functions and tensor functionality is currently implemented within SOL.
Please see the following list of all supported operators. Be advised, that the torch.nn.XXX
modules i.e. Conv2d are implemented through the corresponding torch.nn.functional.XXX
function.