Skip to content

Commit

Permalink
Restructuring (#19)
Browse files Browse the repository at this point in the history
* Move python torchscripting script to utils folder.

* Add README for utils (WIP).

* Create examples directory with basic README listing examples to add.

* Move ts_inference programs to examples directiry with a copy of CMakeLists and modify library CMake to no longer build and install them.

* Move case specific pt2ts to examples, and replace with generic tool.

* Restructure and re-write ResNet example. WIP (need multi-input merge from main.)

* Update resnet example to work.

* Complete Python-Fortran example.

* Remove c and cpp from example 1 to other file.

* Add Makefile build option to ResNet example.

* Tidy ResNet README

* Update resnet example to take model as command line argument.

* Update README files as appropriate for these changes.
  • Loading branch information
jatkinson1000 authored Jun 7, 2023
1 parent 448dc72 commit dbab6a6
Show file tree
Hide file tree
Showing 23 changed files with 866 additions and 194 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ To build and install the library:
make install
```
This will place the following directories at the install location:
* `bin/` - contains example executables
* `include/` - contains header and mod files
* `lib64/` - contains cmake and `.so` files

Expand All @@ -104,7 +103,7 @@ In order to use fortran-pytorch users will typically need to follow these steps:
The trained PyTorch model needs to be exported to [TorchScript](https://pytorch.org/docs/stable/jit.html).
This can be done from within your code using the [`jit.script`](https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script) or [`jit.trace`](https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace) functionalities from within python.

If you are not familiar with these we provide a tool `pt2ts.py` as part of this distribution which contains an easily adaptable script to save your PyTorch model as Torch Script.
If you are not familiar with these we provide a tool [`pt2ts.py`](utils/pt2ts.py) as part of this distribution which contains an easily adaptable script to save your PyTorch model as TorchScript.


### 2. Using the model from Fortran
Expand Down Expand Up @@ -209,7 +208,9 @@ export LD_LIBRARY_PATH = $LD_LIBRARY_PATH:<path/to/installation>/lib64


## Examples
To follow.

Examples of how to use this library are provided in the [examples directory](examples/).
They demonstrate different functionalities and are provided with instructions to modify, build, and run as neccessary.

## License

Expand Down
19 changes: 19 additions & 0 deletions examples/1_ResNet18/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
#policy CMP0076 - target_sources source files are relative to file where target_sources is run
cmake_policy (SET CMP0076 NEW)

set(PROJECT_NAME ResNetExample)

project(${PROJECT_NAME} LANGUAGES Fortran)

# Build in Debug mode if not specified
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Debug CACHE STRING "" FORCE)
endif()

find_package(FTorch)
message(STATUS "Building with Fortran PyTorch coupling")

# Fortran example
add_executable(resnet_infer_fortran resnet_infer_fortran.f90)
target_link_libraries(resnet_infer_fortran PRIVATE FTorch::ftorch)
25 changes: 25 additions & 0 deletions examples/1_ResNet18/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# compiler
# Note - this should match the compiler that the library was built with
FC = gfortran

# compile flags
FCFLAGS = -O3 -I</path/to/installation>/include/ftorch

# link flags
LDFLAGS = -L</path/to/installation>/lib64/ -lftorch

PROGRAM = resnet_infer_fortran
SRC = resnet_infer_fortran.f90
OBJECTS = $(SRC:.f90=.o)

all: $(PROGRAM)

$(PROGRAM): $(OBJECTS)
$(FC) $(FCFLAGS) -o $@ $^ $(LDFLAGS)

%.o: %.f90
$(FC) $(FCFLAGS) $(LDFLAGS) -c $<

clean:
rm -f *.o *.mod

83 changes: 83 additions & 0 deletions examples/1_ResNet18/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Example 1 - ResNet-18

This example provides a simple but complete demonstration of how to use the library.

## Description

A python file is provided that downloads the pretrained
[ResNet-18](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html)
model from [TorchVision](https://pytorch.org/vision/stable/index.html).

A modified version of the `pt2ts.py` tool saves this ResNet-18 to TorchScript.

A series of files `resnet_infer_<LANG>` then bind from other languages to run the
TorchScript ResNet-18 model in inference mode.

## Dependencies

To run this example requires:

- cmake
- fortran compiler
- FTorch (installed as described in main package)
- python3

## Running

To run this example install fortran-pytorch-lib as described in the main documentation.
Then from this directory create a virtual environment an install the neccessary python
modules:
```
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
```

You can check that everything is working by running `resnet18.py`:
```
python3 resnet18.py
```
it should produce the result `tensor([[623, 499, 596, 111, 813]])`.

To save the pretrained ResNet-18 model to TorchScript run the modified version of the
`pt2ts.py` tool :
```
python3 pt2ts.py
```

At this point we no longer require python, so can deactivate the virtual environment:
```
deactivate
```

To call the saved ResNet-18 model from fortran we need to compile the `resnet_infer`
files.
This can be done using the included `CMakeLists.txt` as follows:
```
mkdir build
cd build
cmake .. -DFTorchDIR=<path/to/your/installation/of/library> -DCMAKE_BUILD_TYPE=Release
make
```

To run the compiled code calling the saved ResNet-18 TorchScript from Fortran run the
executable with an argument of the saved model file:
```
./resnet_infer_fortran ../saved_resnet18_model_cpu.pt
```

Alternatively we can use `make`, instead of cmake, with the included Makefile.
However, to do this you will need to modify `Makefile` to link to and include your
installation of FTorch as described in the main documentation. Also check that the compiler is the same as the one you built the Library with.
You will also likely need to add the location of the `.so` files to your `LD_LIBRARY_PATH`:
```
make
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:</path/to/library/installation>/lib64
./resnet_infer_fortran saved_resnet18_model_cpu.pt
```

## Further options

To explore the functionalities of this model:

- Try saving the model through tracing rather than scripting by modifying `pt2ts.py`
160 changes: 160 additions & 0 deletions examples/1_ResNet18/pt2ts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Load a pytorch model and convert it to TorchScript."""
from typing import Optional
import torch

# FPTLIB-TODO
# Add a module import with your model here:
# This example assumes the model architecture is in an adjacent module `my_ml_model.py`
import resnet18


def script_to_torchscript(
model: torch.nn.Module, filename: Optional[str] = "scripted_model.pt"
) -> None:
"""
Save pyTorch model to TorchScript using scripting.
Parameters
----------
model : torch.NN.Module
a pyTorch model
filename : str
name of file to save to
"""
print("Saving model using scripting...", end="")
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved
scripted_model = torch.jit.script(model)
# print(scripted_model.code)
scripted_model.save(filename)
print("done.")


def trace_to_torchscript(
model: torch.nn.Module,
dummy_input: torch.Tensor,
filename: Optional[str] = "traced_model.pt",
) -> None:
"""
Save pyTorch model to TorchScript using tracing.
Parameters
----------
model : torch.NN.Module
a pyTorch model
dummy_input : torch.Tensor
appropriate size Tensor to act as input to model
filename : str
name of file to save to
"""
print("Saving model using tracing...", end="")
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved
traced_model = torch.jit.trace(model, dummy_input)
# traced_model.save(filename)
frozen_model = torch.jit.freeze(traced_model)
## print(frozen_model.graph)
## print(frozen_model.code)
frozen_model.save(filename)
print("done.")


def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Module:
"""
Load a TorchScript from file.
Parameters
----------
filename : str
name of file containing TorchScript model
"""
model = torch.jit.load(filename)

return model


if __name__ == "__main__":
# =====================================================
# Load model and prepare for saving
# =====================================================

# FPTLIB-TODO
# Load a pre-trained PyTorch model
# Insert code here to load your model as `trained_model`.
# This example assumes my_ml_model has a method `initialize` to load
# architecture, weights, and place in inference mode
trained_model = resnet18.initialize()

# Switch off specific layers/parts of the model that behave
# differently during training and inference.
# This may have been done by the user already, so just make sure here.
trained_model.eval()

# =====================================================
# Prepare dummy input and check model runs
# =====================================================

# FPTLIB-TODO
# Generate a dummy input Tensor `dummy_input` to the model of appropriate size.
# This example assumes two inputs of size (512x40) and (512x1)
trained_model_dummy_input_1 = torch.ones(1, 3, 224, 224)

# FPTLIB-TODO
# Uncomment the following lines to save for inference on GPU (rather than CPU):
# device = torch.device('cuda')
# trained_model = trained_model.to(device)
# trained_model.eval()
# trained_model_dummy_input_1 = trained_model_dummy_input_1.to(device)
# trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device)

# FPTLIB-TODO
# Run model for dummy inputs
# If something isn't working This will generate an error
trained_model_dummy_output = trained_model(
trained_model_dummy_input_1,
)

# =====================================================
# Save model
# =====================================================

# FPTLIB-TODO
# Set the name of the file you want to save the torchscript model to:
saved_ts_filename = "saved_resnet18_model_cpu.pt"

# FPTLIB-TODO
# Save the pytorch model using either scripting (recommended where possible) or tracing
# -----------
# Scripting
# -----------
script_to_torchscript(trained_model, filename=saved_ts_filename)

# -----------
# Tracing
# -----------
# trace_to_torchscript(trained_model, trained_model_dummy_input, filename=saved_ts_filename)

print(f"Saved model to TorchScript in '{saved_ts_filename}'.")

# =====================================================
# Check model saved OK
# =====================================================

# Load torchscript and run model as a test
# FPTLIB-TODO
# Scale inputs as above and, if required, move inputs and mode to GPU
trained_model_dummy_input_1 = 2.0 * trained_model_dummy_input_1
trained_model_testing_output = trained_model(
trained_model_dummy_input_1,
)
ts_model = load_torchscript(filename=saved_ts_filename)
ts_model_output = ts_model(
trained_model_dummy_input_1,
)

if torch.all(ts_model_output.eq(trained_model_testing_output)):
print("Saved TorchScript model working as expected in a basic test.")
print("Users should perform further validation as appropriate.")
else:
raise RuntimeError(
"Saved Torchscript model is not performing as expected.\n"
"Consider using scripting if you used tracing, or investigate further."
)
2 changes: 2 additions & 0 deletions examples/1_ResNet18/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch
torchvision
50 changes: 50 additions & 0 deletions examples/1_ResNet18/resnet18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Load and run pretrained ResNet-18 from TorchVision."""

import torch
import torch.nn.functional as F
import torchvision


# Initialize everything
def initialize():
"""
Download pre-trained ResNet-18 model and prepare for inference.
Returns
-------
model : torch.nn.Module
"""

# Load a pre-trained PyTorch model
print("Loading pre-trained ResNet-18 model...", end="")
model = torchvision.models.resnet18(pretrained=True)
print("done.")

# Switch-off some specific layers/parts of the model that behave
# differently during training and inference
model.eval()

return model


def run_model(model):
"""
Run the pre-trained ResNet-18 with dummy input of ones.
Parameters
----------
model : torch.nn.Module
"""

print("Running ResNet-18 model for ones...", end="")
dummy_input = torch.ones(1, 3, 224, 224)
output = model(dummy_input)
top5 = F.softmax(output, dim=1).topk(5).indices
print("done.")

print(f"Top 5 results:\n {top5}")


if __name__ == "__main__":
rn_model = initialize()
run_model(rn_model)
Loading

0 comments on commit dbab6a6

Please sign in to comment.