Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Hierarchical Clustering & some docstring fixes #9

Merged
merged 12 commits into from
Sep 6, 2022
128 changes: 127 additions & 1 deletion src/MLJClusteringInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using Distances

# ===================================================================
## EXPORTS
export KMeans, KMedoids, DBSCAN
export KMeans, KMedoids, DBSCAN, HierarchicalClustering

# ===================================================================
## CONSTANTS
Expand All @@ -31,6 +31,7 @@ const PKG = "MLJClusteringInterface"
@mlj_model mutable struct KMeans <: MMI.Unsupervised
k::Int = 3::(_ ≥ 2)
metric::SemiMetric = SqEuclidean()
init = :kmpp
end

function MMI.fit(model::KMeans, verbosity::Int, X)
Expand Down Expand Up @@ -169,6 +170,47 @@ end
MMI.reporting_operations(::Type{<:DBSCAN}) = (:predict,)


# # HierarchicalClustering
@mlj_model mutable struct HierarchicalClustering <: MMI.Static
linkage::Symbol = :single :: (_ ∈ (:single, :average, :complete, :ward, :ward_presquared))
metric::SemiMetric = SqEuclidean()
branchorder::Symbol = :r :: (_ ∈ (:r, :barjoseph, :optimal))
h::Union{Nothing,Float64} = nothing
k::Int = 3
end
"""
struct DendrogramCutter{T}
dendrogram::T
end

Callable object to cut a dendrogram.
"""
struct DendrogramCutter{T}
dendrogram::T
end
"""
(cutter::DendrogramCutter)(; h = nothing, k = 3)

Cuts the dendrogram at height `h` or, if `height == nothing`, such that `k` clusters are obtained.
"""
function (cutter::DendrogramCutter)(; h = nothing, k = 3)
MMI.categorical(Cl.cutree(cutter.dendrogram, k = k, h = h))
end
function Base.show(io::IO, ::DendrogramCutter)
println(io, "Dendrogram Cutter.")
jbrea marked this conversation as resolved.
Show resolved Hide resolved
end

function MMI.predict(model::HierarchicalClustering, ::Nothing, X)
Xarray = MMI.matrix(X)
d = pairwise(model.metric, Xarray, dims = 1) # n x n
dendrogram = Cl.hclust(d, linkage = model.linkage, branchorder = model.branchorder)
cutter = DendrogramCutter(dendrogram)
yhat = cutter(h = model.h, k = model.k)
return yhat, (; cutter, dendrogram)
end

MMI.reporting_operations(::Type{<:HierarchicalClustering}) = (:predict,)

# # METADATA

metadata_pkg.(
Expand Down Expand Up @@ -205,6 +247,13 @@ metadata_model(
path = "$(PKG).DBSCAN"
)

metadata_model(
HierarchicalClustering,
human_name = "Hierarchical clusterer",
input = MMI.Table(Continuous),
output = MMI.Table(Continuous),
path = "$(PKG).HierarchicalClustering"
)
ablaom marked this conversation as resolved.
Show resolved Hide resolved

"""
$(MMI.doc_header(KMeans))
Expand Down Expand Up @@ -477,4 +526,81 @@ scatter(points, color=colors)
"""
DBSCAN

"""
$(MMI.doc_header(HierarchicalClustering))

[Hierarchical Clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering) is a
clustering algorithm that organizes the data in a dendrogram based on distances between
groups of points and computes cluster assignments by cutting the dendrogram at a given
height. More information is available at the [Clustering.jl
documentation](https://juliastats.org/Clustering.jl/stable/index.html). Use `predict` to
get cluster assignments. The dendrogram and the dendrogram cutter are accessed from the
machine report (see below).

This is a static implementation, i.e., it does not generalize to new data instances, and
there is no training data. For clusterers that do generalize, see [`KMeans`](@ref) or
[`KMedoids`](@ref).

In MLJ or MLJBase, create a machine with

mach = machine(model)

# Hyper-parameters

- `linkage = :single`: linkage method (:single, :average, :complete, :ward, :ward_presquared)

- `metrid = SqEuclidean`: metric (see `Distances.jl` for available metrics)

- `branchorder = :r`: branchorder (:r, :barjoseph, :optimal)

- `h = nothing`: height at which the dendrogram is cut

- `k = 3`: number of clusters; this is ignored, if `h != nothing`.


# Operations

- `predict(mach, X)`: return cluster label assignments, as an unordered
`CategoricalVector`. Here `X` is any table of input features (eg, a `DataFrame`) whose
columns are of scitype `Continuous`; check column scitypes with `schema(X)`. Note that
points of type `noise` will always get a label of `0`.
jbrea marked this conversation as resolved.
Show resolved Hide resolved


# Report

After calling `predict(mach)`, the fields of `report(mach)` are:

- `dendrogram`: the dendrogram that was computed when calling `predict`.

- `cutter`: a dendrogram cutter that can be called with a height `h` or a number of clusters `k`, to obtain a new assignment of the data points to clusters.
jbrea marked this conversation as resolved.
Show resolved Hide resolved

# Examples

```
using MLJ

X, labels = make_moons(400, noise=0.09, rng=1) # synthetic data with 2 clusters; X
y = map(labels) do label
label == 0 ? "cookie" : "monster"
end;
y = coerce(y, Multiclass);
jbrea marked this conversation as resolved.
Show resolved Hide resolved

HierarchicalClustering = @load HierarchicalClustering pkg=Clustering
model = HierarchicalClustering(linkage = :complete)
mach = machine(model)

# compute and output cluster assignments for observations in `X`:
yhat = predict(mach, X)

# plot dendrogram:
using StatsPlots
plot(report(mach).dendrogram)

# make new predictions by cutting the dendrogram at another height
report(mach).cutter(h = 2.5)
```

"""
HierarchicalClustering

ablaom marked this conversation as resolved.
Show resolved Hide resolved
end # module