Skip to content

Commit

Permalink
Restored Cosine support.
Browse files Browse the repository at this point in the history
  • Loading branch information
LTLA committed Jul 27, 2024
1 parent bae7c4c commit f7b2125
Show file tree
Hide file tree
Showing 12 changed files with 176 additions and 157 deletions.
34 changes: 5 additions & 29 deletions R/AllGenerics.R
Original file line number Diff line number Diff line change
@@ -1,46 +1,22 @@
#' @export
#' @rdname buildIndex
setGeneric("buildIndex", signature=c("BNPARAM"),
function(X, ..., BNPARAM)
standardGeneric("buildIndex")
)
setGeneric("buildIndex", signature=c("BNPARAM"), function(X, ..., BNPARAM) standardGeneric("buildIndex"))

#' @export
#' @rdname findKNN-methods
setGeneric("findKNN", signature=c("BNINDEX", "BNPARAM"),
function(X, k, ..., BNINDEX, BNPARAM)
standardGeneric("findKNN")
)
setGeneric("findKNN", signature=c("X", "BNPARAM"), function(X, k, ..., BNPARAM) standardGeneric("findKNN"))

#' @export
#' @rdname queryKNN-methods
setGeneric("queryKNN", signature=c("BNINDEX", "BNPARAM"),
function(X, query, k, ..., BNINDEX, BNPARAM)
standardGeneric("queryKNN")
)
setGeneric("queryKNN", signature=c("X", "BNPARAM"), function(X, query, k, ..., BNPARAM) standardGeneric("queryKNN"))

#' @export
#' @rdname findNeighbors-methods
setGeneric("findNeighbors", signature=c("BNINDEX", "BNPARAM"),
function(X, threshold, ..., BNINDEX, BNPARAM)
standardGeneric("findNeighbors")
)
setGeneric("findNeighbors", signature=c("X", "BNPARAM"), function(X, threshold, ..., BNPARAM) standardGeneric("findNeighbors"))

#' @export
#' @rdname queryNeighbors-methods
setGeneric("queryNeighbors", signature=c("BNINDEX", "BNPARAM"),
function(X, query, threshold, ..., BNINDEX, BNPARAM)
standardGeneric("queryNeighbors")
)

#' @export
setGeneric("bnorder", function(x) standardGeneric("bnorder"))

#' @export
setGeneric("bndata", function(x) standardGeneric("bndata"))
setGeneric("queryNeighbors", signature=c("X", "BNPARAM"), function(X, query, threshold, ..., BNINDEX, BNPARAM) standardGeneric("queryNeighbors"))

#' @export
setGeneric("bndistance", function(x) standardGeneric("bndistance"))

# Generic purely for internal use, to help in defining other S4 methods.
setGeneric("spill_args", function(x) standardGeneric("spill_args"))
56 changes: 6 additions & 50 deletions R/AnnoyParam-class.R → R/annoy.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
#' Aaron Lun
#'
#' @seealso
#' \code{\link{buildAnnoy}}, for the index construction.
#'
#' \code{\link{findAnnoy}} and related functions, for the actual search.
#'
#' \linkS4class{BiocNeighborParam}, for the parent class and its available methods.
#'
#' @examples
Expand All @@ -37,18 +33,14 @@
#' @aliases
#' AnnoyParam-class
#' show,AnnoyParam-method
#' AnnoyParam_ntrees
#' AnnoyParam_directory
#' AnnoyParam_search_mult
#' [[,AnnoyParam-method
#' [[<-,AnnoyParam-method
#' buildIndex,AnnoyParam-method
#'
#' @docType class
#'
#' @export
#' @importFrom methods new
AnnoyParam <- function(ntrees=50, directory=tempdir(), search.mult=ntrees, distance="Euclidean") {
new("AnnoyParam", ntrees=as.integer(ntrees), dir=directory, distance=distance, search.mult=search.mult)
AnnoyParam <- function(ntrees=50, search.mult=ntrees, distance="Euclidean") {
new("AnnoyParam", ntrees=as.integer(ntrees), distance=distance, search.mult=search.mult)
}

#' @importFrom S4Vectors setValidity2
Expand All @@ -60,10 +52,6 @@ setValidity2("AnnoyParam", function(object) {
msg <- c(msg, "'ntrees' should be a positive integer scalar")
}

if (length(object[['dir']])!=1L) {
msg <- c(msg, "'dir' should be a string")
}

search.mult <- object[['search.mult']]
if (length(search.mult)!=1L || is.na(search.mult) || search.mult <= 1) {
msg <- c(msg, "'search.mult' should be a numeric scalar greater than 1")
Expand All @@ -73,46 +61,14 @@ setValidity2("AnnoyParam", function(object) {
return(TRUE)
})

#' @export
AnnoyParam_ntrees <- function(x) {
.Deprecated(new="x[['ntrees']]")
x@ntrees
}

#' @export
AnnoyParam_directory <- function(x) {
.Deprecated(new="x[['directory']]")
x@dir
}

#' @export
AnnoyParam_search_mult <- function(x) {
.Deprecated(new="x[['search.mult']]")
x@search.mult
}

#' @export
setMethod("show", "AnnoyParam", function(object) {
callNextMethod()
cat(sprintf("ntrees: %i\n", object[['ntrees']]))
cat(sprintf("directory: %s\n", object[['dir']]))
cat(sprintf("search multiplier: %i\n", object[['search.mult']]))
})

setMethod("spill_args", "AnnoyParam", function(x) {
list(ntrees=x[['ntrees']], directory=x[['dir']],
search.mult=x[['search.mult']], distance=bndistance(x))
})

#' @export
setMethod("[[", "AnnoyParam", function(x, i, j, ...) {
if (i=="directory") i <- "dir"
callNextMethod()
cat(sprintf("search.mult: %i\n", object[['search.mult']]))
})

#' @export
setReplaceMethod("[[", "AnnoyParam", function(x, i, j, ..., value) {
if (i=="directory") i <- "dir"
callNextMethod()
setMethod("buildIndex", "AnnoyParam", function(X, ..., BNPARAM) {
build_annoy(X, num_trees=BNPARAM@ntrees, distance=BNPARAM@distance)
})

81 changes: 23 additions & 58 deletions R/findKNN-functions.R → R/findKNN.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,66 +96,31 @@
#' head(out4$index)
#' head(out4$distance)
#'
#' @name findKNN-functions
#' @name findKNN
NULL

#' @export
#' @rdname findKNN-functions
#' @importFrom BiocParallel SerialParam
findAnnoy <- function(X, k, get.index=TRUE, get.distance=TRUE, last=k,
BPPARAM=SerialParam(), precomputed=NULL, subset=NULL, raw.index=NA, warn.ties=NA, ...)
{
.template_find_knn(X, k, get.index=get.index, get.distance=get.distance,
last=last, BPPARAM=BPPARAM, precomputed=precomputed, subset=subset,
exact=FALSE, warn.ties=FALSE, raw.index=FALSE,
buildFUN=buildAnnoy, searchFUN=find_annoy, searchArgsFUN=.find_annoy_args, ...)
}
#' @rdname findKNN
setMethod("findKNN", c("externalptr", "missing"), function(X, k, get.index=TRUE, get.distance=TRUE, num.threads=1, subset=NULL, ..., BNPARAM) {
if (is.null(subset)) {
output <- generic_find_knn(X, k=k, num_threads=num.threads, report_index=get.index, report_distance=get.distance)
} else {
output <- generic_find_knn_subset(X, k=k, chosen=subset, num_threads=num.threads, report_index=get.index, report_distance=get.distance)
}
if (!report.index) {
output$index <- NULL
}
if (!report.distance) {
output$distance <- NULL
}
})

#' @export
#' @rdname findKNN-functions
#' @importFrom BiocParallel SerialParam
findHnsw <- function(X, k, get.index=TRUE, get.distance=TRUE, last=k,
BPPARAM=SerialParam(), precomputed=NULL, subset=NULL, raw.index=NA, warn.ties=NA, ...)
{
.template_find_knn(X, k, get.index=get.index, get.distance=get.distance,
last=last, BPPARAM=BPPARAM, precomputed=precomputed, subset=subset,
exact=FALSE, warn.ties=FALSE, raw.index=FALSE,
buildFUN=buildHnsw, searchFUN=find_hnsw, searchArgsFUN=.find_hnsw_args, ...)
}

#' @export
#' @rdname findKNN-functions
#' @importFrom BiocParallel SerialParam
findKmknn <- function(X, k, get.index=TRUE, get.distance=TRUE, last=k,
BPPARAM=SerialParam(), precomputed=NULL, subset=NULL, raw.index=FALSE, warn.ties=TRUE, ...)
{
.template_find_knn(X, k, get.index=get.index, get.distance=get.distance,
last=last, BPPARAM=BPPARAM, precomputed=precomputed, subset=subset,
exact=TRUE, warn.ties=warn.ties, raw.index=raw.index,
buildFUN=buildKmknn, searchFUN=find_kmknn, searchArgsFUN=.find_kmknn_args, ...)
}

#' @export
#' @rdname findKNN-functions
#' @importFrom BiocParallel SerialParam
findVptree <- function(X, k, get.index=TRUE, get.distance=TRUE, last=k,
BPPARAM=SerialParam(), precomputed=NULL, subset=NULL, raw.index=FALSE, warn.ties=TRUE, ...)
{
.template_find_knn(X, k, get.index=get.index, get.distance=get.distance,
last=last, BPPARAM=BPPARAM, precomputed=precomputed, subset=subset,
exact=TRUE, warn.ties=warn.ties, raw.index=raw.index,
buildFUN=buildVptree, searchFUN=find_vptree, searchArgsFUN=.find_vptree_args, ...)
}

#' @export
#' @rdname findKNN-functions
#' @importFrom BiocParallel SerialParam
findExhaustive <- function(X, k, get.index=TRUE, get.distance=TRUE, last=k,
BPPARAM=SerialParam(), precomputed=NULL, subset=NULL, raw.index=FALSE, warn.ties=TRUE, ...)
{
.template_find_knn(X, k, get.index=get.index, get.distance=get.distance,
last=last, BPPARAM=BPPARAM, precomputed=precomputed, subset=subset,
exact=TRUE, warn.ties=warn.ties, raw.index=raw.index,
buildFUN=buildExhaustive, searchFUN=find_exhaustive, searchArgsFUN=.find_exhaustive_args, ...)
}

#' @rdname findKNN
setMethod("findKNN", c("ANY", "BiocNeighborIndex"), function(X, k, num.threads=1, BPPARAM=NULL, ..., BNPARAM) {
ptr <- buildIndex(X, BNPARAM)
if (!is.null(BPPARAM)) {
num.threads <- BiocParallel::bpnworkers(BPPARAM)
}
findKNN(ptr, k=k, ...)
})
11 changes: 11 additions & 0 deletions src/annoy.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "generics.h"
#include "l2norm.h"

// Turn off manual vectorization always, to avoid small inconsistencies in
// distance calculations across otherwise-compliant machines.
Expand All @@ -19,12 +20,22 @@ SEXP build_annoy(Rcpp::NumericMatrix data, int num_trees, double search_mult, st
opt.num_trees = num_trees;
opt.search_mult = search_mult;

BiocNeighborsPrebuiltPointer output(new BiocNeighborsPrebuilt);

if (distance == "Manhattan") {
knncolle_annoy::AnnoyBuilder<Annoy::Manhattan, WrappedMatrix, double> builder(opt);
return generic_build(builder, data);

} else if (distance == "Euclidean") {
knncolle_annoy::AnnoyBuilder<Annoy::Euclidean, WrappedMatrix, double> builder(opt);
return generic_build(builder, data);

} else if (distance == "Cosine") {
knncolle_annoy::AnnoyBuilder<Annoy::Euclidean, WrappedMatrix, double> builder(opt);
auto out = generic_build(builder, l2norm(data));
out->cosine = true;
return out;

} else {
throw std::runtime_error("unknown distance type '" + distance + "'");
return R_NilValue;
Expand Down
9 changes: 9 additions & 0 deletions src/exhaustive.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
#include "generics.h"
#include "l2norm.h"
#include "knncolle/knncolle.hpp"

//[[Rcpp::export(rng=false)]]
SEXP build_exhaustive(Rcpp::NumericMatrix data, std::string distance) {
if (distance == "Manhattan") {
knncolle::BruteforceBuilder<knncolle::ManhattanDistance, WrappedMatrix, double> builder;
return generic_build(builder, data);

} else if (distance == "Euclidean") {
knncolle::BruteforceBuilder<knncolle::EuclideanDistance, WrappedMatrix, double> builder;
return generic_build(builder, data);

} else if (distance == "Cosine") {
knncolle::BruteforceBuilder<knncolle::EuclideanDistance, WrappedMatrix, double> builder;
auto out = generic_build(builder, l2norm(data));
out->cosine = true;
return out;

} else {
throw std::runtime_error("unknown distance type '" + distance + "'");
return R_NilValue;
Expand Down
Loading

0 comments on commit f7b2125

Please sign in to comment.