Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add compile_method flag and add other framework artifact types #40

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
41 changes: 10 additions & 31 deletions CHANGELOG.md
rbavery marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,23 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased](https://github.com/stac-extensions/mlm/tree/main)
## [v1.4.0](https://github.com/stac-extensions/mlm/tree/v1.4.0)

### Added
- Add better descriptions about required and recommended *MLM Asset Roles* and their implications
(fixes [#54](https://github.com/stac-extensions/mlm/issues/54)).
- Add explicit check of `value_scaling` sub-fields `minimum`, `maximum`, `mean`, `stddev`, etc. for
corresponding `type` values `min-max` and `z-score` that depend on it.
- Allow different `value_scaling` operations per band/channel/dimension as needed by the model.
- Allow a `processing:expression` for a band/channel/dimension-specific `value_scaling` operation,
granting more flexibility in the definition of input preparation in contrast to having it applied
for the entire input (but still possible).
- mlm:compile_method with options 'aot' for Ahead of Time Compilation, 'jit' for Just-In Time Compilation

### Changed
- Explicitly disallow `mlm:name`, `mlm:input`, `mlm:output` and `mlm:hyperparameters` at the Asset level.
These fields describe the model as a whole and should therefore be defined in Item properties.
- Moved `norm_type` to `value_scaling` object to better reflect the expected operation, which could be another
operation than what is typically known as "normalization" or "standardization" techniques in machine learning.
- Moved `statistics` to `value_scaling` object to better reflect their mutual `type` and additional
properties dependencies.
- moved mlm:artifact_type field value descriptions that are framework specific to best-practices section.
- expanded suggested mlm:artifact_type values to include Tensorflow/Keras

### Deprecated
- n/a

### Removed
- Removed `norm_type` enum values that were ambiguous regarding their expected result.
Instead, a `processing:expression` should be employed to explicitly define the calculation they represent.
- Removed `norm_clip` property. It is now represented under `value_scaling` objects with a
corresponding `type` definition.
- Removed `norm_by_channel` from `mlm:input` objects. If rescaling (previously normalization in the documentation)
is a single value, broadcasting to the relevant bands should be performed implicitly.
Otherwise, the amount of `value_scaling` objects should match the number of bands or channels involved in the input.
- n/a

### Fixed
- Fix missing `mlm:artifact_type` property check for a Model Asset definition
(fixes <https://github.com/stac-extensions/mlm/issues/42>).
The `mlm:artifact_type` is now mutually and exclusively required by the corresponding Asset with `mlm:model` role.
- Fix check of disallowed unknown/undefined `mlm:`-prefixed fields
(fixes [#41](https://github.com/stac-extensions/mlm/issues/41)).
- n/a

## [v1.3.0](https://github.com/stac-extensions/mlm/tree/v1.3.0)

Expand Down Expand Up @@ -73,7 +52,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
when a `mlm:input` references names in `bands` are now properly validated.
- Fix the examples using `raster:bands` incorrectly defined in STAC Item properties.
The correct use is for them to be defined under the STAC Asset using the `mlm:model` role.
- Fix the [EuroSAT ResNet pydantic example](stac_model/examples.py) that incorrectly referenced some `bands`
- Fix the [EuroSAT ResNet pydantic example](./stac_model/examples.py) that incorrectly referenced some `bands`
in its `mlm:input` definition without providing any definition of those bands. The `eo:bands` properties have
been added to the corresponding `model` Asset using
the [`pystac.extensions.eo`](https://github.com/stac-utils/pystac/blob/main/pystac/extensions/eo.py) utilities.
Expand Down Expand Up @@ -134,7 +113,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- more [Task Enum](README.md#task-enum) tasks
- [Model Output Object](README.md#model-output-object)
- batch_size and hardware summary
- [`mlm:accelerator`, `mlm:accelerator_constrained`, `mlm:accelerator_summary`](README.md#accelerator-type-enum)
- [`mlm:accelerator`, `mlm:accelerator_constrained`, `mlm:accelerator_summary`](./README.md#accelerator-type-enum)
to specify hardware requirements for the model
- Use common metadata
[Asset Object](https://github.com/radiantearth/stac-spec/blob/master/collection-spec/collection-spec.md#asset-object)
Expand All @@ -149,7 +128,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
STAC Item properties (top-level, not nested) to allow better search support by STAC API.
- reorganized `dlm:architecture` nested fields to exist at the top level of properties as `mlm:name`, `mlm:summary`
and so on to provide STAC API search capabilities.
- replaced `normalization:mean`, etc. with [statistics](README.md#bands-and-statistics) from STAC 1.1 common metadata
- replaced `normalization:mean`, etc. with [statistics](./README.md#bands-and-statistics) from STAC 1.1 common metadata
- added `pydantic` models for internal schema objects in `stac_model` package and published to PYPI
- specified [rel_type](README.md#relation-types) to be `derived_from` and
specify how model item or collection json should be named
Expand All @@ -165,7 +144,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- any `dlm`-prefixed field or property

### Removed
- Data Object, replaced with [Model Input Object](README.md#model-input-object) that uses the `name` field from
- Data Object, replaced with [Model Input Object](./README.md#model-input-object) that uses the `name` field from
the [common metadata band object][stac-bands] which also records `data_type` and `nodata` type

### Fixed
Expand Down
33 changes: 5 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ However, fields that relate to supervised ML are optional and users can use the
<!-- lint enable -->

See [Best Practices](./best-practices.md) for guidance on what other STAC extensions you should use in conjunction
with this extension.
with this extension as well as suggested values for specific ML framework.

The Machine Learning Model Extension purposely omits and delegates some definitions to other STAC extensions to favor
reusability and avoid metadata duplication whenever possible. A properly defined MLM STAC Item/Collection should almost
Expand Down Expand Up @@ -667,7 +667,8 @@ In order to provide more context, the following roles are also recommended were
| href | string | URI to the model artifact. |
| type | string | The media type of the artifact (see [Model Artifact Media-Type](#model-artifact-media-type). |
| roles | \[string] | **REQUIRED** Specify `mlm:model`. Can include `["mlm:weights", "mlm:checkpoint"]` as applicable. |
| mlm:artifact_type | [Artifact Type](#artifact-type) | Specifies the kind of model artifact. Typically related to a particular ML framework. |
| mlm:artifact_type | [Artifact Type](./best-practices.md#framework-specific-artifact-types) | Specifies the kind of model artifact. Typically related to a particular ML framework. This is **REQUIRED** if the `mlm:model` role is specified. |
rbavery marked this conversation as resolved.
Show resolved Hide resolved
| mlm:compile_method | string | Describes the method used to compile the ML model at either save time or runtime prior to inference. These options are mutually exclusive `["aot", "jit", null]`. |
rbavery marked this conversation as resolved.
Show resolved Hide resolved

Recommended Asset `roles` include `mlm:weights` or `mlm:checkpoint` for model weights that need to be loaded by a
model definition and `mlm:compiled` for models that can be loaded directly without an intermediate model definition.
Expand Down Expand Up @@ -700,35 +701,11 @@ is used for the artifact described by the media-type. However, users need to rem
official. In order to validate the specific framework and artifact type employed by the model, the MLM properties
`mlm:framework` (see [MLM Fields](#item-properties-and-collection-fields)) and
`mlm:artifact_type` (see [Model Asset](#model-asset)) should be employed instead to perform this validation if needed.
See the [the best practices document](./best-practices.md#framework-specific-artifact-types) on suggested
rbavery marked this conversation as resolved.
Show resolved Hide resolved
fields for framework specific artifact types.

[iana-media-type]: https://www.iana.org/assignments/media-types/media-types.xhtml

#### Artifact Type

This value can be used to provide additional details about the specific model artifact being described.
For example, PyTorch offers [various strategies][pytorch-frameworks] for providing model definitions,
such as Pickle (`.pt`), [TorchScript][pytorch-jit-script],
or [PyTorch Ahead-of-Time Compilation][pytorch-aot-inductor] (`.pt2`) approach.
Since they all refer to the same ML framework, the [Model Artifact Media-Type](#model-artifact-media-type)
can be insufficient in this case to detect which strategy should be used to employ the model definition.

Following are some proposed *Artifact Type* values for corresponding approaches, but other names are
permitted as well. Note that the names are selected using the framework-specific definitions to help
the users understand the source explicitly, although this is not strictly required either.

| Artifact Type | Description |
|--------------------|--------------------------------------------------------------------------------------|
| `torch.save` | A model artifact obtained by [Serialized Pickle Object][pytorch-save] (i.e.: `.pt`). |
| `torch.jit.script` | A model artifact obtained by [`TorchScript`][pytorch-jit-script]. |
| `torch.export` | A model artifact obtained by [`torch.export`][pytorch-export] (i.e.: `.pt2`). |
| `torch.compile` | A model artifact obtained by [`torch.compile`][pytorch-compile]. |

[pytorch-compile]: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
[pytorch-export]: https://pytorch.org/docs/main/export.html
[pytorch-frameworks]: https://pytorch.org/docs/main/export.html#existing-frameworks
[pytorch-aot-inductor]: https://pytorch.org/docs/main/torch.compiler_aot_inductor.html
[pytorch-jit-script]: https://pytorch.org/docs/stable/jit.html
[pytorch-save]: https://pytorch.org/tutorials/beginner/saving_loading_models.html

### Source Code Asset

Expand Down
37 changes: 37 additions & 0 deletions best-practices.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,40 @@ training process to find the "best model". This field could also be used to indi
educational purposes only.

[stac-ext-version]: https://github.com/stac-extensions/version

## Framework Specific Artifact Types
rbavery marked this conversation as resolved.
Show resolved Hide resolved

The `mlm:artifact_type` field can be used to clarify how the model was saved which
can help users understand how to load it or in what runtime contexts it should be used. For example, PyTorch offers
rbavery marked this conversation as resolved.
Show resolved Hide resolved
[various strategies][pytorch-frameworks] for providing model definitions, such as Pickle (`.pt`),
[TorchScript][pytorch-jit-script], or [PyTorch Ahead-of-Time Compilation][pytorch-aot-inductor]
(`.pt2`) approach. Since they all refer to the same ML framework, the
rbavery marked this conversation as resolved.
Show resolved Hide resolved
[Model Artifact Media-Type](./README.md#model-artifact-media-type) can be insufficient in this case to detect which
strategy should be used to employ the model definition.

The following are some proposed *Artifact Type* values for the Model Asset's
[`mlm:artifact_type` field](./README.md#model-asset). Other names are
permitted, as these values are not validated by the schema. Note that the names are selected using the
framework-specific definitions to help the users understand how the model artifact was created, although these exact
names are not strictly required either.

| Artifact Type | Description |
|--------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `torch.save` | A [serialized python pickle object][pytorch-save] (i.e.: `.pt`) which can represent a model or state_dict. |
| `torch.jit.save` | A [`TorchScript`][pytorch-jit-script] model artifact obtained with one or more of the graph export options Torchscript Tracing and Torchscript Scripting. |
| `torch.export.save` | A model artifact storing an [ExportedProgram][exported-program] obtained by [`torch.export.export`][pytorch-export] (i.e.: `.pt2`). |
| `tf.keras.Model.save` | Saves a [.keras model file][keras-model], a unified zip archive format containing the architecture, weights, optimizer, losses, and metrics. |
| `tf.keras.Model.save_weights` | A [.weights.h5][keras-save-weights] file containing only model weights for use by Tensorflow or Keras. |
| `tf.keras.Model.export(format='tf_saved_model')` | TF Saved Model is the [recommended format][tf-keras-recommended] by the Tensorflow team for whole model saving/loading for inference. See this example to [save and load models][keras-example] and the docs for [different save methods][keras-methods] in TF and Keras. Also available from `keras.Model.export(format='tf_saved_model')` |
rbavery marked this conversation as resolved.
Show resolved Hide resolved

[exported-program]: https://pytorch.org/docs/main/export.html#serialization
[pytorch-aot-inductor]: https://pytorch.org/docs/main/torch.compiler_aot_inductor.html
[pytorch-export]: https://pytorch.org/docs/main/export.html
[pytorch-frameworks]: https://pytorch.org/docs/main/export.html#existing-frameworks
[pytorch-jit-script]: https://pytorch.org/docs/stable/jit.html
[pytorch-save]: https://pytorch.org/tutorials/beginner/saving_loading_models.html
[keras-save-weights]: https://keras.io/api/models/model_saving_apis/weights_saving_and_loading/#save_weights-method
[keras-example]: https://keras.io/guides/serialization_and_saving/
[tf-keras-recommended]: https://www.tensorflow.org/guide/saved_model#creating_a_savedmodel_from_keras
[keras-methods]: https://keras.io/2.16/api/models/model_saving_apis/
[keras-model]: https://keras.io/api/models/model_saving_apis/model_saving_and_loading/
rbavery marked this conversation as resolved.
Show resolved Hide resolved
22 changes: 19 additions & 3 deletions json-schema/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@
},
"mlm:artifact_type": {
"$ref": "#/$defs/mlm:artifact_type"
},
"mlm:compile_method": {
"$ref": "#/$defs/mlm:compile_method"
}
},
"$comment": "Allow properties not defined by MLM prefix to work with other extensions and attributes, but disallow undefined MLM fields.",
Expand All @@ -324,7 +327,8 @@
"required": [
"mlm:input",
"mlm:output",
"mlm:artifact_type"
"mlm:artifact_type",
"mlm:compile_method"
]
}
},
Expand Down Expand Up @@ -354,7 +358,8 @@
"anyOf": [
{
"required": [
"mlm:artifact_type"
"mlm:artifact_type",
"mlm:compile_method"
]
rbavery marked this conversation as resolved.
Show resolved Hide resolved
}
]
Expand Down Expand Up @@ -460,7 +465,18 @@
"examples": [
"torch.save",
"torch.jit.save",
"torch.export.save"
"torch.export.save",
"tf.keras.Model.save",
"tf.keras.Model.save_weights",
"tf.saved_model.export(format='tf_saved_model')"
]
},
"mlm:compile_method": {
"type": "string",
"minLength": 1,
rbavery marked this conversation as resolved.
Show resolved Hide resolved
"examples": [
"aot",
"jit"
]
},
"mlm:tasks": {
Expand Down
Loading