Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support weight-stripped engine and REFIT_IDENTICAL flag #3167

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

zewenli98
Copy link
Collaborator

@zewenli98 zewenli98 commented Sep 19, 2024

Description

  1. Supported weight-stripped engine
  2. Added REFIT_IDENTICAL flag

Fixes #3146

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@zewenli98 zewenli98 self-assigned this Sep 19, 2024
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Sep 19, 2024
Comment on lines 79 to 82
name: str = "",
settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed
weight_name_map: Optional[dict[Any, Any]] = None,
graph_module: torch.fx.GraphModule = None,
Copy link
Collaborator Author

@zewenli98 zewenli98 Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@narendasan I tried to do refitting for C++ runtime like for Python runtime but didn't work. Any suggestions? should I do in C++ or Python?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesnt refit already work on both apis?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why do we need the graph module in this module?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. In this PR I moved the refitting part into TRTModule, so only works for Python runtime.

  2. graph module is used for refitting

@@ -619,27 +609,32 @@ def run(
builder_config, self.compilation_settings.timing_cache_path
)

serialized_engine = self.builder.build_serialized_network(
# if strip_engine_weights is true, the serialized engine need to be refitted before using
maybe_unrefitted_serialized_engine = self.builder.build_serialized_network(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this maybe unrefitted engine?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please see the design in the comment below. If compilation_settings.strip_engine_weights is true, it needs to be refitted, else it doesn't. so it's maybe

), "weight-stripped engines must be refittable, please set make_refittable=True"

# Refit the weights
refitter = trt.Refitter(self.engine, TRT_LOGGER)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use this function?

def _refit_single_trt_engine_with_gm(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function requires input_list which is not provided in the caller.

@@ -121,6 +124,52 @@ def setup_engine(self) -> None:
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
self.context = self.engine.create_execution_context()

if self.settings.strip_engine_weights:
Copy link
Collaborator

@narendasan narendasan Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We likely shouldnt be doing the refit in these modules

I think for weight stripping there are 3 workflows.

  1. a user just wants a weight stripped engine. They should use convert_exported_program_to_trt_engine with settings strip_weights. The choice of make_refittable can be used to decide between kREFIT and kREFIT_IDENTICAL (though it might not be entirely clear so we might want to think about that setting).
  2. We want to utilize weight stripping to have a lighter weight cache. Here this choice is opaque to the user. The user choice of make_refittable controls if we use kREFIT or kREFIT_IDENTICAL. But once the engine is loaded or we pull from cache we immediately refit (prior to passing the engine to the TRTModule). Same as we do today
  3. The user wants a stripped weights compiled program (im not sure why or if this is a real usecase). Here, this is basically the same as lazy engine loading. We would require that users need to run through refit_engine_weights before executing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. The very beginning idea/design is commented below. I'll move the refitting part back to TRTInterpreter.run()

The choice of make_refittable can be used to decide between kREFIT and kREFIT_IDENTICAL

Do you mean we use make_refittable to control both kREFIT and kREFIT_IDENTICAL?

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zewenli98 do you have a design for this feature?

@zewenli98
Copy link
Collaborator Author

@narendasan Ok, at first the overall design was like:

In TRTInterpreter.run():

if compilation_settings.strip_engine_weights is True:
    if engine_cache not hit:
        1. build a weight-stripped engine
        2. save the weight-stripped engine if engine_cache is set
        3. return the weight-stripped engine (not yet refit)
    else:
        load and return the weight-stripped engine (not yet refit)
else:
    if engine_cache not hit:
        1. build a weight-included engine
        2. save the weight-included engine if engine_cache is set
        3. return the weight-included engine (don't need to refit)
    else:
        load and return the weight-included engine (not yet refit)

Then, in TRTModule, refit if necessary before inference.
The reason that I didn't put the refitting part into TRTInterpreter.run() is that I want to avoid repeated de/serializations of TRT engines: (1) deserialize in TRTInterpreter.run() for refitting and then serialize (2) deserialize in TRTModule again.

@narendasan narendasan closed this Sep 20, 2024
@narendasan narendasan reopened this Sep 20, 2024
@zewenli98
Copy link
Collaborator Author

@narendasan The design was updated.

From the users' perspective, they are able to set make_refittable and refit_identical_engine_weights.
make_refittable for general refitting and refit_identical_engine_weights only for refitting with identical weights

if self.compilation_settings.make_refittable:
    if version.parse(trt.__version__) >= version.parse("10.0"):
        if self.compilation_settings.refit_identical_engine_weights:
            builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL)
        else:
            builder_config.set_flag(trt.BuilderFlag.REFIT)
    else:
        builder_config.set_flag(trt.BuilderFlag.REFIT)

Besides, users can specify strip_engine_weights. If strip_engine_weights is True, TRTInterpreter.run() will return weight-stripped engine. Otherwise, return general engine (with weights).

For the 3 workflows mentioned above,

  1. controlling the args above, users can call convert_exported_program_to_trt_engine specifying strip_engine_weights=True to get weight-stripped engine.

  2. For engine caching, the implementation of weight-stripped engine is opaque to users, which means engine caching mechanism will (1) save weight-stripped engine no matter what settings users specify (make_refittable is required to be true) and then (2) load and refit the weight-stripped engine while reusing cached engines.
    If strip_engine_weights is True, the engine will not be refitted. Instead, just returns weight-stripped engine.

  3. If users specify strip_engine_weights=True, calling torch.compile() or torch_trt.dynamo.compile() will return weight-stripped compiled program. If running the compiled program with inputs, all the results will be zeros. Then, calling refit_module_weights will make weights back, e.g.:

from torch_tensorrt.dynamo._refit import refit_module_weights
refitted_trt_gm = refit_module_weights(trt_gm, exp_program)
refitted_output = refitted_trt_gm(*inputs)

Please see more details in the tests.

@narendasan
Copy link
Collaborator

The reason that I didn't put the refitting part into TRTInterpreter.run() is that I want to avoid repeated de/serializations of TRT engines: (1) deserialize in TRTInterpreter.run() for refitting and then serialize (2) deserialize in TRTModule again.

I think that we need to separate the runtime and the compiler so im willing to spend the time serializing and deserializing.

I think we should frame PR this around moving TRTInterpreter to default to building weight stripped engines. There will be 3 kinds of engines now.

  1. weight strip + refittable (strip_weights + kREFIT) - should move towards this being the default
  2. weight strip + refittable with original weights (strip_weights + kREFIT_INDIVIDUAL)
  3. non_refittable

The first 2 need separate cache entries. So we need to be able to hash on the weights in the case that the model is being built with kREFIT_INDIVIDUAL

We should look to prefer case 1 in the long term as it allows us to reuse the most work, case 2 would be the next preference. Case 2 should produce faster engines than Case 1 so there remains a need to support kREFIT_IDENTICAL

Do you mean we use make_refittable to control both kREFIT and kREFIT_IDENTICAL?

The case for type 3 engines now is only valid if building a non refittable engine is faster than building a refit_identical engine then refitting the weights. If it is not by a significant enough margin I propose we remove that workflow and just have refit or refit_identical engines.

So assuming that we can remove type 3 engines, make_refittable really means "allows the weights to be changed" (we can change the name if needed here), since now both engines are refittable they just have different weight constraints.

@narendasan
Copy link
Collaborator

Some of the open questions are:

  • how we determine if the weights have been refit prior to running the engine. Can TRT tell us without an error?
  • How can we tell if a user is trying to refit an engine with different weights to an engine built with REFIT_IDENTICAL?
  • If building strip weights refit identical + refit is slower than just building?

@zewenli98
Copy link
Collaborator Author

The reason that I didn't put the refitting part into TRTInterpreter.run() is that I want to avoid repeated de/serializations of TRT engines: (1) deserialize in TRTInterpreter.run() for refitting and then serialize (2) deserialize in TRTModule again.

I think that we need to separate the runtime and the compiler so im willing to spend the time serializing and deserializing.

I think we should frame PR this around moving TRTInterpreter to default to building weight stripped engines. There will be 3 kinds of engines now.

  1. weight strip + refittable (strip_weights + kREFIT) - should move towards this being the default
  2. weight strip + refittable with original weights (strip_weights + kREFIT_INDIVIDUAL)
  3. non_refittable

The first 2 need separate cache entries. So we need to be able to hash on the weights in the case that the model is being built with kREFIT_INDIVIDUAL

We should look to prefer case 1 in the long term as it allows us to reuse the most work, case 2 would be the next preference. Case 2 should produce faster engines than Case 1 so there remains a need to support kREFIT_IDENTICAL

Are you referring to kREFIT_IDENTICAL or kREFIT_INDIVIDUAL? The updated design only considered kREFIT_IDENTICAL. kREFIT_INDIVIDUAL is for fine-grained control which is not yet to be considered.

@zewenli98
Copy link
Collaborator Author

  • how we determine if the weights have been refit prior to running the engine. Can TRT tell us without an error?

My current design is: If users specify strip_engine_weights=True in compile, the weights will not be refitted. They will get a weight-stripped engine.
However, if they get an engine somewhere, they can call get_missing_weights() to see if there's any weight not gets refitted.

  • How can we tell if a user is trying to refit an engine with different weights to an engine built with REFIT_IDENTICAL?

I also thought about it earlier. The TRT doc says "if the refit weights are not identical to the build-time weights, behavior is undefined... This enables use of a single set of weights with different inference backends, or with TensorRT plans for multiple GPU architectures."
My understanding is that we cannot tell if weights are identical in build time and refitting, from the perspective of engine itself, because weight-stripped engine doesn't compare weights in build time and refitting phase, or give any prompts. So users need to be clear what they are refitting.

  • If building strip weights refit identical + refit is slower than just building?

will investigate on it. But if users want a weight-included engine, they can specify strip_engine_weights=False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Weight specific engine caching
4 participants