Skip to content

Commit

Permalink
WIP: multiview_pca
Browse files Browse the repository at this point in the history
  • Loading branch information
stnava committed Oct 14, 2024
1 parent 5875593 commit f964743
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 1 deletion.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ export(multiResRandomForestRegression)
export(multichannelPCA)
export(multichannelToVector)
export(multiscaleSVD)
export(multiview_pca)
export(n3BiasFieldCorrection)
export(n3BiasFieldCorrection2)
export(n4BiasFieldCorrection)
Expand Down
105 changes: 104 additions & 1 deletion R/multiscaleSVDxpts.R
Original file line number Diff line number Diff line change
Expand Up @@ -6137,4 +6137,107 @@ simlr_feature_orth <- function( p ) {
oo = 0.0
for ( k in 1:length(p)) oo = oo + invariant_orthogonality_defect( data.matrix( p[[k]] ))
return( oo / length(p))
}
}






#' Multiview PCA
#'
#' Perform Multiview PCA on multiple datasets with an option for sparse PCA.
#'
#' @param views A list of data matrices for each view.
#' @param n_components Number of principal components to compute.
#' @param sparse vector of length views with values between zero and one
#' @param max_iter Maximum number of iterations for the optimization.
#' @param sparsenessAlg NA is default otherwise basic, spmp or orthorank
#' @param verbose Logical, whether to print information about energy and sparsity.
#' @return A list containing the common representation Z and transformation matrices W.
#' @examples
#' set.seed(123)
#' n_samples <- 100
#' n_features_1 <- 50
#' n_features_2 <- 60
#' n_features_3 <- 70
#' n_components <- 5
#' view1 <- matrix(rnorm(n_samples * n_features_1), nrow = n_samples, ncol = n_features_1)
#' view2 <- matrix(rnorm(n_samples * n_features_2), nrow = n_samples, ncol = n_features_2)
#' view3 <- matrix(rnorm(n_samples * n_features_3), nrow = n_samples, ncol = n_features_3)
#' result <- multiview_pca(list(view1, view2, view3), n_components, sparse = rep(0.5,3),
#' verbose = TRUE)
#' print(result)
#' @export
multiview_pca <- function(views, n_components, sparse = 0.5, max_iter = 100, sparsenessAlg='basic', verbose = FALSE) {
# Validate alpha
if ( any(sparse < 0) || any( sparse > 1) ) {
stop("sparse must be between 0 and 1.")
}

n_views <- length(views)
n_samples <- nrow(views[[1]])

# Initialize Z (latent representation) with random values
Z <- matrix(rnorm(n_samples * n_components), nrow = n_samples, ncol = n_components)

# Initialize W_list: each W_i should have ncol(views[[i]]) rows and n_components columns
W_list <- vector("list", n_views)
for (i in seq_len(n_views)) {
W_list[[i]] <- matrix(rnorm(ncol(views[[i]]) * n_components), nrow = ncol(views[[i]]), ncol = n_components)
}

prev_Z <- Z
tol <- 1e-6 # Tolerance for convergence

# Iterative optimization
for (iter in 1:max_iter) {
# Fix W, solve for Z
Z <- matrix(0, nrow = n_samples, ncol = n_components)
for (i in seq_len(n_views)) {
# Ensure matrix multiplication is conformable
if (ncol(views[[i]]) == nrow(W_list[[i]])) {
Z <- Z + views[[i]] %*% W_list[[i]]
} else {
stop("Non-conformable dimensions between views and W_list matrices.")
}
}
Z <- Z / n_views

# Fix Z, solve for W
for (i in seq_len(n_views)) {
# Non-sparse case: direct computation of W via least squares
# Regularize Z'Z to ensure invertibility
reg <- diag(1e-6, n_components)
W_list[[i]] <- solve(t(Z) %*% Z + reg) %*% t(Z) %*% views[[i]]
W_list[[i]] <- t(W_list[[i]]) # Transpose to match dimensions
if ( sparse[i] > 0 & sparse[i] < 1) {
for ( jj in 1:ncol(W_list[[i]]) ) {
W_list[[i]] = orthogonalizeAndQSparsify( W_list[[i]], sparse[i], 'positive',
sparsenessAlg=sparsenessAlg )
}
}
}

# Calculate energy (reconstruction error)
energy <- 0
for (i in seq_len(n_views)) {
reconstruction <- Z %*% t(W_list[[i]])
energy <- energy + sum((views[[i]] - reconstruction)^2)
}

# Print energy if verbose
if (verbose) {
cat(sprintf("Iteration %d: Energy = %.6f\n", iter, energy))
}

# Check for convergence
if (max(abs(prev_Z - Z)) < tol) {
message("Converged in ", iter, " iterations.")
break
}
prev_Z <- Z
}

return(list(Z = Z, W = W_list))
}
48 changes: 48 additions & 0 deletions man/multiview_pca.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit f964743

Please sign in to comment.