A flutter plugin for pytorch model inference. Since this is still being developed, the plugin is only supported for Android. An iOS version is going to come soon
To use this plugin, add pytorch_mobile
as a dependency in your pubspec.yaml file.
Create a assets
folder with your pytorch model and labels if needed. Modify pubspec.yaml
accoringly.
assets:
- assets/models/model.pt
- assets/labels.csv
Run flutter pub get
import 'package:pytorch_mobile/pytorch_mobile.dart';
Either custom model:
Model customModel = await PyTorchMobile
.loadModel('assets/models/custom_model.pt');
Or image model:
Model imageModel = await PyTorchMobile
.loadModel('assets/models/resnet18.pt');
List prediction = await customModel
.getPrediction([1, 2, 3, 4], [1, 2, 2], DType.float32);
String prediction = await _imageModel
.getImagePrediction(image, 224, 224, "assets/labels/labels.csv");
final mean = [0.5, 0.5, 0.5];
final std = [0.5, 0.5, 0.5];
String prediction = await _imageModel
.getImagePrediction(image, 224, 224, "assets/labels/labels.csv", mean: mean, std: std);
List<List>? prediction = await _d2model
.detectron2(image, 320, 320, "assets/labels/d2go.csv", minScore: 0.4);
// prediction[0] => [left, top, right, bottom, score, label]