Skip to content

Commit

Permalink
Merge pull request #9 from jbrea/master
Browse files Browse the repository at this point in the history
add Hierarchical Clustering & some docstring fixes
  • Loading branch information
ablaom authored Sep 6, 2022
2 parents 6f60a35 + 9f08382 commit a9e19bf
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 3 deletions.
126 changes: 124 additions & 2 deletions src/MLJClusteringInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using Distances

# ===================================================================
## EXPORTS
export KMeans, KMedoids, DBSCAN
export KMeans, KMedoids, DBSCAN, HierarchicalClustering

# ===================================================================
## CONSTANTS
Expand All @@ -31,6 +31,7 @@ const PKG = "MLJClusteringInterface"
@mlj_model mutable struct KMeans <: MMI.Unsupervised
k::Int = 3::(_ ≥ 2)
metric::SemiMetric = SqEuclidean()
init = :kmpp
end

function MMI.fit(model::KMeans, verbosity::Int, X)
Expand Down Expand Up @@ -169,10 +170,51 @@ end
MMI.reporting_operations(::Type{<:DBSCAN}) = (:predict,)


# # HierarchicalClustering
@mlj_model mutable struct HierarchicalClustering <: MMI.Static
linkage::Symbol = :single :: (_ ∈ (:single, :average, :complete, :ward, :ward_presquared))
metric::SemiMetric = SqEuclidean()
branchorder::Symbol = :r :: (_ ∈ (:r, :barjoseph, :optimal))
h::Union{Nothing,Float64} = nothing
k::Int = 3
end
"""
struct DendrogramCutter{T}
dendrogram::T
end
Callable object to cut a dendrogram.
"""
struct DendrogramCutter{T}
dendrogram::T
end
"""
(cutter::DendrogramCutter)(; h = nothing, k = 3)
Cuts the dendrogram at height `h` or, if `height == nothing`, such that `k` clusters are obtained.
"""
function (cutter::DendrogramCutter)(; h = nothing, k = 3)
MMI.categorical(Cl.cutree(cutter.dendrogram, k = k, h = h))
end
function Base.show(io::IO, ::DendrogramCutter)
print(io, "Dendrogram Cutter.")
end

function MMI.predict(model::HierarchicalClustering, ::Nothing, X)
Xarray = MMI.matrix(X)
d = pairwise(model.metric, Xarray, dims = 1) # n x n
dendrogram = Cl.hclust(d, linkage = model.linkage, branchorder = model.branchorder)
cutter = DendrogramCutter(dendrogram)
yhat = cutter(h = model.h, k = model.k)
return yhat, (; cutter, dendrogram)
end

MMI.reporting_operations(::Type{<:HierarchicalClustering}) = (:predict,)

# # METADATA

metadata_pkg.(
(KMeans, KMedoids, DBSCAN),
(KMeans, KMedoids, DBSCAN, HierarchicalClustering),
name="Clustering",
uuid="aaaa29a8-35af-508c-8bc3-b662a17a0fe5",
url="https://github.com/JuliaStats/Clustering.jl",
Expand Down Expand Up @@ -205,6 +247,12 @@ metadata_model(
path = "$(PKG).DBSCAN"
)

metadata_model(
HierarchicalClustering,
human_name = "hierarchical clusterer",
input = MMI.Table(Continuous),
path = "$(PKG).HierarchicalClustering"
)

"""
$(MMI.doc_header(KMeans))
Expand Down Expand Up @@ -477,4 +525,78 @@ scatter(points, color=colors)
"""
DBSCAN

"""
$(MMI.doc_header(HierarchicalClustering))
[Hierarchical Clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering) is a
clustering algorithm that organizes the data in a dendrogram based on distances between
groups of points and computes cluster assignments by cutting the dendrogram at a given
height. More information is available at the [Clustering.jl
documentation](https://juliastats.org/Clustering.jl/stable/index.html). Use `predict` to
get cluster assignments. The dendrogram and the dendrogram cutter are accessed from the
machine report (see below).
This is a static implementation, i.e., it does not generalize to new data instances, and
there is no training data. For clusterers that do generalize, see [`KMeans`](@ref) or
[`KMedoids`](@ref).
In MLJ or MLJBase, create a machine with
mach = machine(model)
# Hyper-parameters
- `linkage = :single`: linkage method (:single, :average, :complete, :ward, :ward_presquared)
- `metric = SqEuclidean`: metric (see `Distances.jl` for available metrics)
- `branchorder = :r`: branchorder (:r, :barjoseph, :optimal)
- `h = nothing`: height at which the dendrogram is cut
- `k = 3`: number of clusters.
If both `k` and `h` are specified, it is guaranteed that the number of clusters is not less than `k` and their height is not above `h`.
# Operations
- `predict(mach, X)`: return cluster label assignments, as an unordered
`CategoricalVector`. Here `X` is any table of input features (eg, a `DataFrame`) whose
columns are of scitype `Continuous`; check column scitypes with `schema(X)`.
# Report
After calling `predict(mach)`, the fields of `report(mach)` are:
- `dendrogram`: the dendrogram that was computed when calling `predict`.
- `cutter`: a dendrogram cutter that can be called with a height `h` or a number of clusters `k`, to obtain a new assignment of the data points to clusters (see example below).
# Examples
```
using MLJ
X, labels = make_moons(400, noise=0.09, rng=1) # synthetic data with 2 clusters; X
HierarchicalClustering = @load HierarchicalClustering pkg=Clustering
model = HierarchicalClustering(linkage = :complete)
mach = machine(model)
# compute and output cluster assignments for observations in `X`:
yhat = predict(mach, X)
# plot dendrogram:
using StatsPlots
plot(report(mach).dendrogram)
# make new predictions by cutting the dendrogram at another height
report(mach).cutter(h = 2.5)
```
"""
HierarchicalClustering

end # module
19 changes: 18 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,25 @@ end

end

# # HierarchicalClustering

@testset "HierarchicalClustering" begin
h = Inf; k = 1; linkage = :complete; bo = :optimal;
metric = Distances.Euclidean()
mach = machine(HierarchicalClustering(h = h, k = k, metric = metric,
linkage = linkage, branchorder = bo))
yhat = predict(mach, X)
@test length(union(yhat)) == 1 # uses h = Inf
cutter = report(mach).cutter
@test length(union(cutter(k = 4))) == 4 # uses k = 4
dendro = Clustering.hclust(Distances.pairwise(metric, hcat(X...), dims = 1),
linkage = linkage, branchorder = bo)
@test cutter(k = 2) == Clustering.cutree(dendro, k = 2)
@test report(mach).dendrogram.heights == dendro.heights
end

@testset "MLJ interface" begin
models = [KMeans, KMedoids, DBSCAN]
models = [KMeans, KMedoids, DBSCAN, HierarchicalClustering]
failures, summary = MLJTestIntegration.test(
models,
X;
Expand Down

0 comments on commit a9e19bf

Please sign in to comment.