Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Hierarchical Clustering & some docstring fixes #9

Merged
merged 12 commits into from
Sep 6, 2022
127 changes: 110 additions & 17 deletions src/MLJClusteringInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function MMI.fit(model::KMeans, verbosity::Int, X)
# NOTE: using transpose here to get a LinearAlgebra.Transpose object
# which Kmeans can handle.
Xarray = transpose(MMI.matrix(X))
result = Cl.kmeans(Xarray, model.k; distance=model.metric, init = model.init)
result = Cl.kmeans(Xarray, model.k; distance=model.metric)
cluster_labels = MMI.categorical(1:model.k)
fitresult = (result.centers, cluster_labels) # centers (p x k)
cache = nothing
Expand Down Expand Up @@ -177,27 +177,36 @@ MMI.reporting_operations(::Type{<:DBSCAN}) = (:predict,)
branchorder::Symbol = :r :: (_ ∈ (:r, :barjoseph, :optimal))
h::Union{Nothing,Float64} = nothing
k::Int = 3
_cache = nothing
end
"""
struct DendrogramCutter{T}
dendrogram::T
end

function compute_dendrogram(X, metric, linkage, branchorder)
# @info "computing dendrogram"
Xarray = MMI.matrix(X)
d = pairwise(metric, Xarray, dims = 1) # n x n
Cl.hclust(d, linkage = linkage, branchorder = branchorder)
Callable object to cut a dendrogram.
"""
struct DendrogramCutter{T}
dendrogram::T
end
"""
(cutter::DendrogramCutter)(; h = nothing, k = 3)

Cuts the dendrogram at height `h` or, if `height == nothing`, such that `k` clusters are obtained.
"""
function (cutter::DendrogramCutter)(; h = nothing, k = 3)
MMI.categorical(Cl.cutree(cutter.dendrogram, k = k, h = h))
end
function Base.show(io::IO, ::DendrogramCutter)
println(io, "Dendrogram Cutter.")
jbrea marked this conversation as resolved.
Show resolved Hide resolved
end

function MMI.predict(model::HierarchicalClustering, ::Nothing, X)
Xhash = hash(X)
if model._cache === nothing || model._cache.Xhash != Xhash
dendrogram = compute_dendrogram(X,
model.metric,
model.linkage,
model.branchorder)
model._cache = (dendrogram = dendrogram, Xhash = Xhash)
end
yhat = MMI.categorical(Cl.cutree(model._cache.dendrogram, k = model.k, h = model.h))
return yhat, model._cache
Xarray = MMI.matrix(X)
d = pairwise(model.metric, Xarray, dims = 1) # n x n
dendrogram = Cl.hclust(d, linkage = model.linkage, branchorder = model.branchorder)
cutter = DendrogramCutter(dendrogram)
yhat = cutter(h = model.h, k = model.k)
return yhat, (; cutter, dendrogram)
end

MMI.reporting_operations(::Type{<:HierarchicalClustering}) = (:predict,)
Expand Down Expand Up @@ -238,6 +247,13 @@ metadata_model(
path = "$(PKG).DBSCAN"
)

metadata_model(
HierarchicalClustering,
human_name = "Hierarchical clusterer",
input = MMI.Table(Continuous),
output = MMI.Table(Continuous),
path = "$(PKG).HierarchicalClustering"
)
ablaom marked this conversation as resolved.
Show resolved Hide resolved

"""
$(MMI.doc_header(KMeans))
Expand Down Expand Up @@ -510,4 +526,81 @@ scatter(points, color=colors)
"""
DBSCAN

"""
$(MMI.doc_header(HierarchicalClustering))

[Hierarchical Clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering) is a
clustering algorithm that organizes the data in a dendrogram based on distances between
groups of points and computes cluster assignments by cutting the dendrogram at a given
height. More information is available at the [Clustering.jl
documentation](https://juliastats.org/Clustering.jl/stable/index.html). Use `predict` to
get cluster assignments. The dendrogram and the dendrogram cutter are accessed from the
machine report (see below).

This is a static implementation, i.e., it does not generalize to new data instances, and
there is no training data. For clusterers that do generalize, see [`KMeans`](@ref) or
[`KMedoids`](@ref).

In MLJ or MLJBase, create a machine with

mach = machine(model)

# Hyper-parameters

- `linkage = :single`: linkage method (:single, :average, :complete, :ward, :ward_presquared)

- `metrid = SqEuclidean`: metric (see `Distances.jl` for available metrics)

- `branchorder = :r`: branchorder (:r, :barjoseph, :optimal)

- `h = nothing`: height at which the dendrogram is cut

- `k = 3`: number of clusters; this is ignored, if `h != nothing`.


# Operations

- `predict(mach, X)`: return cluster label assignments, as an unordered
`CategoricalVector`. Here `X` is any table of input features (eg, a `DataFrame`) whose
columns are of scitype `Continuous`; check column scitypes with `schema(X)`. Note that
points of type `noise` will always get a label of `0`.
jbrea marked this conversation as resolved.
Show resolved Hide resolved


# Report

After calling `predict(mach)`, the fields of `report(mach)` are:

- `dendrogram`: the dendrogram that was computed when calling `predict`.

- `cutter`: a dendrogram cutter that can be called with a height `h` or a number of clusters `k`, to obtain a new assignment of the data points to clusters.
jbrea marked this conversation as resolved.
Show resolved Hide resolved

# Examples

```
using MLJ

X, labels = make_moons(400, noise=0.09, rng=1) # synthetic data with 2 clusters; X
y = map(labels) do label
label == 0 ? "cookie" : "monster"
end;
y = coerce(y, Multiclass);
jbrea marked this conversation as resolved.
Show resolved Hide resolved

HierarchicalClustering = @load HierarchicalClustering pkg=Clustering
model = HierarchicalClustering(linkage = :complete)
mach = machine(model)

# compute and output cluster assignments for observations in `X`:
yhat = predict(mach, X)

# plot dendrogram:
using StatsPlots
plot(report(mach).dendrogram)

# make new predictions by cutting the dendrogram at another height
report(mach).cutter(h = 2.5)
```

"""
HierarchicalClustering

ablaom marked this conversation as resolved.
Show resolved Hide resolved
end # module