PyTorch

This example requires the torchvision package: https://github.com/pytorch/vision/ . Please note, that SOL does not support the use of model.eval() or model.train(). SOL always assumes model.eval() for running inference, and model.train() when running training.

import torch
import sol
import torchvision.models as models

''' Training in PyTorch requires to use a loss function at the end of the network
that is normal not part of the structure. To add the loss function into the SOL
model you can embed it into a wrapper model like this.'''

class TrainingModel(torch.nn.Module):
	def __init__(self, model):
		super().__init__()
		self.m_model = model
		self.m_loss  = torch.nn.L1Loss()
		
	def forward(self, x, y, z, target):
		output = self.m_model(x, y, z)
		loss = self.m_loss(A, target)
		return sol.no_grad(output), loss

py_model = models.__dict__["alexnet"]()
input	= torch.rand(32, 3, 224, 224)
target	= torch.rand(32, 1000)

opt = sol.optimize(TrainingModel(py_model), input, target)

''' Run training '''
opt.train()
for batch in ...:
	input, target = ...
	output, loss = opt(input, target)
	loss.backward()
	...

''' Run validation '''
opt.eval()
with torch.no_grad():
	for batch in ...:
		input, target = ...
		output = opt(input, target)[0]
		...
		
''' Deploying model needs to be called on the PyTorch model. Direct deployment of SOL models is not supported yet. '''
py.load_state_dict(opt.state_dict(), strict=False)
sol.deploy(py, sol.input([1, 3, 224, 224]), libName="libmynetwork", funcName="predict", deployPath=".", sol.deployment.SharedLib, sol.device.x86)

IMPORTANT: SOL does not provide the “num_batches_tracked” value for BatchNormalization, therefore loading the state dict with load_state_dict(…, strict=True) will fail in these cases!

Tested Networks

  • TorchVision
    • Alexnet
    • SqueezeNet (1.0, 1.1)
    • VGG (11, 13, 16, 19, w/ and w/o batchnorm)
    • ResNet (18, 34, 50, 101, 152)
    • DenseNet (121, 161, 169, 201)
    • Inception v3
    • MobileNet v2
    • MNasNet (0.5, 0.75, 1.0, 1.3)
    • ShuffleNet v2 (0.5, 1.0, 1.5, 2.0)
    • ResNext (50, 101)
    • WideResNet (50, 101)
    • GoogleNet
  • MobileNet
  • PyTorchic BERT
  • HuggingFace
    • BERT
    • GPT-2

Supported Layers

Please refer to https://pytorch.org/docs/stable/ for how these functions are used. This documentation only contains which layers, functions and tensor functionality is currently implemented within SOL.

Layers

Please see the following list of all supported operators. Be advised, that the torch.nn.XXX modules i.e. Conv2d are implemented through the corresponding torch.nn.functional.XXX function.