Skip to content

Commit

Permalink
Default Device Ordinal API (#1079)
Browse files Browse the repository at this point in the history
* Add default device ordinal API

* Fix reference backend
  • Loading branch information
antoniojkim committed Jul 19, 2022
1 parent b730ffb commit 817dea0
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
8 changes: 8 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,5 +192,13 @@ BackendDevice TorchMlirBackendImpl::GetBackendDevice(c10::Device device) const {
return BackendDevice(GetDefaultDeviceType(), device.index());
}

int64_t TorchMlirBackendImpl::GetDefaultDeviceOrdinal() const {
return default_device_ordinal;
}

void TorchMlirBackendImpl::SetDefaultDeviceOrdinal(int64_t ordinal) {
default_device_ordinal = ordinal;
}

} // namespace lazy
} // namespace torch
7 changes: 7 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/backend_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ class TORCH_API TorchMlirBackendImpl : public BackendImplInterface {
// identity mappings.
virtual BackendDevice GetBackendDevice(c10::Device device) const override;

virtual int64_t GetDefaultDeviceOrdinal() const override;

virtual void SetDefaultDeviceOrdinal(int64_t ordinal) override;

/**
* Debug/Metrics
* */
Expand All @@ -164,6 +168,9 @@ class TORCH_API TorchMlirBackendImpl : public BackendImplInterface {
// virtual std::string GetComputationBackendText(
// const ComputationPtr computation
// ) const = 0;

protected:
int64_t default_device_ordinal = 0;
};

} // namespace lazy
Expand Down
14 changes: 9 additions & 5 deletions python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//
//===----------------------------------------------------------------------===//

#include <c10/core/DeviceType.h>
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
Expand All @@ -26,17 +27,19 @@ namespace torch {
namespace lazy {

struct ReferenceLazyBackendDeviceType : public BackendDeviceType {
ReferenceLazyBackendDeviceType(std::string device_type)
ReferenceLazyBackendDeviceType(c10::DeviceType device_type)
: device_type_(device_type) {}
ReferenceLazyBackendDeviceType(int8_t device_type)
: device_type_(static_cast<c10::DeviceType>(device_type)) {}

std::string toString() const override { return device_type_; }
std::string toString() const override { return c10::DeviceTypeName(device_type_); }

std::string device_type_;
c10::DeviceType device_type_;
};

class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl {
public:
ReferenceLazyBackendImpl() : default_device_type_("Magic") {}
ReferenceLazyBackendImpl() : default_device_type_(c10::DeviceType::Lazy) {}

/**
* Configuration
Expand Down Expand Up @@ -128,7 +131,8 @@ class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl {
return std::make_shared<BackendDeviceType>(default_device_type_);
}

void SetDefaultDeviceType(std::string device_type) {

void SetDefaultDeviceType(int8_t device_type) override {
default_device_type_ = ReferenceLazyBackendDeviceType(device_type);
}

Expand Down

0 comments on commit 817dea0

Please sign in to comment.