PyTorch IntegrationΒΆ
In this document, we will show how to directly compile PyTorch models to Allo. First, users can define a PyTorch module as usual:
import torch
import torch.nn.functional as F
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x, y):
x = x + y
x = F.relu(x)
return x
model = Model()
model.eval()
Then, users can compile the PyTorch model to Allo by using the allo.frontend.from_pytorch
API:
import allo
example_inputs = [torch.rand(1, 3, 10, 10), torch.rand(1, 3, 10, 10)]
llvm_mod = allo.frontend.from_pytorch(model, example_inputs=example_inputs)
Then, we can use the generated Allo LLVM module as usual by passing in the NumPy inputs:
golden = model(*example_inputs)
np_inputs = [x.detach().numpy() for x in example_inputs]
res = llvm_mod(*np_inputs)
torch.testing.assert_close(res, golden.detach().numpy())
print("Passed!")
The process should be very similar to the original Allo workflow.
The default target is LLVM. We can also change the backend to other compilers such as Vitis HLS by specifying the target
:
mod = allo.frontend.from_pytorch(model, example_inputs=example_inputs, target="vhls")
print(mod.hls_code)