Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Predict 1.9-4.2x faster #1341

Closed
wants to merge 13 commits into from
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
cmake_minimum_required(VERSION 2.8.9)
project(fasttext)

set(CMAKE_CXX_STANDARD 17)

# The version number.
set (fasttext_VERSION_MAJOR 0)
set (fasttext_VERSION_MINOR 1)

include_directories(fasttext)

set(CMAKE_CXX_FLAGS " -pthread -std=c++11 -funroll-loops -O3 -march=native")
set(CMAKE_CXX_FLAGS " -pthread -std=c++17 -funroll-loops -O3 -march=native")

set(HEADER_FILES
src/args.h
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#

CXX = c++
CXXFLAGS = -pthread -std=c++11 -march=native
CXXFLAGS = -pthread -std=c++17 -march=native
OBJS = args.o autotune.o matrix.o dictionary.o loss.o productquantizer.o densematrix.o quantmatrix.o vector.o model.o utils.o meter.o fasttext.o
INCLUDES = -I.

Expand Down
7 changes: 3 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,14 @@ def has_flag(compiler, flags):


def cpp_flag(compiler):
"""Return the -std=c++[11/14] compiler flag.
The c++14 is preferred over c++11 (when it is available).
"""Return the -std=c++17 compiler flag.
"""
standards = ['-std=c++11']
standards = ['-std=c++17']
for standard in standards:
if has_flag(compiler, [standard]):
return standard
raise RuntimeError(
'Unsupported compiler -- at least C++11 support '
'Unsupported compiler -- at least C++17 support '
'is needed!'
)

Expand Down
98 changes: 98 additions & 0 deletions src/aligned.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#pragma once
#include <cstdlib>
#include <new>
#ifdef _MSC_VER
// Ensure _HAS_EXCEPTIONS is defined
#include <vcruntime.h>
#include <malloc.h>
#endif

#if !((defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS))
#include <cstdlib>
#endif

// Aligned simple vector.

namespace intgemm {

template <class T> class AlignedVector {
public:
AlignedVector() : mem_(nullptr), size_(0) {}

explicit AlignedVector(std::size_t size, std::size_t alignment = 64 /* CPU cares about this */)
: size_(size) {
#ifdef _MSC_VER
mem_ = static_cast<T*>(_aligned_malloc(size * sizeof(T), alignment));
if (!mem_) {
# if (defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS)
throw std::bad_alloc();
# else
std::abort();
# endif
}
#else
if (posix_memalign(reinterpret_cast<void **>(&mem_), alignment, size * sizeof(T))) {
# if (defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS)
throw std::bad_alloc();
# else
std::abort();
# endif
}
#endif
}

template <class InputIt> AlignedVector(InputIt first, InputIt last)
: AlignedVector(last - first) {
std::copy(first, last, begin());
}

AlignedVector(AlignedVector &&from) noexcept : mem_(from.mem_), size_(from.size_) {
from.mem_ = nullptr;
from.size_ = 0;
}

AlignedVector &operator=(AlignedVector &&from) {
if (this == &from) return *this;
release();
mem_ = from.mem_;
size_ = from.size_;
from.mem_ = nullptr;
from.size_ = 0;
return *this;
}

AlignedVector(const AlignedVector&) = delete;
AlignedVector& operator=(const AlignedVector&) = delete;

~AlignedVector() { release(); }

std::size_t size() const { return size_; }

T &operator[](std::size_t offset) { return mem_[offset]; }
const T &operator[](std::size_t offset) const { return mem_[offset]; }

T *begin() { return mem_; }
const T *begin() const { return mem_; }
T *end() { return mem_ + size_; }
const T *end() const { return mem_ + size_; }

T *data() { return mem_; }
const T *data() const { return mem_; }

template <typename ReturnType>
ReturnType *as() { return reinterpret_cast<ReturnType*>(mem_); }

private:
T *mem_;
std::size_t size_;

void release() {
#ifdef _MSC_VER
_aligned_free(mem_);
#else
std::free(mem_);
#endif
}
};

} // namespace intgemm
92 changes: 91 additions & 1 deletion src/densematrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
#include "utils.h"
#include "vector.h"

#if defined(__AVX512F__) || defined(__AVX__) || defined(__SSE__)
#include <immintrin.h>
#endif

namespace fasttext {

DenseMatrix::DenseMatrix() : DenseMatrix(0, 0) {}
Expand Down Expand Up @@ -146,6 +150,92 @@ void DenseMatrix::addRowToVector(Vector& x, int32_t i, real a) const {
}
}

/* Abstract over AVX512F, AVX, and SSE intrinsics, using the one available on this machine. */
#if defined(__AVX512F__)
using Register = __m512;
inline Register Add(Register first, Register second) { return _mm512_add_ps(first, second); }
inline Register Set1(float to) { return _mm512_set1_ps(to); }
inline Register Multiply(Register first, Register second) { return _mm512_mul_ps(first, second); }
#elif defined(__AVX__)
using Register = __m256;
inline Register Add(Register first, Register second) { return _mm256_add_ps(first, second); }
inline Register Set1(float to) { return _mm256_set1_ps(to); }
inline Register Multiply(Register first, Register second) { return _mm256_mul_ps(first, second); }
#elif defined(__SSE__)
using Register = __m128;
inline Register Add(Register first, Register second) { return _mm_add_ps(first, second); }
inline Register Set1(float to) { return _mm_set1_ps(to); }
inline Register Multiply(Register first, Register second) { return _mm_mul_ps(first, second); }
#endif

/* Faster routine for averaging rows of a matrix on x86.
* The idea here is to keep the accumulators in registers if possible. */
#if defined(__AVX512F__) || defined(__AVX__) || defined(__SSE__)
template <unsigned Cols> void averageRowsFast(Vector& x, const std::vector<int32_t>& rows, const DenseMatrix &matrix) {
// Columns must be a multiple of how many floats fit in a register.
static_assert(Cols % (sizeof(Register) / 4) == 0);
constexpr unsigned RegisterCount = Cols / (sizeof(Register) / 4);
// These should be aligned by aligned.h
assert(reinterpret_cast<uintptr_t>(x.data()) % sizeof(Register) == 0);
assert(reinterpret_cast<uintptr_t>(matrix.data()) % sizeof(Register) == 0);

// Guard against empty list of rows with default NaN behavior.
if (rows.empty()) {
x.zero();
x.mul(1.0 / rows.size());
return;
}

// Copy the first row to accumulation registers.
Register accum[RegisterCount];
auto row = rows.cbegin();
const Register *base = reinterpret_cast<const Register*>(matrix.data() + matrix.cols() * *row);
for (unsigned i = 0; i < RegisterCount; ++i) {
accum[i] = base[i];
}
// Add the rows after the first.
for (++row; row != rows.cend(); ++row) {
base = reinterpret_cast<const Register*>(matrix.data() + matrix.cols() * *row);
for (unsigned i = 0; i < RegisterCount; ++i) {
accum[i] = Add(accum[i], base[i]);
}
}
// Multiply by (1.0 / rows.size()) and write to x.
Register mul = Set1(1.0 / rows.size());
for (unsigned i = 0; i < RegisterCount; ++i) {
reinterpret_cast<Register*>(x.data())[i] = Multiply(accum[i], mul);
}
}
#endif

void DenseMatrix::averageRowsToVector(Vector& x, const std::vector<int32_t>& rows) const {
#if defined(__AVX512F__) || defined(__AVX__) || defined(__SSE__)
switch (cols()) {
case 512:
// Maximum number that can fit all in registers on AVX512F.
averageRowsFast<512>(x, rows, *this);
return;
case 256:
averageRowsFast<256>(x, rows, *this);
return;
case 64:
averageRowsFast<64>(x, rows, *this);
return;
case 32:
averageRowsFast<32>(x, rows, *this);
return;
case 16:
averageRowsFast<16>(x, rows, *this);
return;
}
#endif
x.zero();
for (auto it = rows.cbegin(); it != rows.cend(); ++it) {
addRowToVector(x, *it);
}
x.mul(1.0 / rows.size());
}

void DenseMatrix::save(std::ostream& out) const {
out.write((char*)&m_, sizeof(int64_t));
out.write((char*)&n_, sizeof(int64_t));
Expand All @@ -155,7 +245,7 @@ void DenseMatrix::save(std::ostream& out) const {
void DenseMatrix::load(std::istream& in) {
in.read((char*)&m_, sizeof(int64_t));
in.read((char*)&n_, sizeof(int64_t));
data_ = std::vector<real>(m_ * n_);
data_ = intgemm::AlignedVector<real>(m_ * n_);
in.read((char*)data_.data(), m_ * n_ * sizeof(real));
}

Expand Down
4 changes: 3 additions & 1 deletion src/densematrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <stdexcept>
#include <vector>

#include "aligned.h"
#include "matrix.h"
#include "real.h"

Expand All @@ -24,7 +25,7 @@ class Vector;

class DenseMatrix : public Matrix {
protected:
std::vector<real> data_;
intgemm::AlignedVector<real> data_;
void uniformThread(real, int, int32_t);

public:
Expand Down Expand Up @@ -71,6 +72,7 @@ class DenseMatrix : public Matrix {
void addVectorToRow(const Vector&, int64_t, real) override;
void addRowToVector(Vector& x, int32_t i) const override;
void addRowToVector(Vector& x, int32_t i, real a) const override;
void averageRowsToVector(Vector& x, const std::vector<int32_t>& rows) const override;
void save(std::ostream&) const override;
void load(std::istream&) override;
void dump(std::ostream&) const override;
Expand Down
66 changes: 58 additions & 8 deletions src/dictionary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ Dictionary::Dictionary(std::shared_ptr<Args> args, std::istream& in)
load(in);
}

int32_t Dictionary::find(const std::string& w) const {
int32_t Dictionary::find(const std::string_view w) const {
return find(w, hash(w));
}

int32_t Dictionary::find(const std::string& w, uint32_t h) const {
int32_t Dictionary::find(const std::string_view w, uint32_t h) const {
int32_t word2intsize = word2int_.size();
int32_t id = h % word2intsize;
while (word2int_[id] != -1 && words_[word2int_[id]].word != w) {
Expand Down Expand Up @@ -126,12 +126,12 @@ bool Dictionary::discard(int32_t id, real rand) const {
return rand > pdiscard_[id];
}

int32_t Dictionary::getId(const std::string& w, uint32_t h) const {
int32_t Dictionary::getId(const std::string_view w, uint32_t h) const {
int32_t id = find(w, h);
return word2int_[id];
}

int32_t Dictionary::getId(const std::string& w) const {
int32_t Dictionary::getId(const std::string_view w) const {
int32_t h = find(w);
return word2int_[h];
}
Expand All @@ -142,7 +142,7 @@ entry_type Dictionary::getType(int32_t id) const {
return words_[id].type;
}

entry_type Dictionary::getType(const std::string& w) const {
entry_type Dictionary::getType(const std::string_view w) const {
return (w.find(args_->label) == 0) ? entry_type::label : entry_type::word;
}

Expand All @@ -160,7 +160,7 @@ std::string Dictionary::getWord(int32_t id) const {
// Since all fasttext models that were already released were trained
// using signed char, we fixed the hash function to make models
// compatible whatever compiler is used.
uint32_t Dictionary::hash(const std::string& str) const {
uint32_t Dictionary::hash(const std::string_view str) const {
uint32_t h = 2166136261;
for (size_t i = 0; i < str.size(); i++) {
h = h ^ uint32_t(int8_t(str[i]));
Expand Down Expand Up @@ -324,11 +324,16 @@ void Dictionary::addWordNgrams(

void Dictionary::addSubwords(
std::vector<int32_t>& line,
const std::string& token,
const std::string_view token,
int32_t wid) const {
if (wid < 0) { // out of vocab
if (token != EOS) {
computeSubwords(BOW + token + EOW, line);
std::string concat;
concat.reserve(BOW.size() + token.size() + EOW.size());
concat += BOW;
concat.append(token.data(), token.size());
concat += EOW;
computeSubwords(concat, line);
}
} else {
if (args_->maxn <= 0) { // in vocab w/o subwords
Expand Down Expand Up @@ -406,6 +411,51 @@ int32_t Dictionary::getLine(
return ntokens;
}

namespace {
bool readWordNoNewline(std::string_view& in, std::string_view& word) {
const std::string_view spaces(" \n\r\t\v\f\0");
std::string_view::size_type begin = in.find_first_not_of(spaces);
if (begin == std::string_view::npos) {
in.remove_prefix(in.size());
return false;
}
in.remove_prefix(begin);
word = in.substr(0, in.find_first_of(spaces));
in.remove_prefix(word.size());
return true;
}
} // namespace

int32_t Dictionary::getStringNoNewline(
std::string_view in,
std::vector<int32_t>& words,
std::vector<int32_t>& labels) const {
std::vector<int32_t> word_hashes;
std::string_view token;
int32_t ntokens = 0;

words.clear();
labels.clear();
while (readWordNoNewline(in, token)) {
uint32_t h = hash(token);
int32_t wid = getId(token, h);
entry_type type = wid < 0 ? getType(token) : getType(wid);

ntokens++;
if (type == entry_type::word) {
addSubwords(words, token, wid);
word_hashes.push_back(h);
} else if (type == entry_type::label && wid >= 0) {
labels.push_back(wid - nwords_);
}
if (token == EOS) {
break;
}
}
addWordNgrams(words, word_hashes, args_->wordNgrams);
return ntokens;
}

void Dictionary::pushHash(std::vector<int32_t>& hashes, int32_t id) const {
if (pruneidx_size_ == 0 || id < 0) {
return;
Expand Down
Loading