-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_model.py
32 lines (23 loc) · 915 Bytes
/
create_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from pathlib import Path
import torch
import torch.nn as nn
from torchvision.transforms import GaussianBlur
class SimpleModel(nn.Module):
"""A simple example torch model containing only a gaussian blur"""
def __init__(self):
super().__init__()
self.transform = GaussianBlur(kernel_size=3, sigma=(1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
y = self.transform(x)
return y
def create_model(model_path: Path):
"""Create and save an example jit model"""
model = SimpleModel()
example_input = torch.rand(1, 255, 255)
jit_model = torch.jit.trace(model, example_inputs=example_input)
print(f'Saving model to: {model_path.absolute()}')
torch.jit.save(jit_model, model_path)
if __name__ == "__main__":
model_path = Path(__file__).parent / "resources/model.pth"
create_model(model_path)