Skip to content

Commit

Permalink
Extract tensorflow module
Browse files Browse the repository at this point in the history
  • Loading branch information
juliabeliaeva committed Aug 2, 2022
1 parent a41a62f commit 24f9462
Show file tree
Hide file tree
Showing 292 changed files with 36 additions and 20 deletions.
10 changes: 0 additions & 10 deletions api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,10 @@ project.setDescription("This module contains the Kotlin API for building, traini
dependencies {
api project(":dataset")
implementation 'org.jetbrains.kotlin:kotlin-stdlib-jdk8:1.6.21'
api group: 'org.tensorflow', name: 'tensorflow', version: '1.15.0'
api 'com.github.doyaaaaaken:kotlin-csv-jvm:0.7.3' // for csv parsing
api 'io.github.microutils:kotlin-logging:2.1.21' // for logging
api 'io.jhdf:jhdf:0.5.7' // for hdf5 parsing
api 'com.beust:klaxon:5.5'
testImplementation 'ch.qos.logback:logback-classic:1.2.11'
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.2'
testImplementation 'org.junit.jupiter:junit-jupiter-engine:5.8.2'
testImplementation 'org.junit.jupiter:junit-jupiter-params:5.8.2'
testImplementation 'org.junit.jupiter:junit-jupiter-engine:5.8.2'
}

compileKotlin {
Expand All @@ -27,10 +21,6 @@ kotlin {
explicitApiWarning()
}

test {
useJUnitPlatform()
}

task fatJar(type: Jar) {
duplicatesStrategy = DuplicatesStrategy.INCLUDE

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,11 @@ public class TensorShape() {
}
}

internal fun getDimsOfArray(data: kotlin.Array<*>): LongArray {
/**
* Get shape of array of arrays (of arrays...) of Array of elements of any type.
* If the most inner array does not have any elements its size is skipped in the result.
*/
public fun getDimsOfArray(data: kotlin.Array<*>): LongArray {
fun appendPrimitiveArraySize(size: Int, acc: MutableList<Long>): LongArray {
acc += size.toLong()
return acc.toLongArray()
Expand All @@ -240,7 +244,6 @@ internal fun getDimsOfArray(data: kotlin.Array<*>): LongArray {
}

/**
* Get shape of array of arrays (of arrays...) of Array of elements of any type.
* If the most inner array does not have any elements its size is missed in result
* @see getDimsOfArray
*/
internal val Array<*>.tensorShape: TensorShape get() = TensorShape(getDimsOfArray(this))
2 changes: 2 additions & 0 deletions examples/build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
dependencies {
implementation 'org.jetbrains.kotlin:kotlin-stdlib-jdk8:1.6.21'
api project(":api")
api project(":tensorflow")
api project(":dataset")
api project(":visualization")
api project(":onnx")
Expand All @@ -25,6 +26,7 @@ if (publishedArtifactsVersion != null && !publishedArtifactsVersion.isBlank()) {
configurations.all {
resolutionStrategy.dependencySubstitution {
substitute project(":api") using module("org.jetbrains.kotlinx:kotlin-deeplearning-api:$publishedArtifactsVersion") because "testing published artifacts"
substitute project(":tensorflow") using module("org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:$publishedArtifactsVersion") because "testing published artifacts"
substitute project(":dataset") using module("org.jetbrains.kotlinx:kotlin-deeplearning-dataset:$publishedArtifactsVersion") because "testing published artifacts"
substitute project(":visualization") using module("org.jetbrains.kotlinx:kotlin-deeplearning-visualization:$publishedArtifactsVersion") because "testing published artifacts"
substitute project(":onnx") using module("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:$publishedArtifactsVersion") because "testing published artifacts"
Expand Down
5 changes: 2 additions & 3 deletions settings.gradle
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
rootProject.name = 'KotlinDL'
include("api")
include("tensorflow")
include("visualization")
include("examples")
include("dataset")
include("onnx")


include("onnx")
25 changes: 25 additions & 0 deletions tensorflow/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
project.setDescription("This module contains the Kotlin API for building, training, and evaluating TensorFlow models.")

dependencies {
api project(":dataset")
api project(":api")
implementation 'org.jetbrains.kotlin:kotlin-stdlib-jdk8:1.6.21'
api group: 'org.tensorflow', name: 'tensorflow', version: '1.15.0'
testImplementation 'ch.qos.logback:logback-classic:1.2.11'
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.2'
testImplementation 'org.junit.jupiter:junit-jupiter-engine:5.8.2'
testImplementation 'org.junit.jupiter:junit-jupiter-params:5.8.2'
testImplementation 'org.junit.jupiter:junit-jupiter-engine:5.8.2'
}

compileKotlin {
kotlinOptions.jvmTarget = "1.8"
}

compileTestKotlin {
kotlinOptions.jvmTarget = "1.8"
}

kotlin {
explicitApiWarning()
}
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,3 @@ public class TFModelHub(cacheDirectory: File) : ModelHub(cacheDirectory) {
return HdfFile(File(fileName))
}
}




File renamed without changes.
File renamed without changes.
1 change: 1 addition & 0 deletions visualization/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ project.setDescription("This module contains the Kotlin API for visualization of
dependencies {
implementation 'org.jetbrains.kotlin:kotlin-stdlib-jdk8:1.6.21'
api project(":api")
api project(":tensorflow")
}

dependencies {
Expand Down

0 comments on commit 24f9462

Please sign in to comment.