Skip to content

Commit

Permalink
[XLA:HloParser] Add a flag set_to_default_entry_computation_layout
Browse files Browse the repository at this point in the history
…with default value true.

If it is false, do not overwrite the raw input with the default layout in entry_compuation_layout. If entry_compuation_layout is defined explicitly, we should exactly follow the raw definition. For instance, if the raw input has only shape and does not have layout, we should not set the default layout.

If entry_compuation_layout is not defined explicitly, we still infer it from the parameter and root instructions of entry computation. The layout of these instructions can be either explicitly defined or the default one.

PiperOrigin-RevId: 673159258
  • Loading branch information
ZixuanJiang authored and copybara-github committed Sep 11, 2024
1 parent be6b8f4 commit 81ba4c2
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 38 deletions.
64 changes: 37 additions & 27 deletions xla/service/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,11 @@ class HloParserImpl : public HloParser {
using LocTy = HloLexer::LocTy;
using BoolList = absl::InlinedVector<bool, 1>;

explicit HloParserImpl(absl::string_view str) : lexer_(str) {}
explicit HloParserImpl(absl::string_view str,
bool set_to_default_entry_computation_layout = true)
: lexer_(str),
set_to_default_entry_computation_layout_(
set_to_default_entry_computation_layout) {}

// Runs the parser and constructs the resulting HLO in the given (empty)
// HloModule. Returns the error status in case an error occurred.
Expand Down Expand Up @@ -541,7 +545,7 @@ class HloParserImpl : public HloParser {
bool ParseJsonDict(std::string* result);
bool ParseDimensionSizes(std::vector<int64_t>* dimension_sizes,
std::vector<bool>* dynamic_dimensions);
bool ParseShape(Shape* result);
bool ParseShape(Shape* result, bool set_to_default_layout = true);
bool ParseLayout(Layout* layout);
bool ParseLayoutIntAttribute(int64_t* attr_value,
absl::string_view attr_description);
Expand Down Expand Up @@ -664,6 +668,8 @@ class HloParserImpl : public HloParser {

// Used to generate names for anonymous instructions.
NameUniquer name_uniquer_{/*separator=*/"."};

const bool set_to_default_entry_computation_layout_;
};

bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64_t>* out) {
Expand Down Expand Up @@ -907,7 +913,7 @@ bool HloParserImpl::ParseComputationLayout(
}
while (lexer_.GetKind() != TokKind::kRparen) {
Shape param;
if (!ParseShape(&param)) {
if (!ParseShape(&param, set_to_default_entry_computation_layout_)) {
return false;
}
computation_layout->add_parameter_layout(ShapeLayout(param));
Expand All @@ -927,7 +933,7 @@ bool HloParserImpl::ParseComputationLayout(
return false;
}
Shape result;
if (!ParseShape(&result)) {
if (!ParseShape(&result, set_to_default_entry_computation_layout_)) {
return false;
}
*computation_layout->mutable_result_layout() = ShapeLayout(result);
Expand Down Expand Up @@ -1117,9 +1123,6 @@ bool HloParserImpl::ParseHloModule(HloModule* module,

if (parse_module_without_header) {
name = absl::StrCat("module_", module->entry_computation()->name());
entry_computation_layout =
ComputationLayout(module->entry_computation()->ComputeProgramShape(),
/*ignore_layouts*/ false);
}

module->set_name(name);
Expand All @@ -1145,6 +1148,21 @@ bool HloParserImpl::ParseHloModule(HloModule* module,
if (entry_computation_layout.has_value()) {
*config.mutable_entry_computation_layout() = *entry_computation_layout;
default_config = false;
} else {
// If entry_computation_layout is not specified explicitly, we infer the
// layout from parameter and root instructions.
HloComputation* entry_computation = module->entry_computation();
for (int64_t p = 0; p < entry_computation->num_parameters(); p++) {
const Shape& param_shape =
entry_computation->parameter_instruction(p)->shape();
TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = entry_computation->root_instruction()->shape();
TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
}
if (frontend_attributes) {
module->set_frontend_attributes(frontend_attributes.value());
Expand Down Expand Up @@ -1209,19 +1227,7 @@ bool HloParserImpl::ParseComputations(HloModule* module) {
module->AddEmbeddedComputation(std::move(computations_[i]));
continue;
}
auto computation = module->AddEntryComputation(std::move(computations_[i]));
// The parameters and result layouts were set to default layout. Here we
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
module->AddEntryComputation(std::move(computations_[i]));
}
return true;
}
Expand Down Expand Up @@ -6088,7 +6094,7 @@ bool HloParserImpl::ParseLayout(Layout* layout) {
// tuple_elements
// ::= /*empty*/
// ::= shape (',' shape)*
bool HloParserImpl::ParseShape(Shape* result) {
bool HloParserImpl::ParseShape(Shape* result, bool set_to_default_layout) {
if (EatIfPresent(TokKind::kLparen)) { // Tuple
std::vector<Shape> shapes;
if (lexer_.GetKind() == TokKind::kRparen) {
Expand All @@ -6097,7 +6103,7 @@ bool HloParserImpl::ParseShape(Shape* result) {
// shape (',' shape)*
do {
shapes.emplace_back();
if (!ParseShape(&shapes.back())) {
if (!ParseShape(&shapes.back(), set_to_default_layout)) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
Expand All @@ -6123,7 +6129,9 @@ bool HloParserImpl::ParseShape(Shape* result) {
result->add_dimensions(dimension_sizes[i]);
result->set_dynamic_dimension(i, dynamic_dimensions[i]);
}
LayoutUtil::SetToDefaultLayout(result);
if (set_to_default_layout || ShapeUtil::IsScalar(*result)) {
LayoutUtil::SetToDefaultLayout(result);
}
// We need to lookahead to see if a following open brace is the start of a
// layout. The specific problematic case is:
//
Expand Down Expand Up @@ -6976,16 +6984,18 @@ bool HloParserImpl::ParseSingleInstruction(HloModule* module) {
} // namespace

absl::StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view str, const HloModuleConfig& config) {
absl::string_view str, const HloModuleConfig& config,
bool set_to_default_entry_computation_layout) {
auto module = std::make_unique<HloModule>(/*name=*/"_", config);
HloParserImpl parser(str);
HloParserImpl parser(str, set_to_default_entry_computation_layout);
TF_RETURN_IF_ERROR(parser.Run(module.get()));
return std::move(module);
}

absl::StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view str) {
return ParseAndReturnUnverifiedModule(str, HloModuleConfig());
absl::string_view str, bool set_to_default_entry_computation_layout) {
return ParseAndReturnUnverifiedModule(
str, HloModuleConfig(), set_to_default_entry_computation_layout);
}

absl::StatusOr<HloSharding> ParseSharding(absl::string_view str) {
Expand Down
5 changes: 3 additions & 2 deletions xla/service/hlo_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ namespace xla {
// Note: Tests derived from HloTestBase should use
// ParseAndReturnVerifiedModule() instead!
absl::StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view str, const HloModuleConfig& config);
absl::string_view str, const HloModuleConfig& config,
bool set_to_default_entry_computation_layout = true);

// Given a string in the HloModule::ToString() format, parses the string and
// creates a HloModule with default config.
// Note: Tests derived from HloTestBase should use
// ParseAndReturnVerifiedModule() instead!
absl::StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view str);
absl::string_view str, bool set_to_default_entry_computation_layout = true);

// Parses sharding from str. str is supposed to contain the body of the
// sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g.,
Expand Down
30 changes: 28 additions & 2 deletions xla/service/hlo_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3385,8 +3385,10 @@ ENTRY %CustomCall () -> f32[1] {
"with that of its root instruction foo, f32[1,2,3]");
}

TEST_F(HloParserTest, EntryComputationWithLayout) {
const std::string original = R"(HloModule layout:
TEST_F(HloParserTest, EntryComputationLayoutNotDefined) {
const std::string original = R"(
HloModule layout_not_defined
add_F32.v3 {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
Expand Down Expand Up @@ -3414,6 +3416,30 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
<< LayoutUtil::HumanString(result_layout);
}

TEST_F(HloParserTest, EntryComputationLayoutDefined) {
const std::string original = R"(
HloModule layout_defined, entry_computation_layout={(f32[8,16,256]) -> f32[8,16]}
add_F32.v3 {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
input = f32[8,16,256]{0,1,2} parameter(0)
constant = f32[] constant(0)
ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
})";

absl::StatusOr<std::unique_ptr<HloModule>> module =
ParseAndReturnUnverifiedModule(
original, /*set_to_default_entry_computation_layout=*/false);
TF_ASSERT_OK(module.status());
// Do not set the default layout.
EXPECT_FALSE(module.value()->entry_computation_layout().AnyLayoutSet());
}

TEST_F(HloParserTest, NoEntry) {
const std::string original = R"(HloModule no_entry:
c1 {
Expand Down
14 changes: 9 additions & 5 deletions xla/tools/hlo_module_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
const std::string& data, std::string_view format,
const hlo_module_loader_details::Config& ovr_config,
const std::function<void(HloModuleConfig*)>& config_modifier_hook,
BufferAssignmentProto* buffer_assignment_proto) {
BufferAssignmentProto* buffer_assignment_proto,
bool set_to_default_entry_computation_layout) {
DebugOptions debug_options = GetDebugOptionsFromFlags();
std::unique_ptr<HloModule> module;
if (format == "hlo" || format == "txt") {
Expand All @@ -81,8 +82,9 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
if (config_modifier_hook) {
config_modifier_hook(&config);
}
TF_ASSIGN_OR_RETURN(module,
ParseAndReturnUnverifiedModule(hlo_string, config));
TF_ASSIGN_OR_RETURN(module, ParseAndReturnUnverifiedModule(
hlo_string, config,
set_to_default_entry_computation_layout));
} else {
HloSnapshot proto;
if (format == "pb") {
Expand Down Expand Up @@ -130,14 +132,16 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromFile(
const std::string& path, std::string format,
const hlo_module_loader_details::Config& ovr_config,
const std::function<void(HloModuleConfig*)>& config_modifier_hook,
BufferAssignmentProto* buffer_assignment_proto) {
BufferAssignmentProto* buffer_assignment_proto,
bool set_to_default_entry_computation_layout) {
std::string data;
if (format.empty()) {
format = std::string(tsl::io::Extension(path));
}
TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), path, &data));
return LoadModuleFromData(data, format, ovr_config, config_modifier_hook,
buffer_assignment_proto);
buffer_assignment_proto,
set_to_default_entry_computation_layout);
}

absl::StatusOr<std::unique_ptr<RunHloModuleIterationLiterals>>
Expand Down
6 changes: 4 additions & 2 deletions xla/tools/hlo_module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
const hlo_module_loader_details::Config& ovr_config =
hlo_module_loader_details::Config(),
const std::function<void(HloModuleConfig*)>& config_modifier_hook = {},
BufferAssignmentProto* buffer_assignment_proto = nullptr);
BufferAssignmentProto* buffer_assignment_proto = nullptr,
bool set_to_default_entry_computation_layout = true);

// Loads an HLO module from file.
// The file can be one of the followings:
Expand All @@ -82,7 +83,8 @@ absl::StatusOr<std::unique_ptr<HloModule>> LoadModuleFromFile(
const hlo_module_loader_details::Config& ovr_config =
hlo_module_loader_details::Config(),
const std::function<void(HloModuleConfig*)>& config_modifier_hook = {},
BufferAssignmentProto* buffer_assignment_proto = nullptr);
BufferAssignmentProto* buffer_assignment_proto = nullptr,
bool set_to_default_entry_computation_layout = true);

// Loads an HLO snapshot from a string, only for its inputs
// The data format must be one of the following:
Expand Down

0 comments on commit 81ba4c2

Please sign in to comment.