Skip to content

Commit

Permalink
Update logic to separate tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
cupertank committed Jul 11, 2024
1 parent fc03e2b commit 8f7a653
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 88 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package io.kinference.primitives

import org.gradle.api.services.BuildService
import org.gradle.api.services.BuildServiceParameters

abstract class PrimitivesCache : BuildService<BuildServiceParameters.None> {
val sourceSetToResolved = mutableMapOf<String, Boolean>()
val resolvedPaths = mutableSetOf<String>()
}
Original file line number Diff line number Diff line change
@@ -1,53 +1,54 @@
package io.kinference.primitives

import org.gradle.api.Plugin
import org.gradle.api.Project
import org.gradle.api.tasks.PathSensitivity
import org.gradle.api.*
import org.jetbrains.kotlin.gradle.dsl.*
import org.jetbrains.kotlin.gradle.plugin.KotlinCompilation
import org.jetbrains.kotlin.gradle.plugin.KotlinPlatformType
import org.jetbrains.kotlin.gradle.tasks.KotlinCompileTool

class PrimitivesGradlePlugin : Plugin<Project> {
override fun apply(project: Project) {
val kotlinExt = project.extensions.findByType(KotlinProjectExtension::class.java) ?: return

val primitivesExt = project.extensions.create(extensionName, PrimitivesExtension::class.java)

val primitivesTask = project.tasks.register(primitivesTaskName, PrimitivesTask::class.java) { primitivesTask ->
primitivesTask.generationPath.set(primitivesExt.generationPath)

val primitivesCache = project.gradle.sharedServices.registerIfAbsent("${project.path}_${primitivesCacheName}", PrimitivesCache::class.java) {
it.maxParallelUsages.set(1)
}

project.afterEvaluate {
val buildDependenciesTasks = project.tasks.named("buildNeeded").get().dependsOn.filterNot { it == "build" }.filterNotNull()
for (task in buildDependenciesTasks) {
primitivesTask.get().dependsOn(task)
}
val generalPrimitivesTask = project.tasks.register(primitivesTaskName) {
it.group = "generate"
}

kotlinExt.sourceSets.all { sourceSet ->
val sourceSetName = sourceSet.name
val fullPath = primitivesExt.generationPath.dir(sourceSetName)

sourceSet.kotlin.srcDir(fullPath)

//Support for Incremental compilation
primitivesTask.get().inputs
.files(sourceSet.kotlin.asFileTree)
.withPathSensitivity(PathSensitivity.ABSOLUTE)
.normalizeLineEndings()
.skipWhenEmpty()
sourceSet.kotlin.srcDir(primitivesExt.generationPath.dir(sourceSet.name))
primitivesCache.get().sourceSetToResolved[sourceSet.name] = false
}

fun configureCompilation(compilation: KotlinCompilation<*>) {
val targetTask = compilation.compileTaskProvider
if (compilation.platformType !in setOf(KotlinPlatformType.common, KotlinPlatformType.jvm, KotlinPlatformType.js)) return

targetTask.configure {
it.dependsOn(primitivesTask)
val compileTask = compilation.compileTaskProvider.get() as KotlinCompileTool
val taskName = compileTask.name.replace("compile", "generate")

val primitivesTask = compilation.project.tasks.register(taskName, PrimitivesTask::class.java) { primitiveTask ->
primitiveTask.usesService(primitivesCache)
primitiveTask.primitivesCache.set(primitivesCache)

primitiveTask.generationPath.set(primitivesExt.generationPath)
primitiveTask.inputFiles.from(compileTask.sources)
primitiveTask.libraries.from(compileTask.libraries)
primitiveTask.compilation.set(compilation)
}

compileTask.dependsOn(primitivesTask)
generalPrimitivesTask.get().dependsOn(primitivesTask)
}

if (kotlinExt is KotlinMultiplatformExtension) {
kotlinExt.targets.all { kotlinTarget ->
kotlinTarget.compilations.all { compilation ->
kotlinExt.targets.all { target ->
target.compilations.all { compilation ->
configureCompilation(compilation)
}
}
Expand All @@ -61,7 +62,8 @@ class PrimitivesGradlePlugin : Plugin<Project> {
}

companion object {
const val primitivesTaskName = "generatePrimitives"
const val primitivesCacheName = "primitivesCache"
const val primitivesTaskName = "generateAllPrimitives"
const val extensionName = "primitives"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,94 +5,98 @@ import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.generator.PrimitiveGenerator
import io.kinference.primitives.utils.psi.isAnnotatedWith
import org.gradle.api.DefaultTask
import org.gradle.api.GradleException
import org.gradle.api.file.ConfigurableFileCollection
import org.gradle.api.file.DirectoryProperty
import org.gradle.api.tasks.OutputDirectory
import org.gradle.api.tasks.TaskAction
import org.gradle.api.provider.Property
import org.gradle.api.tasks.*
import org.gradle.work.NormalizeLineEndings
import org.jetbrains.kotlin.cli.common.config.addKotlinSourceRoot
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.cli.jvm.config.addJvmClasspathRoots
import org.jetbrains.kotlin.gradle.dsl.*
import org.jetbrains.kotlin.gradle.plugin.KotlinPlatformType
import org.jetbrains.kotlin.gradle.plugin.mpp.*
import org.jetbrains.kotlin.gradle.targets.js.ir.KotlinJsIrCompilation
import org.jetbrains.kotlin.gradle.tasks.KotlinCompileTool

import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension
import org.jetbrains.kotlin.gradle.dsl.KotlinProjectExtension
import org.jetbrains.kotlin.gradle.plugin.*
import java.io.File

abstract class PrimitivesTask : DefaultTask() {
@get:OutputDirectory
@get:Internal
abstract val generationPath: DirectoryProperty

@get:Classpath
abstract val libraries: ConfigurableFileCollection

@get:InputFiles
@get:SkipWhenEmpty
@get:IgnoreEmptyDirectories
@get:NormalizeLineEndings
@get:PathSensitive(PathSensitivity.RELATIVE)
abstract val inputFiles: ConfigurableFileCollection

@get:Internal
abstract val compilation: Property<KotlinCompilation<*>>

@get:Internal
abstract val primitivesCache: Property<PrimitivesCache>

init {
group = "generate"
description = "Generates primitives from sources"
}

@TaskAction
fun generate() {
val kotlinExt = project.extensions.findByType(KotlinProjectExtension::class.java) ?: return
val (targets, isMpp) = when (kotlinExt) {
is KotlinMultiplatformExtension -> kotlinExt.targets.toList() to true
is KotlinSingleTargetExtension<*> -> listOf(kotlinExt.target) to false
else -> throw GradleException("Incorrect KGP extension type, type: ${kotlinExt::class}")
}
val compilations = targets
.flatMap { target -> target.compilations.toList() }
.filterNot { compilation ->
//It's a meta-compilation that doesn't have libraries
(compilation.target.platformType == KotlinPlatformType.common && compilation.name == "main")
}
if (compilation.get().allKotlinSourceSets.all { primitivesCache.get().sourceSetToResolved[it.name]!! }) return

val pathToSourceSet: Map<String, String> = HashMap<String, String>().apply {
for (sourceSet in kotlinExt.sourceSets) {
val sourceSetName = sourceSet.name
val isMpp = kotlinExt is KotlinMultiplatformExtension

for (file in sourceSet.kotlin.files) {
this[file.path] = sourceSetName
}
}
fun findSourceSetName(file: File): String {
return kotlinExt.sourceSets.find { file in it.kotlin }!!.name
}

val resolvedSourceSets = kotlinExt.sourceSets.associate { it.name to false }.toMutableMap()
fun findSourceSetName(path: String): String = findSourceSetName(File(path))

for (compilation in compilations) {
// Check if all sourceSets of compilation are resolved
if (compilation.allKotlinSourceSets.all { resolvedSourceSets[it.name]!! }) continue
val sourcesWithMppInfo = inputFiles
.filterNot { it.absolutePath.startsWith(generationPath.get().asFile.absolutePath) }
.map { source ->
val isMpp = findSourceSetName(source) == "commonMain"

val analyzeFun = when (compilation) {
is KotlinJvmCompilation -> Analyze::analyzeJvmSources
is KotlinJsIrCompilation -> Analyze::analyzeJsSources
is KotlinCommonCompilation -> Analyze::analyzeCommonSources
else -> throw GradleException("Unsupported compilation target: ${compilation.target}")
FileWithMpp(source, isMpp)
}

val task = compilation.compileTaskProvider.get() as KotlinCompileTool
val analyzeFun = when(compilation.get().platformType) {
KotlinPlatformType.jvm -> Analyze::analyzeJvmSources
KotlinPlatformType.js -> Analyze::analyzeJsSources
KotlinPlatformType.common -> Analyze::analyzeCommonSources
else -> error("Unsupported platform type ${compilation.get().platformType}")
}

val libs = task.libraries.files.filterNotNull()
val sources = task.sources.files.filterNot { it.path.startsWith(generationPath.get().asFile.path) }.filterNotNull()
val compilerConfig = Analyze.createCompilerConfig(isMpp)

if (sources.isEmpty()) {
compilation.allKotlinSourceSets.forEach { resolvedSourceSets[it.name] = true }
continue
}
for (source in sourcesWithMppInfo) {
compilerConfig.addKotlinSourceRoot(source.file.path, source.isMpp)
}

val compilerConfig = Analyze.createCompilerConfig(isMpp)
compilerConfig.addJvmClasspathRoots(libraries.files.filterNotNull())

for (source in sources) {
val isCommon = pathToSourceSet[source.path] == "commonMain"
compilerConfig.addKotlinSourceRoot(source.path, isCommon)
}
compilerConfig.addJvmClasspathRoots(libs)
val (result, kotlinEnv) = analyzeFun(compilerConfig)
val ktSources = kotlinEnv.getSourceFiles()

val (result, kotlinEnv) = analyzeFun(compilerConfig)
val ktSources = kotlinEnv.getSourceFiles()
val annotated = ktSources.filter { it.isAnnotatedWith<GeneratePrimitives>(result.bindingContext) }
val notGeneratedYet = annotated.filterNot { it.virtualFilePath in primitivesCache.get().resolvedPaths }

val annotated = ktSources.filter { it.isAnnotatedWith<GeneratePrimitives>(result.bindingContext) }
val notGeneratedYet = annotated.filterNot { resolvedSourceSets[pathToSourceSet[it.virtualFilePath]!!]!! }
for (ktFile in notGeneratedYet) {
val sourceSet = findSourceSetName(ktFile.virtualFilePath)
val outputDir = generationPath.dir(sourceSet).get().asFile

for (ktFile in notGeneratedYet) {
val sourceSet = pathToSourceSet[ktFile.virtualFilePath]!!
val outputDir = generationPath.dir(sourceSet).get().asFile
PrimitiveGenerator(ktFile, result.bindingContext, outputDir, MessageCollector.NONE).generate()

PrimitiveGenerator(ktFile, result.bindingContext, outputDir, MessageCollector.NONE).generate()
}

compilation.allKotlinSourceSets.forEach { resolvedSourceSets[it.name] = true }
primitivesCache.get().resolvedPaths.add(ktFile.virtualFilePath)
}

compilation.get().allKotlinSourceSets.forEach { primitivesCache.get().sourceSetToResolved[it.name] = true }
}


data class FileWithMpp(val file: File, val isMpp: Boolean)
}

0 comments on commit 8f7a653

Please sign in to comment.