-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Annabelle Gillet
committed
Apr 27, 2021
1 parent
95b1dca
commit 97b61a3
Showing
22 changed files
with
112,979 additions
and
1,514 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,4 @@ build | |
.settings | ||
.worksheet | ||
|
||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package tdm.core | ||
|
||
import org.apache.spark.sql.Row | ||
import shapeless.{HList, HMap} | ||
import tdm._ | ||
|
||
class CollectedTensor[T, DL <: HList] private[core] | ||
(val typeList: List[TensorDimension[_]], val dimensions: HMap[DimensionMap], val data: Array[Row]) | ||
(implicit tensorTypeAuthorized: Numeric[T]) { | ||
private var cachedRow: Option[Row] = None | ||
private var cachedIndex: Option[Int] = None | ||
|
||
/** | ||
* Number of non-zero elements of this tensor. | ||
*/ | ||
val size = data.length | ||
|
||
/** | ||
* Produce the range of all index of the non-zero elements of this tensor. | ||
* | ||
* @return a [[Range]] value from 0 to one less than the number of non-zero elements of this tensor. | ||
*/ | ||
def indices: Range = data.indices | ||
|
||
/** | ||
* Get the value of the tensor at the index i. | ||
* | ||
* @param i | ||
* @return | ||
*/ | ||
def apply(i: Int): T = { | ||
val row = getRow(i) | ||
row.getAs[T](cachedRow.get.fieldIndex(Tensor.TENSOR_VALUES_COLUMN_NAME)) | ||
} | ||
|
||
/** | ||
* Get the value of the given dimension at the index i. | ||
* | ||
* @param dimension | ||
* @param i | ||
* @return | ||
*/ | ||
def apply[CT, D <: TensorDimension[_]](dimension: D, i: Int) | ||
(implicit eq: dimension.DimensionType =:= CT, | ||
contains: ContainsConstraint[DL, D]): CT = { | ||
val row = getRow(i) | ||
row.getAs[CT](row.fieldIndex(dimension.name)) | ||
} | ||
|
||
/** | ||
* Sort the values by ascending order. | ||
* | ||
* @return the sorted [[CollectedTensor]] | ||
*/ | ||
def orderByValues(): CollectedTensor[T, DL] = { | ||
new CollectedTensor[T, DL](typeList, dimensions, | ||
data.sortWith((r1, r2) => tensorTypeAuthorized.lteq(r1.getAs[T](r1.fieldIndex(Tensor.TENSOR_VALUES_COLUMN_NAME)), r2.getAs[T](r2.fieldIndex(Tensor.TENSOR_VALUES_COLUMN_NAME))))) | ||
} | ||
|
||
/** | ||
* Sort the values by descending order. | ||
* | ||
* @return the sorted [[CollectedTensor]] | ||
*/ | ||
def orderByValuesDesc(): CollectedTensor[T, DL] = { | ||
new CollectedTensor[T, DL](typeList, dimensions, | ||
data.sortWith((r1, r2) => tensorTypeAuthorized.gteq(r1.getAs[T](r1.fieldIndex(Tensor.TENSOR_VALUES_COLUMN_NAME)), r2.getAs[T](r2.fieldIndex(Tensor.TENSOR_VALUES_COLUMN_NAME))))) | ||
} | ||
|
||
private def getRow(i: Int): Row = { | ||
if (cachedIndex.isEmpty || cachedIndex.get != i) { | ||
cachedRow = Some(data(i)) | ||
cachedIndex = Some(i) | ||
} | ||
cachedRow.get | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
package tdm.core | ||
|
||
import org.apache.spark.sql.{DataFrame, SparkSession} | ||
import shapeless.{::, HList, HMap, HNil} | ||
import tdm.DimensionMap | ||
|
||
class KruskalTensor[DL <: HList] private[core] (val typeList: List[TensorDimension[_]], | ||
val lambdas: Map[Int, Double], | ||
val factorMatrices: Map[String, DataFrame], | ||
val corcondia: Option[Double]) | ||
(implicit spark: SparkSession) { | ||
|
||
/** | ||
* Extract a specific dimension result from this [[KruskalTensor]] as a 2-order tensor, | ||
* with dimensions [[D]] and [[Rank]]. | ||
* | ||
* @param dimension: the dimension to extract | ||
* @return a 2-order tensor with schema [[D]] :: [[Rank]] | ||
*/ | ||
def extract[D <: TensorDimension[_]](dimension: D): Tensor[Double, D :: Rank.type :: HNil] = { | ||
// Initializing the dimensions of the new tensor | ||
var newDimensions = HMap.empty[DimensionMap] | ||
newDimensions = newDimensions + (dimension -> dimension.produceDimension()) | ||
newDimensions = newDimensions + (Rank -> Rank.produceDimension()) | ||
|
||
// Intializing the new tensor | ||
val tensor = new Tensor[Double, D :: Rank.type :: HNil](typeList, newDimensions) | ||
tensor.values = factorMatrices(dimension.name) | ||
tensor.empty = false | ||
|
||
tensor | ||
} | ||
|
||
def reconstruct(): Tensor[Double, DL] = ??? | ||
} | ||
|
||
object Rank extends TensorDimension[Int] |
Oops, something went wrong.