-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Move python torchscripting script to utils folder. * Add README for utils (WIP). * Create examples directory with basic README listing examples to add. * Move ts_inference programs to examples directiry with a copy of CMakeLists and modify library CMake to no longer build and install them. * Move case specific pt2ts to examples, and replace with generic tool. * Restructure and re-write ResNet example. WIP (need multi-input merge from main.) * Update resnet example to work. * Complete Python-Fortran example. * Remove c and cpp from example 1 to other file. * Add Makefile build option to ResNet example. * Tidy ResNet README * Update resnet example to take model as command line argument. * Update README files as appropriate for these changes.
- Loading branch information
1 parent
448dc72
commit dbab6a6
Showing
23 changed files
with
866 additions
and
194 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
cmake_minimum_required(VERSION 3.1 FATAL_ERROR) | ||
#policy CMP0076 - target_sources source files are relative to file where target_sources is run | ||
cmake_policy (SET CMP0076 NEW) | ||
|
||
set(PROJECT_NAME ResNetExample) | ||
|
||
project(${PROJECT_NAME} LANGUAGES Fortran) | ||
|
||
# Build in Debug mode if not specified | ||
if(NOT CMAKE_BUILD_TYPE) | ||
set(CMAKE_BUILD_TYPE Debug CACHE STRING "" FORCE) | ||
endif() | ||
|
||
find_package(FTorch) | ||
message(STATUS "Building with Fortran PyTorch coupling") | ||
|
||
# Fortran example | ||
add_executable(resnet_infer_fortran resnet_infer_fortran.f90) | ||
target_link_libraries(resnet_infer_fortran PRIVATE FTorch::ftorch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# compiler | ||
# Note - this should match the compiler that the library was built with | ||
FC = gfortran | ||
|
||
# compile flags | ||
FCFLAGS = -O3 -I</path/to/installation>/include/ftorch | ||
|
||
# link flags | ||
LDFLAGS = -L</path/to/installation>/lib64/ -lftorch | ||
|
||
PROGRAM = resnet_infer_fortran | ||
SRC = resnet_infer_fortran.f90 | ||
OBJECTS = $(SRC:.f90=.o) | ||
|
||
all: $(PROGRAM) | ||
|
||
$(PROGRAM): $(OBJECTS) | ||
$(FC) $(FCFLAGS) -o $@ $^ $(LDFLAGS) | ||
|
||
%.o: %.f90 | ||
$(FC) $(FCFLAGS) $(LDFLAGS) -c $< | ||
|
||
clean: | ||
rm -f *.o *.mod | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Example 1 - ResNet-18 | ||
|
||
This example provides a simple but complete demonstration of how to use the library. | ||
|
||
## Description | ||
|
||
A python file is provided that downloads the pretrained | ||
[ResNet-18](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html) | ||
model from [TorchVision](https://pytorch.org/vision/stable/index.html). | ||
|
||
A modified version of the `pt2ts.py` tool saves this ResNet-18 to TorchScript. | ||
|
||
A series of files `resnet_infer_<LANG>` then bind from other languages to run the | ||
TorchScript ResNet-18 model in inference mode. | ||
|
||
## Dependencies | ||
|
||
To run this example requires: | ||
|
||
- cmake | ||
- fortran compiler | ||
- FTorch (installed as described in main package) | ||
- python3 | ||
|
||
## Running | ||
|
||
To run this example install fortran-pytorch-lib as described in the main documentation. | ||
Then from this directory create a virtual environment an install the neccessary python | ||
modules: | ||
``` | ||
python3 -m venv venv | ||
source venv/bin/activate | ||
pip install -r requirements.txt | ||
``` | ||
|
||
You can check that everything is working by running `resnet18.py`: | ||
``` | ||
python3 resnet18.py | ||
``` | ||
it should produce the result `tensor([[623, 499, 596, 111, 813]])`. | ||
|
||
To save the pretrained ResNet-18 model to TorchScript run the modified version of the | ||
`pt2ts.py` tool : | ||
``` | ||
python3 pt2ts.py | ||
``` | ||
|
||
At this point we no longer require python, so can deactivate the virtual environment: | ||
``` | ||
deactivate | ||
``` | ||
|
||
To call the saved ResNet-18 model from fortran we need to compile the `resnet_infer` | ||
files. | ||
This can be done using the included `CMakeLists.txt` as follows: | ||
``` | ||
mkdir build | ||
cd build | ||
cmake .. -DFTorchDIR=<path/to/your/installation/of/library> -DCMAKE_BUILD_TYPE=Release | ||
make | ||
``` | ||
|
||
To run the compiled code calling the saved ResNet-18 TorchScript from Fortran run the | ||
executable with an argument of the saved model file: | ||
``` | ||
./resnet_infer_fortran ../saved_resnet18_model_cpu.pt | ||
``` | ||
|
||
Alternatively we can use `make`, instead of cmake, with the included Makefile. | ||
However, to do this you will need to modify `Makefile` to link to and include your | ||
installation of FTorch as described in the main documentation. Also check that the compiler is the same as the one you built the Library with. | ||
You will also likely need to add the location of the `.so` files to your `LD_LIBRARY_PATH`: | ||
``` | ||
make | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:</path/to/library/installation>/lib64 | ||
./resnet_infer_fortran saved_resnet18_model_cpu.pt | ||
``` | ||
|
||
## Further options | ||
|
||
To explore the functionalities of this model: | ||
|
||
- Try saving the model through tracing rather than scripting by modifying `pt2ts.py` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
"""Load a pytorch model and convert it to TorchScript.""" | ||
from typing import Optional | ||
import torch | ||
|
||
# FPTLIB-TODO | ||
# Add a module import with your model here: | ||
# This example assumes the model architecture is in an adjacent module `my_ml_model.py` | ||
import resnet18 | ||
|
||
|
||
def script_to_torchscript( | ||
model: torch.nn.Module, filename: Optional[str] = "scripted_model.pt" | ||
) -> None: | ||
""" | ||
Save pyTorch model to TorchScript using scripting. | ||
Parameters | ||
---------- | ||
model : torch.NN.Module | ||
a pyTorch model | ||
filename : str | ||
name of file to save to | ||
""" | ||
print("Saving model using scripting...", end="") | ||
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved | ||
scripted_model = torch.jit.script(model) | ||
# print(scripted_model.code) | ||
scripted_model.save(filename) | ||
print("done.") | ||
|
||
|
||
def trace_to_torchscript( | ||
model: torch.nn.Module, | ||
dummy_input: torch.Tensor, | ||
filename: Optional[str] = "traced_model.pt", | ||
) -> None: | ||
""" | ||
Save pyTorch model to TorchScript using tracing. | ||
Parameters | ||
---------- | ||
model : torch.NN.Module | ||
a pyTorch model | ||
dummy_input : torch.Tensor | ||
appropriate size Tensor to act as input to model | ||
filename : str | ||
name of file to save to | ||
""" | ||
print("Saving model using tracing...", end="") | ||
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved | ||
traced_model = torch.jit.trace(model, dummy_input) | ||
# traced_model.save(filename) | ||
frozen_model = torch.jit.freeze(traced_model) | ||
## print(frozen_model.graph) | ||
## print(frozen_model.code) | ||
frozen_model.save(filename) | ||
print("done.") | ||
|
||
|
||
def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Module: | ||
""" | ||
Load a TorchScript from file. | ||
Parameters | ||
---------- | ||
filename : str | ||
name of file containing TorchScript model | ||
""" | ||
model = torch.jit.load(filename) | ||
|
||
return model | ||
|
||
|
||
if __name__ == "__main__": | ||
# ===================================================== | ||
# Load model and prepare for saving | ||
# ===================================================== | ||
|
||
# FPTLIB-TODO | ||
# Load a pre-trained PyTorch model | ||
# Insert code here to load your model as `trained_model`. | ||
# This example assumes my_ml_model has a method `initialize` to load | ||
# architecture, weights, and place in inference mode | ||
trained_model = resnet18.initialize() | ||
|
||
# Switch off specific layers/parts of the model that behave | ||
# differently during training and inference. | ||
# This may have been done by the user already, so just make sure here. | ||
trained_model.eval() | ||
|
||
# ===================================================== | ||
# Prepare dummy input and check model runs | ||
# ===================================================== | ||
|
||
# FPTLIB-TODO | ||
# Generate a dummy input Tensor `dummy_input` to the model of appropriate size. | ||
# This example assumes two inputs of size (512x40) and (512x1) | ||
trained_model_dummy_input_1 = torch.ones(1, 3, 224, 224) | ||
|
||
# FPTLIB-TODO | ||
# Uncomment the following lines to save for inference on GPU (rather than CPU): | ||
# device = torch.device('cuda') | ||
# trained_model = trained_model.to(device) | ||
# trained_model.eval() | ||
# trained_model_dummy_input_1 = trained_model_dummy_input_1.to(device) | ||
# trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device) | ||
|
||
# FPTLIB-TODO | ||
# Run model for dummy inputs | ||
# If something isn't working This will generate an error | ||
trained_model_dummy_output = trained_model( | ||
trained_model_dummy_input_1, | ||
) | ||
|
||
# ===================================================== | ||
# Save model | ||
# ===================================================== | ||
|
||
# FPTLIB-TODO | ||
# Set the name of the file you want to save the torchscript model to: | ||
saved_ts_filename = "saved_resnet18_model_cpu.pt" | ||
|
||
# FPTLIB-TODO | ||
# Save the pytorch model using either scripting (recommended where possible) or tracing | ||
# ----------- | ||
# Scripting | ||
# ----------- | ||
script_to_torchscript(trained_model, filename=saved_ts_filename) | ||
|
||
# ----------- | ||
# Tracing | ||
# ----------- | ||
# trace_to_torchscript(trained_model, trained_model_dummy_input, filename=saved_ts_filename) | ||
|
||
print(f"Saved model to TorchScript in '{saved_ts_filename}'.") | ||
|
||
# ===================================================== | ||
# Check model saved OK | ||
# ===================================================== | ||
|
||
# Load torchscript and run model as a test | ||
# FPTLIB-TODO | ||
# Scale inputs as above and, if required, move inputs and mode to GPU | ||
trained_model_dummy_input_1 = 2.0 * trained_model_dummy_input_1 | ||
trained_model_testing_output = trained_model( | ||
trained_model_dummy_input_1, | ||
) | ||
ts_model = load_torchscript(filename=saved_ts_filename) | ||
ts_model_output = ts_model( | ||
trained_model_dummy_input_1, | ||
) | ||
|
||
if torch.all(ts_model_output.eq(trained_model_testing_output)): | ||
print("Saved TorchScript model working as expected in a basic test.") | ||
print("Users should perform further validation as appropriate.") | ||
else: | ||
raise RuntimeError( | ||
"Saved Torchscript model is not performing as expected.\n" | ||
"Consider using scripting if you used tracing, or investigate further." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
torch | ||
torchvision |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
"""Load and run pretrained ResNet-18 from TorchVision.""" | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import torchvision | ||
|
||
|
||
# Initialize everything | ||
def initialize(): | ||
""" | ||
Download pre-trained ResNet-18 model and prepare for inference. | ||
Returns | ||
------- | ||
model : torch.nn.Module | ||
""" | ||
|
||
# Load a pre-trained PyTorch model | ||
print("Loading pre-trained ResNet-18 model...", end="") | ||
model = torchvision.models.resnet18(pretrained=True) | ||
print("done.") | ||
|
||
# Switch-off some specific layers/parts of the model that behave | ||
# differently during training and inference | ||
model.eval() | ||
|
||
return model | ||
|
||
|
||
def run_model(model): | ||
""" | ||
Run the pre-trained ResNet-18 with dummy input of ones. | ||
Parameters | ||
---------- | ||
model : torch.nn.Module | ||
""" | ||
|
||
print("Running ResNet-18 model for ones...", end="") | ||
dummy_input = torch.ones(1, 3, 224, 224) | ||
output = model(dummy_input) | ||
top5 = F.softmax(output, dim=1).topk(5).indices | ||
print("done.") | ||
|
||
print(f"Top 5 results:\n {top5}") | ||
|
||
|
||
if __name__ == "__main__": | ||
rn_model = initialize() | ||
run_model(rn_model) |
Oops, something went wrong.