Skip to content

Commit

Permalink
Add static shape for scalar tensors (#833)
Browse files Browse the repository at this point in the history
* Assume zero rank tensors are scalar

* Run RefineTypes pass on JIT Graph

* Rollback assumption that zero rank tensors are scalar

* Set numSizes to -1 for non-ranked tensors

* Rename RefineTypes to RefineTupleTypes
  • Loading branch information
henrytwo authored and antoniojkim committed Jul 7, 2022
1 parent 6be45ca commit 019699c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 3 deletions.
2 changes: 2 additions & 0 deletions include/torch-mlir-c/TorchTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t);

/// Gets a !torch.tensor type.
///
/// - `numSizes` having a value of -1 denotes an unranked tensor.
/// - `optionalSizes` is allowed to be null, meaning that no size
/// information is present (and `numSizes` is ignored in that case). -
/// `optionalDtype` is allowed to be null, meaning that no dtype
Expand All @@ -190,6 +191,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t);

/// Gets a !torch.vtensor type.
///
/// - `numSizes` having a value of -1 denotes an unranked tensor.
/// - `optionalSizes` is allowed to be null, meaning that no size
/// information is present (and `numSizes` is ignored in that case).
/// - `optionalDtype` is allowed to be null, meaning that no dtype
Expand Down
6 changes: 4 additions & 2 deletions lib/CAPI/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
const int64_t *optionalSizes,
MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes)
// if numSizes == -1, then it is unranked.
if (numSizes > -1)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
return wrap(Torch::NonValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
Expand Down Expand Up @@ -231,7 +232,8 @@ MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
const int64_t *optionalSizes,
MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes)
// if numSizes == -1, then it is unranked.
if (numSizes > -1)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
return wrap(Torch::ValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <iostream>

#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/passes/refine_tuple_types.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>

#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
Expand Down Expand Up @@ -108,6 +109,10 @@ void TorchMlirLoweringContext::AddParameter(
ComputationPtr TorchMlirLoweringContext::Build() {
PRINT_FUNCTION();

// Since we mutated the types of some nodes to insert shape information, we
// must perform this pass to ensure tuples have up to date output types.
torch::jit::RefineTupleTypes(graph_);

// Insert return values into graph.
for (torch::jit::Value* output : root_tuple_) {
graph_->block()->registerOutput(output);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
if (!sizes.rank()) {
// Unranked.
return getMlirTensorType(context,
/*numSizes=*/0,
/*numSizes=*/-1,
/*optionalSizes=*/nullptr,
/*optionalDtype=*/
elementType);
Expand Down

0 comments on commit 019699c

Please sign in to comment.