Skip to content

Commit

Permalink
Remote caching support
Browse files Browse the repository at this point in the history
This implements Bazel-compat remote caching support via a plugin.
  • Loading branch information
eed3si9n committed Apr 5, 2024
1 parent 35520e6 commit e5b3e09
Show file tree
Hide file tree
Showing 33 changed files with 1,140 additions and 242 deletions.
31 changes: 30 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,26 @@ lazy val utilRelation = (project in file("internal") / "util-relation")
// Persisted caching based on sjson-new
lazy val utilCache = project
.in(file("util-cache"))
.enablePlugins(
ContrabandPlugin,
// we generate JsonCodec only for actionresult.conta
// JsonCodecPlugin,
)
.dependsOn(utilLogging)
.settings(
testedBaseSettings,
name := "Util Cache",
libraryDependencies ++=
Seq(sjsonNewScalaJson.value, sjsonNewMurmurhash.value, scalaReflect.value),
Seq(
sjsonNewCore.value,
sjsonNewScalaJson.value,
sjsonNewMurmurhash.value,
scalaReflect.value
),
Compile / managedSourceDirectories +=
baseDirectory.value / "src" / "main" / "contraband-scala",
Compile / generateContrabands / sourceManaged := baseDirectory.value / "src" / "main" / "contraband-scala",
Compile / generateContrabands / contrabandFormatsForType := ContrabandConfig.getFormats,
utilMimaSettings,
Test / fork := true,
)
Expand Down Expand Up @@ -645,6 +660,19 @@ lazy val dependencyTreeProj = (project in file("dependency-tree"))
mimaPreviousArtifacts := Set.empty,
)

lazy val remoteCacheProj = (project in file("sbt-remote-cache"))
.dependsOn(sbtProj)
.settings(
sbtPlugin := true,
baseSettings,
name := "sbt-remote-cache",
pluginCrossBuild / sbtVersion := version.value,
publishMavenStyle := true,
// mimaSettings,
mimaPreviousArtifacts := Set.empty,
libraryDependencies += remoteapis,
)

// Implementation and support code for defining actions.
lazy val actionsProj = (project in file("main-actions"))
.dependsOn(
Expand Down Expand Up @@ -1266,6 +1294,7 @@ def allProjects =
utilTracking,
collectionProj,
coreMacrosProj,
remoteCacheProj,
) ++ lowerUtilProjects

// These need to be cross published to 2.12 and 2.13 for Zinc
Expand Down
2 changes: 2 additions & 0 deletions main-settings/src/main/scala/sbt/Def.scala
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,12 @@ object Def extends Init[Scope] with TaskMacroExtra with InitializeImplicits:
private[sbt] var _cacheStore: ActionCacheStore = InMemoryActionCacheStore()
def cacheStore: ActionCacheStore = _cacheStore
private[sbt] var _outputDirectory: Option[Path] = None
private[sbt] val cacheEventLog: CacheEventLog = CacheEventLog()
def cacheConfiguration: BuildWideCacheConfiguration =
BuildWideCacheConfiguration(
_cacheStore,
_outputDirectory.getOrElse(sys.error("outputDirectory has not been set")),
cacheEventLog,
)

inline def cachedTask[A1: JsonFormat](inline a1: A1): Def.Initialize[Task[A1]] =
Expand Down
10 changes: 10 additions & 0 deletions main/src/main/scala/sbt/Defaults.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4647,6 +4647,16 @@ trait BuildExtra extends BuildCommon with DefExtra {
scalaBinaryVersion.value
)

/**
* Adds Maven resolver plugin.
*/
def addRemoteCachePlugin: Setting[Seq[ModuleID]] =
libraryDependencies += sbtPluginExtra(
ModuleID("org.scala-sbt", "sbt-remote-cache", sbtVersion.value),
sbtBinaryVersion.value,
scalaBinaryVersion.value
)

/**
* Adds `dependency` as an sbt plugin for the specific sbt version `sbtVersion` and Scala version `scalaVersion`.
* Typically, use the default values for these versions instead of specifying them explicitly.
Expand Down
12 changes: 11 additions & 1 deletion main/src/main/scala/sbt/Keys.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ package sbt

import java.nio.file.{ Path => NioPath }
import java.io.File
import java.net.URL
import java.net.{ URL, URI }
import lmcoursier.definitions.{ CacheLogger, ModuleMatchers, Reconciliation }
import lmcoursier.{ CoursierConfiguration, FallbackDependency }
import org.apache.ivy.core.module.descriptor.ModuleDescriptor
Expand Down Expand Up @@ -116,6 +116,16 @@ object Keys {
val fullServerHandlers = SettingKey(BasicKeys.fullServerHandlers)
val serverHandlers = settingKey[Seq[ServerHandler]]("User-defined server handlers.")
val cacheStores = settingKey[Seq[ActionCacheStore]]("Cache backends")
@cacheLevel(include = Array.empty)
val remoteCache = settingKey[Option[URI]]("URI of the remote cache")
@cacheLevel(include = Array.empty)
val remoteCacheTlsCertificate = settingKey[Option[File]]("Path to a TLS certificate (*.crt) that is trusted to sign server certificates")
@cacheLevel(include = Array.empty)
val remoteCacheTlsClientCertificate = settingKey[Option[File]]("Path to a TLS client certificate *.crt used with remoteCacheTlsClientKey ")
@cacheLevel(include = Array.empty)
val remoteCacheTlsClientKey = settingKey[Option[File]]("Path to a TLS client key *.pem used with remoteCacheTlsClientCertificate")
@cacheLevel(include = Array.empty)
val remoteCacheHeaders = settingKey[Seq[String]]("List of key=value headers to be sent to the remote cache.")
val rootOutputDirectory = SettingKey(BasicKeys.rootOutputDirectory)

// val analysis = AttributeKey[CompileAnalysis]("analysis", "Analysis of compilation, including dependencies and generated outputs.", DSetting)
Expand Down
16 changes: 9 additions & 7 deletions main/src/main/scala/sbt/RemoteCache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,10 @@ object RemoteCache {
private[sbt] def getCachedAnalysis(ref: HashedVirtualFileRef): CompileAnalysis =
analysisStore.getOrElseUpdate(
ref, {
val vfs = cacheStore.getBlobs(ref :: Nil)
if vfs.nonEmpty then
val outputDirectory = Def.cacheConfiguration.outputDirectory
cacheStore.syncBlobs(vfs, outputDirectory).headOption match
case Some(file) => FileAnalysisStore.binary(file.toFile()).get.get.getAnalysis
case None => Analysis.empty
else Analysis.empty
val outputDirectory = Def.cacheConfiguration.outputDirectory
cacheStore.syncBlobs(Seq(ref), outputDirectory).headOption match
case Some(file) => FileAnalysisStore.binary(file.toFile()).get.get.getAnalysis
case None => Analysis.empty
}
)

Expand Down Expand Up @@ -157,6 +154,11 @@ object RemoteCache {
DiskActionCacheStore(localCacheDirectory.value.toPath())
)
},
remoteCache := SysProp.remoteCache,
remoteCacheTlsCertificate := SysProp.remoteCacheTlsCertificate,
remoteCacheTlsClientCertificate := SysProp.remoteCacheTlsClientCertificate,
remoteCacheTlsClientKey := SysProp.remoteCacheTlsClientKey,
remoteCacheHeaders := SysProp.remoteCacheHeaders,
)

lazy val projectSettings: Seq[Def.Setting[_]] = (Seq(
Expand Down
80 changes: 39 additions & 41 deletions main/src/main/scala/sbt/internal/Aggregation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ object Aggregation {
success: Boolean
)

final case class Complete[T](
final case class Complete[A](
start: Long,
stop: Long,
results: sbt.Result[Seq[KeyValue[T]]],
results: sbt.Result[Seq[KeyValue[A]]],
cacheSummary: String,
state: State
)

Expand Down Expand Up @@ -68,44 +69,43 @@ object Aggregation {
)(implicit display: Show[ScopedKey[_]]): Parser[() => State] =
Command.applyEffect(seqParser(ps))(ts => runTasks(s, ts, DummyTaskMap(Nil), show))

private def showRun[T](complete: Complete[T], show: ShowConfig)(implicit
display: Show[ScopedKey[_]]
): Unit = {
import complete._
private def showRun[A](complete: Complete[A], show: ShowConfig)(implicit
display: Show[ScopedKey[?]]
): Unit =
import complete.*
val log = state.log
val extracted = Project.extract(state)
val success = results match
case Result.Value(_) => true
case Result.Inc(_) => false
results.toEither.foreach { r =>
if (show.taskValues) printSettings(r, show.print)
if show.taskValues then printSettings(r, show.print) else ()
}
if (show.success && !state.get(suppressShow).getOrElse(false))
printSuccess(start, stop, extracted, success, log)
}
if show.success && !state.get(suppressShow).getOrElse(false) then
printSuccess(start, stop, extracted, success, cacheSummary, log)
else ()

def timedRun[T](
def timedRun[A](
s: State,
ts: Values[Task[T]],
extra: DummyTaskMap
): Complete[T] = {
ts: Values[Task[A]],
extra: DummyTaskMap,
): Complete[A] =
import EvaluateTask._
import std.TaskExtra._

val extracted = Project extract s
val extracted = Project.extract(s)
import extracted.structure
val toRun = ts.map { case KeyValue(k, t) => t.map(v => KeyValue(k, v)) }.join
val roots = ts.map { case KeyValue(k, _) => k }
val config = extractedTaskConfig(extracted, structure, s)

val start = System.currentTimeMillis
val (newS, result) = withStreams(structure, s) { str =>
val cacheEventLog = Def.cacheConfiguration.cacheEventLog
cacheEventLog.clear()
val (newS, result) = withStreams(structure, s): str =>
val transform = nodeView(s, str, roots, extra)
runTask(toRun, s, str, structure.index.triggers, config)(using transform)
}
val stop = System.currentTimeMillis
Complete(start, stop, result, newS)
}
val cacheSummary = cacheEventLog.summary
Complete(start, stop, result, cacheSummary, newS)

def runTasks[A1](
s: State,
Expand All @@ -124,20 +124,22 @@ object Aggregation {
stop: Long,
extracted: Extracted,
success: Boolean,
log: Logger
): Unit = {
import extracted._
cacheSummary: String,
log: Logger,
): Unit =
import extracted.*
def get(key: SettingKey[Boolean]): Boolean =
(currentRef / key).get(structure.data) getOrElse true

if (get(showSuccess)) {
if (get(showTiming)) {
val msg = timingString(start, stop, structure.data, currentRef)
if (success) log.success(msg) else if (Terminal.get.isSuccessEnabled) log.error(msg)
} else if (success)
log.success("")
}
}
if get(showSuccess) then
if get(showTiming) then
val msg = timingString(start, stop, structure.data, currentRef) + (
if cacheSummary == "" then ""
else ", " + cacheSummary
)
if success then log.success(msg)
else if Terminal.get.isSuccessEnabled then log.error(msg)
else if success then log.success("")
else ()

private def timingString(
startTime: Long,
Expand All @@ -149,23 +151,19 @@ object Aggregation {
timing(format, startTime, endTime)
}

def timing(format: java.text.DateFormat, startTime: Long, endTime: Long): String = {
val nowString = format.format(new java.util.Date(endTime))
def timing(format: java.text.DateFormat, startTime: Long, endTime: Long): String =
val total = (endTime - startTime + 500) / 1000
val totalString = s"$total s" +
(if (total <= 60) ""
(if total <= 60 then ""
else {
val maybeHours = total / 3600 match {
val maybeHours = total / 3600 match
case 0 => ""
case h => f"$h%02d:"
}
val mins = f"${total % 3600 / 60}%02d"
val secs = f"${total % 60}%02d"
s" ($maybeHours$mins:$secs)"
})

s"Total time: $totalString, completed $nowString"
}
s"elapsed time: $totalString"

def defaultFormat: DateFormat = {
import java.text.DateFormat
Expand Down
16 changes: 16 additions & 0 deletions main/src/main/scala/sbt/internal/SysProp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package sbt
package internal

import java.io.File
import java.net.URI
import java.nio.file.{ Path, Paths }
import java.util.Locale

Expand Down Expand Up @@ -98,6 +99,21 @@ object SysProp {
def legacyTestReport: Boolean = getOrFalse("sbt.testing.legacyreport")
def semanticdb: Boolean = getOrFalse("sbt.semanticdb")
def forceServerStart: Boolean = getOrFalse("sbt.server.forcestart")
def remoteCache: Option[URI] = sys.props
.get("sbt.remote_cache")
.map(URI(_))
def remoteCacheTlsCertificate: Option[File] = sys.props
.get("sbt.remote_cache.tls_certificate")
.map(File(_))
def remoteCacheTlsClientCertificate: Option[File] = sys.props
.get("sbt.remote_cache.tls_client_certificate")
.map(File(_))
def remoteCacheTlsClientKey: Option[File] = sys.props
.get("sbt.remote_cache.tls_client_key")
.map(File(_))
def remoteCacheHeaders: List[String] = sys.props
.get("sbt.remote_cache.header")
.toList

def watchMode: String =
sys.props.get("sbt.watch.mode").getOrElse("auto")
Expand Down
6 changes: 6 additions & 0 deletions project/ContrabandConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ object ContrabandConfig {
case "scalajson.ast.unsafe.JValue" | "sjsonnew.shaded.scalajson.ast.unsafe.JValue" => { _ =>
"sbt.internal.util.codec.JValueFormats" :: Nil
}
case "xsbti.HashedVirtualFileRef" => { _ =>
"sbt.internal.util.codec.HashedVirtualFileRefFormats" :: Nil
}
case "java.nio.ByteBuffer" => { _ =>
"sbt.internal.util.codec.ByteBufferFormats" :: Nil
}
}

/** Returns the list of formats required to encode the given `TpeRef`. */
Expand Down
4 changes: 3 additions & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ object Dependencies {
private val ioVersion = nightlyVersion.getOrElse("1.8.0")
private val lmVersion =
sys.props.get("sbt.build.lm.version").orElse(nightlyVersion).getOrElse("2.0.0-alpha13")
val zincVersion = nightlyVersion.getOrElse("2.0.0-alpha12")
val zincVersion = nightlyVersion.getOrElse("2.0.0-alpha13")

private val sbtIO = "org.scala-sbt" %% "io" % ioVersion

Expand Down Expand Up @@ -106,6 +106,8 @@ object Dependencies {
val junit = "junit" % "junit" % "4.13.1"
val scalaVerify = "com.eed3si9n.verify" %% "verify" % "1.0.0"
val templateResolverApi = "org.scala-sbt" % "template-resolver" % "0.1"
val remoteapis =
"com.eed3si9n.remoteapis.shaded" % "shaded-remoteapis-java" % "2.3.0-M1-52317e00d8d4c37fa778c628485d220fb68a8d08"

val scalaCompiler = "org.scala-lang" %% "scala3-compiler" % scala3

Expand Down
Loading

0 comments on commit e5b3e09

Please sign in to comment.