Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impl IndexPreTransform for c_api #1816

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions c_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ set(FAISS_C_SRC
IndexIVF_c.cpp
IndexLSH_c.cpp
IndexPreTransform_c.cpp
VectorTransform_c.cpp
IndexShards_c.cpp
Index_c.cpp
MetaIndexes_c.cpp
Expand Down
40 changes: 40 additions & 0 deletions c_api/IndexPreTransform_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,52 @@

#include "IndexPreTransform_c.h"
#include <faiss/IndexPreTransform.h>
#include <faiss/VectorTransform.h>
#include "macros_impl.h"

using faiss::Index;
using faiss::IndexPreTransform;
using faiss::VectorTransform;

extern "C" {

DEFINE_DESTRUCTOR(IndexPreTransform)
DEFINE_INDEX_DOWNCAST(IndexPreTransform)

DEFINE_GETTER_PERMISSIVE(IndexPreTransform, FaissIndex*, index)

DEFINE_GETTER(IndexPreTransform, int, own_fields)
DEFINE_SETTER(IndexPreTransform, int, own_fields)

int faiss_IndexPreTransform_new(FaissIndexPreTransform** p_index) {
try {
*p_index = reinterpret_cast<FaissIndexPreTransform*>(
new IndexPreTransform());
}
CATCH_AND_HANDLE
}

int faiss_IndexPreTransform_new_with(
FaissIndexPreTransform** p_index,
FaissIndex* index) {
try {
auto ind = reinterpret_cast<Index*>(index);
*p_index = reinterpret_cast<FaissIndexPreTransform*>(
new IndexPreTransform(ind));
}
CATCH_AND_HANDLE
}

int faiss_IndexPreTransform_new_with_transform(
FaissIndexPreTransform** p_index,
FaissVectorTransform* ltrans,
FaissIndex* index) {
try {
auto lt = reinterpret_cast<VectorTransform*>(ltrans);
auto ind = reinterpret_cast<Index*>(index);
*p_index = reinterpret_cast<FaissIndexPreTransform*>(
new IndexPreTransform(lt, ind));
}
CATCH_AND_HANDLE
}
}
17 changes: 16 additions & 1 deletion c_api/IndexPreTransform_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,32 @@
#define FAISS_INDEX_PRETRANSFORM_C_H

#include "Index_c.h"
#include "VectorTransform_c.h"
#include "faiss_c.h"

#ifdef __cplusplus
extern "C" {
#endif

FAISS_DECLARE_CLASS(IndexPreTransform)
/** Index that applies a LinearTransform transform on vectors before
* handing them over to a sub-index */
FAISS_DECLARE_CLASS_INHERITED(IndexPreTransform, Index)
FAISS_DECLARE_DESTRUCTOR(IndexPreTransform)
FAISS_DECLARE_INDEX_DOWNCAST(IndexPreTransform)

FAISS_DECLARE_GETTER(IndexPreTransform, FaissIndex*, index)
FAISS_DECLARE_GETTER_SETTER(IndexPreTransform, int, own_fields)

int faiss_IndexPreTransform_new(FaissIndexPreTransform** p_index);

int faiss_IndexPreTransform_new_with(
FaissIndexPreTransform** p_index,
FaissIndex* index);

int faiss_IndexPreTransform_new_with_transform(
FaissIndexPreTransform** p_index,
FaissVectorTransform* ltrans,
FaissIndex* index);

#ifdef __cplusplus
}
Expand Down
2 changes: 1 addition & 1 deletion c_api/Index_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ FAISS_DECLARE_GETTER(Index, FaissMetricType, metric_type)
*
* @param index opaque pointer to index object
* @param n nb of training vectors
* @param x training vecors, size n * d
* @param x training vectors, size n * d
*/
int faiss_Index_train(FaissIndex* index, idx_t n, const float* x);

Expand Down
228 changes: 228 additions & 0 deletions c_api/VectorTransform_c.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

// Copyright 2004-present Facebook. All Rights Reserved.
// -*- c++ -*-

#include "VectorTransform_c.h"
#include <faiss/VectorTransform.h>
#include "macros_impl.h"

extern "C" {

DEFINE_DESTRUCTOR(VectorTransform)

DEFINE_GETTER(VectorTransform, int, is_trained)

DEFINE_GETTER(VectorTransform, int, d_in)

DEFINE_GETTER(VectorTransform, int, d_out)

int faiss_VectorTransform_train(
FaissVectorTransform* vt,
idx_t n,
const float* x) {
try {
reinterpret_cast<faiss::VectorTransform*>(vt)->train(n, x);
}
CATCH_AND_HANDLE
}

float* faiss_VectorTransform_apply(
const FaissVectorTransform* vt,
idx_t n,
const float* x) {
return reinterpret_cast<const faiss::VectorTransform*>(vt)->apply(n, x);
}

void faiss_VectorTransform_apply_noalloc(
const FaissVectorTransform* vt,
idx_t n,
const float* x,
float* xt) {
return reinterpret_cast<const faiss::VectorTransform*>(vt)->apply_noalloc(
n, x, xt);
}

void faiss_VectorTransform_reverse_transform(
const FaissVectorTransform* vt,
idx_t n,
const float* xt,
float* x) {
return reinterpret_cast<const faiss::VectorTransform*>(vt)
->reverse_transform(n, xt, x);
}

/*********************************************
* LinearTransform
*********************************************/

DEFINE_DESTRUCTOR(LinearTransform)

DEFINE_GETTER(LinearTransform, int, have_bias)

DEFINE_GETTER(LinearTransform, int, is_orthonormal)

void faiss_LinearTransform_transform_transpose(
const FaissLinearTransform* vt,
idx_t n,
const float* y,
float* x) {
return reinterpret_cast<const faiss::LinearTransform*>(vt)
->transform_transpose(n, y, x);
}

void faiss_LinearTransform_set_is_orthonormal(FaissLinearTransform* vt) {
return reinterpret_cast<faiss::LinearTransform*>(vt)->set_is_orthonormal();
}

/*********************************************
* RandomRotationMatrix
*********************************************/

DEFINE_DESTRUCTOR(RandomRotationMatrix)

int faiss_RandomRotationMatrix_new_with(
FaissRandomRotationMatrix** p_vt,
int d_in,
int d_out) {
try {
*p_vt = reinterpret_cast<FaissRandomRotationMatrix*>(
new faiss::RandomRotationMatrix(d_in, d_out));
}
CATCH_AND_HANDLE
}

/*********************************************
* PCAMatrix
*********************************************/

DEFINE_DESTRUCTOR(PCAMatrix)

int faiss_PCAMatrix_new_with(
FaissPCAMatrix** p_vt,
int d_in,
int d_out,
float eigen_power,
int random_rotation) {
try {
bool random_rotation_ = static_cast<bool>(random_rotation);
*p_vt = reinterpret_cast<FaissPCAMatrix*>(new faiss::PCAMatrix(
d_in, d_out, eigen_power, random_rotation_));
}
CATCH_AND_HANDLE
}

DEFINE_GETTER(PCAMatrix, float, eigen_power)

DEFINE_GETTER(PCAMatrix, int, random_rotation)

/*********************************************
* ITQMatrix
*********************************************/

DEFINE_DESTRUCTOR(ITQMatrix)

int faiss_ITQMatrix_new_with(FaissITQMatrix** p_vt, int d) {
try {
*p_vt = reinterpret_cast<FaissITQMatrix*>(new faiss::ITQMatrix(d));
}
CATCH_AND_HANDLE
}

DEFINE_DESTRUCTOR(ITQTransform)

int faiss_ITQTransform_new_with(
FaissITQTransform** p_vt,
int d_in,
int d_out,
int do_pca) {
try {
bool do_pca_ = static_cast<bool>(do_pca);
*p_vt = reinterpret_cast<FaissITQTransform*>(
new faiss::ITQTransform(d_in, d_out, do_pca_));
}
CATCH_AND_HANDLE
}

DEFINE_GETTER(ITQTransform, int, do_pca)

/*********************************************
* OPQMatrix
*********************************************/

DEFINE_DESTRUCTOR(OPQMatrix)

int faiss_OPQMatrix_new_with(FaissOPQMatrix** p_vt, int d, int M, int d2) {
try {
*p_vt = reinterpret_cast<FaissOPQMatrix*>(
new faiss::OPQMatrix(d, M, d2));
}
CATCH_AND_HANDLE
}

DEFINE_GETTER(OPQMatrix, int, verbose)
DEFINE_SETTER(OPQMatrix, int, verbose)

DEFINE_GETTER(OPQMatrix, int, niter)
DEFINE_SETTER(OPQMatrix, int, niter)

DEFINE_GETTER(OPQMatrix, int, niter_pq)
DEFINE_SETTER(OPQMatrix, int, niter_pq)

/*********************************************
* RemapDimensionsTransform
*********************************************/

DEFINE_DESTRUCTOR(RemapDimensionsTransform)

int faiss_RemapDimensionsTransform_new_with(
FaissRemapDimensionsTransform** p_vt,
int d_in,
int d_out,
int uniform) {
try {
bool uniform_ = static_cast<bool>(uniform);
*p_vt = reinterpret_cast<FaissRemapDimensionsTransform*>(
new faiss::RemapDimensionsTransform(d_in, d_out, uniform_));
}
CATCH_AND_HANDLE
}

/*********************************************
* NormalizationTransform
*********************************************/

DEFINE_DESTRUCTOR(NormalizationTransform)

int faiss_NormalizationTransform_new_with(
FaissNormalizationTransform** p_vt,
int d,
float norm) {
try {
*p_vt = reinterpret_cast<FaissNormalizationTransform*>(
new faiss::NormalizationTransform(d, norm));
}
CATCH_AND_HANDLE
}

DEFINE_GETTER(NormalizationTransform, float, norm)

/*********************************************
* CenteringTransform
*********************************************/

DEFINE_DESTRUCTOR(CenteringTransform)

int faiss_CenteringTransform_new_with(FaissCenteringTransform** p_vt, int d) {
try {
*p_vt = reinterpret_cast<FaissCenteringTransform*>(
new faiss::CenteringTransform(d));
}
CATCH_AND_HANDLE
}
}
Loading