Skip to content

Commit

Permalink
Cleanup host vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
cgmb committed Jan 25, 2023
1 parent 4700f87 commit 8bfe3e5
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 426 deletions.
272 changes: 68 additions & 204 deletions clients/rocblascommon/host_batch_vector.hpp
Original file line number Diff line number Diff line change
@@ -1,265 +1,129 @@
/* ************************************************************************
* Copyright (c) 2018-2021 Advanced Micro Devices, Inc.
* Copyright (c) 2018-2023 Advanced Micro Devices, Inc.
* ************************************************************************ */

#pragma once

#include <cassert>
#include <cmath>
#include <memory>
#include <ostream>
#include <string.h>
#include <vector>

#include "rocblas_init.hpp"
#include <hip/hip_runtime_api.h>
#include <rocblas/rocblas.h>

//
// Local declaration of the device batch vector.
//
template <typename T, size_t PAD, typename U>
class device_batch_vector;

//!
//! @brief Implementation of the batch vector on host.
//!
template <typename T>
class host_batch_vector
{
public:
using value_type = T;

public:
//!
//! @brief Delete copy constructor.
//!
host_batch_vector(const host_batch_vector<T>& that) = delete;

//!
//! @brief Delete copy assignement.
//!
host_batch_vector& operator=(const host_batch_vector<T>& that) = delete;

//!
//! @brief Constructor.
//! @param n The length of the vector.
//! @param inc The increment.
//! @param batch_count The batch count.
//!
explicit host_batch_vector(rocblas_int n, rocblas_int inc, rocblas_int batch_count)
: m_n(n)
, m_inc(inc)
, m_batch_count(batch_count)
host_batch_vector(rocblas_int n, rocblas_int inc, rocblas_int batch_count)
: data_(std::make_unique<ArrT[]>(batch_count))
, n_(n)
, inc_(inc)
, batch_count_(batch_count)
{
if(false == this->try_initialize_memory())
assert(n > 0);
assert(batch_count > 0);

const size_t size = vsize();
for(rocblas_int i = 0; i < batch_count_; ++i)
{
this->free_memory();
data_[i] = std::make_unique<T[]>(size);
}
}

//!
//! @brief Constructor.
//! @param n The length of the vector.
//! @param inc The increment.
//! @param stride (UNUSED) The stride.
//! @param batch_count The batch count.
//!
explicit host_batch_vector(rocblas_int n,
rocblas_int inc,
rocblas_stride stride,
rocblas_int batch_count)
host_batch_vector(rocblas_int n, rocblas_int inc, rocblas_stride stride, rocblas_int batch_count)
: host_batch_vector(n, inc, batch_count)
{
assert(stride == 1);
}

//!
//! @brief Destructor.
//!
~host_batch_vector()
// The number of elements in each vector.
rocblas_int n() const noexcept
{
this->free_memory();
return n_;
}

//!
//! @brief Returns the length of the vector.
//!
rocblas_int n() const
// The increment between elements in each vector.
rocblas_int inc() const noexcept
{
return this->m_n;
return inc_;
}

//!
//! @brief Returns the increment of the vector.
//!
rocblas_int inc() const
// The size of each vector. This is a derived property of the number of elements in the vector
// and the spacing between them.
size_t vsize() const
{
return this->m_inc;
return size_t(n_) * std::abs(inc_);
}

//!
//! @brief Returns the batch count.
//!
rocblas_int batch_count() const
// The number of vectors in the batch.
rocblas_int batch_count() const noexcept
{
return this->m_batch_count;
return batch_count_;
}

//!
//! @brief Returns the stride value.
//!
rocblas_stride stride() const
{
return 0;
}

//!
//! @brief Random access to the vectors.
//! @param batch_index the batch index.
//! @return The mutable pointer.
//!
// Returns a vector from the batch.
T* operator[](rocblas_int batch_index)
{
return this->m_data[batch_index];
assert(batch_index >= 0);
assert(batch_index < batch_count_);
return data_[batch_index].get();
}

//!
//! @brief Constant random access to the vectors.
//! @param batch_index the batch index.
//! @return The non-mutable pointer.
//!
// Returns a vector from the batch.
const T* operator[](rocblas_int batch_index) const
{
return this->m_data[batch_index];
}

// clang-format off
//!
//! @brief Cast to a double pointer.
//!
operator T**()
{
return this->m_data;
}
// clang-format on

//!
//! @brief Constant cast to a double pointer.
//!
operator const T* const *()
{
return this->m_data;
}

//!
//! @brief Copy from a host batched vector.
//! @param that the vector the data is copied from.
//! @return true if the copy is done successfully, false otherwise.
//!
bool copy_from(const host_batch_vector<T>& that)
{
if((this->batch_count() == that.batch_count()) && (this->n() == that.n())
&& (this->inc() == that.inc()))
{
size_t num_bytes = this->n() * std::abs(this->inc()) * sizeof(T);
for(rocblas_int batch_index = 0; batch_index < this->m_batch_count; ++batch_index)
{
memcpy((*this)[batch_index], that[batch_index], num_bytes);
}
return true;
}
else
{
return false;
}
assert(batch_index >= 0);
assert(batch_index < batch_count_);
return data_[batch_index].get();
}

//!
//! @brief Transfer from a device batched vector.
//! @param that the vector the data is copied from.
//! @return the hip error.
//!
// Copy from a device_batch_vector into host memory.
hipError_t transfer_from(const device_batch_vector<T>& that)
{
hipError_t hip_err;
size_t num_bytes = size_t(this->m_n) * std::abs(this->m_inc) * sizeof(T);
for(rocblas_int batch_index = 0; batch_index < this->m_batch_count; ++batch_index)
{
if(hipSuccess
!= (hip_err = hipMemcpy((*this)[batch_index], that[batch_index], num_bytes,
hipMemcpyDeviceToHost)))
{
return hip_err;
}
}
return hipSuccess;
}
assert(n_ == that.n());
assert(inc_ == that.inc());
assert(batch_count_ == that.batch_count());

//!
//! @brief Check if memory exists.
//! @return hipSuccess if memory exists, hipErrorOutOfMemory otherwise.
//!
hipError_t memcheck() const
{
return (nullptr != this->m_data) ? hipSuccess : hipErrorOutOfMemory;
hipError_t err = hipSuccess;
host_batch_vector<T>& self = *this;
size_t num_bytes = vsize() * sizeof(T);
for(size_t b = 0; err == hipSuccess && b < batch_count_; ++b)
err = hipMemcpy(self[b], that[b], num_bytes, hipMemcpyDeviceToHost);
return err;
}

private:
rocblas_int m_n{};
rocblas_int m_inc{};
rocblas_int m_batch_count{};
T** m_data{};

bool try_initialize_memory()
{
bool success = (nullptr != (this->m_data = (T**)calloc(this->m_batch_count, sizeof(T*))));
if(success)
{
size_t nmemb = size_t(this->m_n) * std::abs(this->m_inc);
for(rocblas_int batch_index = 0; batch_index < this->m_batch_count; ++batch_index)
{
success = (nullptr != (this->m_data[batch_index] = (T*)calloc(nmemb, sizeof(T))));
if(false == success)
{
break;
}
}
}
return success;
}
using ArrT = std::unique_ptr<T[]>;

void free_memory()
{
if(nullptr != this->m_data)
{
for(rocblas_int batch_index = 0; batch_index < this->m_batch_count; ++batch_index)
{
if(nullptr != this->m_data[batch_index])
{
free(this->m_data[batch_index]);
this->m_data[batch_index] = nullptr;
}
}

free(this->m_data);
this->m_data = nullptr;
}
}
private:
std::unique_ptr<ArrT[]> data_;
rocblas_int n_;
rocblas_int inc_;
rocblas_int batch_count_;
};

//!
//! @brief Overload output operator.
//! @param os The ostream.
//! @param that That host batch vector.
//!
template <typename T>
std::ostream& operator<<(std::ostream& os, const host_batch_vector<T>& that)
std::ostream& operator<<(std::ostream& os, const host_batch_vector<T>& hbv)
{
auto n = that.n();
auto inc = std::abs(that.inc());
auto batch_count = that.batch_count();
rocblas_int n = hbv.n();
rocblas_int inc = std::abs(hbv.inc());
rocblas_int batch_count = hbv.batch_count();

for(rocblas_int batch_index = 0; batch_index < batch_count; ++batch_index)
for(rocblas_int b = 0; b < batch_count; ++b)
{
auto batch_data = that[batch_index];
os << "[" << batch_index << "] = { " << batch_data[0];
for(rocblas_int i = 1; i < n; ++i)
T* hv = hbv[b];
os << "[" << b << "] = { ";
for(rocblas_int i = 0; i < n; ++i)
{
os << ", " << batch_data[i * inc];
os << hv[i * inc];
if(i + 1 < n)
os << ", ";
}
os << " }" << std::endl;
}
Expand Down
Loading

0 comments on commit 8bfe3e5

Please sign in to comment.