PyTorch

This example requires the torchvision package: https://github.com/pytorch/vision/ . Please note, that SOL does not support the use of model.eval() or model.train(). SOL always assumes model.eval() for running inference, and model.train() when running training.

import torch
import sol
import torchvision.models as models

''' Training in PyTorch requires to use a loss function at the end of the network
that is normal not part of the structure. To add the loss function into the SOL
model you can embed it into a wrapper model like this.'''

class TrainingModel(torch.nn.Module):
	def __init__(self, model):
		super().__init__()
		self.m_model = model
		self.m_loss  = torch.nn.L1Loss()
		
	def forward(self, x, y, z, target):
		output = self.m_model(x, y, z)
		loss = self.m_loss(A, target)
		return sol.no_grad(output), loss

py_model = models.__dict__["alexnet"]()
input	= torch.rand(32, 3, 224, 224)
target	= torch.rand(32, 1000)

opt = sol.optimize(TrainingModel(py_model), input, target)

''' Run training '''
opt.train()
for batch in ...:
	input, target = ...
	output, loss = opt(input, target)
	loss.backward()
	...

''' Run validation '''
opt.eval()
with torch.no_grad():
	for batch in ...:
		input, target = ...
		output = opt(input, target)[0]
		...
		
''' Deploying model needs to be called on the PyTorch model. Direct deployment of SOL models is not supported yet. '''
py.load_state_dict(opt.state_dict(), strict=False)
sol.deploy(py, sol.input([1, 3, 224, 224]), libName="libmynetwork", funcName="predict", deployPath=".", sol.deployment.SharedLib, sol.device.x86)

IMPORTANT: SOL does not provide the “num_batches_tracked” value for BatchNormalization, therefore loading the state dict with load_state_dict(…, strict=True) will fail in these cases!

F.A.Q.

Error:The SOL model returns more outputs than the PyTorch model.
Solution: This error occurs, i.e., in TorchVisions Inception V3 or GoogleNet. These models return 1 output in inference and 2 outputs in training mode. SOL relies on the TorchScript parser. Unfortunately the TorchVision models are build in a way that hides the change of output behavior from TorchScript. However, you can implement this yourself as follows:
from torchvision import models

class Wrap(torch.nn.Module):
	def __init__(self, model):
		super().__init__()
		self.model = model

	def forward(self, x):
		out = self.model(x)
		if torch.jit.is_scripting():
			return (out[0], out[1]) if self.training else (out[0], None)
		return (out[0], out[1]) if self.training else (out, None)

model = Wrap(models.inception_v3())

# use only one output
model.training = False
sol_model = sol.optimize(model, ...)

# use two outputs
model.training = True
sol_model = sol.optimize(model, ...)

SOL currently does not support to dynamically switch between these two modes and requires to compile the model for each mode separately.

Supported Layers

Please refer to https://pytorch.org/docs/stable/ for how these functions are used. This documentation only contains which layers, functions and tensor functionality is currently implemented within SOL.

Layers

Please see the following list of all supported operators. Since v0.4.2 we use TorchScript for parsing the neural network.

  • aten::Float
  • aten::IntImplicit
  • aten::ScalarImplicit
  • aten::__getitem__
  • aten::__is__
  • aten::__isnot__
  • aten::__not__
  • aten::_set_item
  • aten::abs
  • aten::acos
  • aten::acosh
  • aten::adaptive_avg_pool1d
  • aten::adaptive_avg_pool2d
  • aten::adaptive_avg_pool3d
  • aten::adaptive_max_pool1d
  • aten::adaptive_max_pool2d
  • aten::adaptive_max_pool3d
  • aten::add
  • aten::add_
  • aten::addbmm
  • aten::addcdiv
  • aten::addcmul
  • aten::addmm
  • aten::append
  • aten::arange
  • aten::asin
  • aten::asinh
  • aten::atan
  • aten::atanh
  • aten::avg_pool1d
  • aten::avg_pool2d
  • aten::avg_pool3d
  • aten::batch_norm
  • aten::broadcast_tensors
  • aten::cat
  • aten::ceil
  • aten::celu
  • aten::chunk
  • aten::constant_pad_nd
  • aten::contiguous
  • aten::conv1d
  • aten::conv2d
  • aten::cos
  • aten::cosh
  • aten::dict
  • aten::dim
  • aten::div
  • aten::dropout
  • aten::dropout_
  • aten::elu
  • aten::embedding
  • aten::eq
  • aten::erf
  • aten::exp
  • aten::expand
  • aten::flatten
  • aten::floor
  • aten::floordiv
  • aten::fmod
  • aten::format
  • aten::ge
  • aten::gelu
  • aten::gru
  • aten::gru_cell
  • aten::gt
  • aten::hardshrink
  • aten::hardtanh
  • aten::hardtanh_
  • aten::isfinite
  • aten::isinf
  • aten::isnan
  • aten::l1_loss
  • aten::layer_norm
  • aten::le
  • aten::leaky_relu
  • aten::len
  • aten::linear
  • aten::list
  • aten::log
  • aten::log10
  • aten::log2
  • aten::log_sigmoid
  • aten::log_softmax
  • aten::logaddexp
  • aten::logaddexp2
  • aten::logical_and
  • aten::logical_or
  • aten::logical_xor
  • aten::lstm
  • aten::lstm_cell
  • aten::lt
  • aten::matmul
  • aten::max
  • aten::max_pool1d
  • aten::max_pool1d_with_indices
  • aten::max_pool2d
  • aten::max_pool2d_with_indices
  • aten::max_pool3d
  • aten::max_pool3d_with_indices
  • aten::maximum
  • aten::mean
  • aten::min
  • aten::minimum
  • aten::mse_loss
  • aten::mul
  • aten::narrow
  • aten::ne
  • aten::neg
  • aten::neg_
  • aten::ones
  • aten::ones_like
  • aten::permute
  • aten::pow
  • aten::prelu
  • aten::reciprocal
  • aten::relu
  • aten::relu6
  • aten::relu_
  • aten::remainder
  • aten::rnn_relu
  • aten::rnn_relu_cell
  • aten::rnn_tanh
  • aten::rnn_tanh_cell
  • aten::rsqrt
  • aten::select
  • aten::selu
  • aten::sigmoid
  • aten::sign
  • aten::sin
  • aten::sinh
  • aten::size
  • aten::slice
  • aten::smooth_l1_loss
  • aten::softmax
  • aten::softmin
  • aten::softplus
  • aten::softshrink
  • aten::split
  • aten::sqrt
  • aten::squeeze
  • aten::stack
  • aten::sub
  • aten::sum
  • aten::tan
  • aten::tanh
  • aten::transpose
  • aten::unsqueeze
  • aten::view
  • aten::warn
  • aten::where
  • aten::zeros
  • aten::zeros_like
  • prim::CallFunction
  • prim::CallMethod
  • prim::Constant
  • prim::GetAttr
  • prim::If
  • prim::ListConstruct
  • prim::ListIndex
  • prim::ListUnpack
  • prim::Loop
  • prim::RaiseException
  • prim::SetAttr
  • prim::TupleConstruct
  • prim::TupleIndex
  • prim::TupleUnpack
  • prim::Uninitialized
  • prim::device
  • prim::dtype
  • prim::isinstance
  • prim::unchecked_cast