Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix case class default parameter detection and annotation detection in Scala3 #88

Merged
merged 5 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ trait MainArgsPublishModule
)

def ivyDeps = Agg(
ivy"org.scala-lang.modules::scala-collection-compat::2.8.1"
ivy"org.scala-lang.modules::scala-collection-compat::2.8.1",
ivy"com.lihaoyi::pprint:0.8.1"
)
}

Expand Down
41 changes: 35 additions & 6 deletions mainargs/src-3/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,37 @@ object Macros {
val annotatedMethod = TypeRepr.of[B].typeSymbol.companionModule.memberMethod("apply").head
companionModuleType match
case '[bCompanion] =>
val mainData = createMainData[B, Any](annotatedMethod, mainAnnotationInstance)
val mainData = createMainData[B, Any](
annotatedMethod,
mainAnnotationInstance,
// Somehow the `apply` method parameter annotations don't end up on
// the `apply` method parameters, but end up in the `<init>` method
// parameters, so use those for getting the annotations instead
TypeRepr.of[B].typeSymbol.primaryConstructor.paramSymss
)
'{ new ParserForClass[B](${ mainData }, () => ${ Ident(companionModule).asExpr }) }
}

def createMainData[T: Type, B: Type](using Quotes)(method: quotes.reflect.Symbol, annotation: quotes.reflect.Term): Expr[MainData[T, B]] = {
def createMainData[T: Type, B: Type](using Quotes)
(method: quotes.reflect.Symbol,
mainAnnotation: quotes.reflect.Term): Expr[MainData[T, B]] = {
createMainData[T, B](method, mainAnnotation, method.paramSymss)
}

def createMainData[T: Type, B: Type](using Quotes)
(method: quotes.reflect.Symbol,
mainAnnotation: quotes.reflect.Term,
annotatedParamsLists: List[List[quotes.reflect.Symbol]]): Expr[MainData[T, B]] = {

import quotes.reflect.*
val params = method.paramSymss.headOption.getOrElse(report.throwError("Multiple parameter lists not supported"))
val defaultParams = getDefaultParams(method)
val argSigs = Expr.ofList(params.map { param =>
val argSigsExprs = params.zip(annotatedParamsLists.flatten).map { paramAndAnnotParam =>
val param = paramAndAnnotParam._1
val annotParam = paramAndAnnotParam._2
val paramTree = param.tree.asInstanceOf[ValDef]
val paramTpe = paramTree.tpt.tpe
val arg = param.getAnnotation(argAnnotation).map(_.asExprOf[mainargs.arg]).getOrElse('{ new mainargs.arg() })
val arg = annotParam.getAnnotation(argAnnotation).map(_.asExprOf[mainargs.arg]).getOrElse('{ new mainargs.arg() })
val paramType = paramTpe.asType
paramType match
case '[t] =>
Expand All @@ -66,13 +85,14 @@ object Macros {
)
}
'{ (ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ tokensReader })) }
})
}
val argSigs = Expr.ofList(argSigsExprs)

val invokeRaw: Expr[(B, Seq[Any]) => T] = {
def callOf(args: Expr[Seq[Any]]) = call(method, '{ Seq( ${ args }) }).asExprOf[T]
'{ ((b: B, params: Seq[Any]) => ${ callOf('{ params }) }) }
}
'{ MainData.create[T, B](${ Expr(method.name) }, ${ annotation.asExprOf[mainargs.main] }, ${ argSigs }, ${ invokeRaw }) }
'{ MainData.create[T, B](${ Expr(method.name) }, ${ mainAnnotation.asExprOf[mainargs.main] }, ${ argSigs }, ${ invokeRaw }) }
}

/** Call a method given by its symbol.
Expand Down Expand Up @@ -134,12 +154,21 @@ object Macros {
val defaults = collection.mutable.Map.empty[Symbol, Expr[Any]]

val Name = (method.name + """\$default\$(\d+)""").r
val InitName = """\$lessinit\$greater\$default\$(\d+)""".r

val idents = method.owner.tree.asInstanceOf[ClassDef].body

idents.foreach{
case deff @ DefDef(Name(idx), _, _, _) =>
val expr = Ref(deff.symbol).asExpr
defaults += (params(idx.toInt - 1) -> expr)

// The `apply` method re-uses the default param factory methods from `<init>`,
// so make sure to check if those exist too
case deff @ DefDef(InitName(idx), _, _, _) if method.name == "apply" =>
val expr = Ref(deff.symbol).asExpr
defaults += (params(idx.toInt - 1) -> expr)

case _ =>
}

Expand Down
185 changes: 86 additions & 99 deletions mainargs/test/src/ClassTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,121 +58,108 @@ object ClassTests extends TestSuite {
Bar(Flag(true), Foo(1, 2), "xxx")
}
test("missingInner") {
// Blocked by https://github.com/lampepfl/dotty/issues/12492
TestUtils.scala2Only {
barParser.constructRaw(Seq("-w", "-x", "1", "-z", "xxx")) ==>
Result.Failure.MismatchedArguments(
Seq(
ArgSig(
None,
Some('y'),
None,
None,
mainargs.TokensReader.IntRead,
positional = false,
hidden = false
)
),
List(),
List(),
None
)
}
barParser.constructRaw(Seq("-w", "-x", "1", "-z", "xxx")) ==>
Result.Failure.MismatchedArguments(
Seq(
ArgSig(
None,
Some('y'),
None,
None,
mainargs.TokensReader.IntRead,
positional = false,
hidden = false
)
),
List(),
List(),
None
)
}
test("missingOuter") {
// Blocked by https://github.com/lampepfl/dotty/issues/12492
TestUtils.scala2Only {
barParser.constructRaw(Seq("-w", "-x", "1", "-y", "2")) ==>
Result.Failure.MismatchedArguments(
Seq(
ArgSig(
Some("zzzz"),
Some('z'),
None,
None,
mainargs.TokensReader.StringRead,
positional = false,
hidden = false
)
),
List(),
List(),
None
)
}
barParser.constructRaw(Seq("-w", "-x", "1", "-y", "2")) ==>
Result.Failure.MismatchedArguments(
Seq(
ArgSig(
Some("zzzz"),
Some('z'),
None,
None,
mainargs.TokensReader.StringRead,
positional = false,
hidden = false
)
),
List(),
List(),
None
)
}

test("missingInnerOuter") {
// Blocked by https://github.com/lampepfl/dotty/issues/12492
TestUtils.scala2Only {
barParser.constructRaw(Seq("-w", "-x", "1")) ==>
Result.Failure.MismatchedArguments(
Seq(
ArgSig(
None,
Some('y'),
None,
None,
mainargs.TokensReader.IntRead,
positional = false,
hidden = false
),
ArgSig(
Some("zzzz"),
Some('z'),
None,
None,
mainargs.TokensReader.StringRead,
positional = false,
hidden = false
)
barParser.constructRaw(Seq("-w", "-x", "1")) ==>
Result.Failure.MismatchedArguments(
Seq(
ArgSig(
None,
Some('y'),
None,
None,
mainargs.TokensReader.IntRead,
positional = false,
hidden = false
),
List(),
List(),
None
)
}
ArgSig(
Some("zzzz"),
Some('z'),
None,
None,
mainargs.TokensReader.StringRead,
positional = false,
hidden = false
)
),
List(),
List(),
None
)
}

test("failedInnerOuter") {
TestUtils.scala2Only {
assertMatch(
barParser.constructRaw(
Seq("-w", "-x", "xxx", "-y", "hohoho", "-z", "xxx")
)
) {
case Result.Failure.InvalidArguments(
Seq(
Result.ParamError.Failed(
ArgSig(None, Some('x'), None, None, _, false, _),
Seq("xxx"),
_
),
Result.ParamError.Failed(
ArgSig(None, Some('y'), None, None, _, false, _),
Seq("hohoho"),
_
)
assertMatch(
barParser.constructRaw(
Seq("-w", "-x", "xxx", "-y", "hohoho", "-z", "xxx")
)
) {
case Result.Failure.InvalidArguments(
Seq(
Result.ParamError.Failed(
ArgSig(None, Some('x'), None, None, _, false, _),
Seq("xxx"),
_
),
Result.ParamError.Failed(
ArgSig(None, Some('y'), None, None, _, false, _),
Seq("hohoho"),
_
)
) =>
}
)
) =>

}
}
}

test("doubleNested") {
TestUtils.scala2Only {
quxParser.constructOrThrow(
Seq("-w", "-x", "1", "-y", "2", "-z", "xxx", "--moo", "cow")
) ==>
Qux("cow", Bar(Flag(true), Foo(1, 2), "xxx"))
}
quxParser.constructOrThrow(
Seq("-w", "-x", "1", "-y", "2", "-z", "xxx", "--moo", "cow")
) ==>
Qux("cow", Bar(Flag(true), Foo(1, 2), "xxx"))
}
test("success") {
TestUtils.scala2Only {
ParserForMethods(Main).runOrThrow(
Seq("-x", "1", "-y", "2", "-z", "hello")
) ==> "false 1 2 hello false"
}
ParserForMethods(Main).runOrThrow(
Seq("-x", "1", "-y", "2", "-z", "hello")
) ==> "false 1 2 hello false"
}
}
}
60 changes: 60 additions & 0 deletions mainargs/test/src/ClassWithDefaultTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package mainargs
import utest._

// Make sure
object ClassWithDefaultTests extends TestSuite {
@main
case class Foo(x: Int, y: Int = 1)

implicit val fooParser: ParserForClass[Foo] = ParserForClass[Foo]

object Main {
@main
def run(foo: Foo, bool: Boolean = false) = s"${foo.x} ${foo.y} $bool"
}

val mainParser = ParserForMethods(Main)

val tests = Tests {
test("simple") {
test("success") {
fooParser.constructOrThrow(Seq("-x", "1", "-y", "2")) ==> Foo(1, 2)
}
test("default") {
fooParser.constructOrThrow(Seq("-x", "0")) ==> Foo(0, 1)
}
test("missing") {
fooParser.constructRaw(Seq()) ==>
Result.Failure.MismatchedArguments(
Seq(
ArgSig(
None,
Some('x'),
None,
None,
mainargs.TokensReader.IntRead,
positional = false,
hidden = false
)
),
List(),
List(),
None
)

}
}

test("nested") {
test("success"){
mainParser.runOrThrow(Seq("-x", "1", "-y", "2", "--bool", "true")) ==> "1 2 true"
}
test("default"){
mainParser.runOrThrow(Seq("-x", "1", "-y", "2")) ==> "1 2 false"
}
test("default2"){
mainParser.runOrThrow(Seq("-x", "0")) ==> "0 1 false"
}
}
}
}
7 changes: 2 additions & 5 deletions mainargs/test/src/ParserTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,8 @@ object ParserTests extends TestSuite {
) ==> Right("xxxxx")
}
test("constructEither") {
TestUtils.scala2Only {
// default values in classes not working on Scala 3
classParser.constructEither(Array("--code", "println(1)")) ==>
Right(ClassBase(code = Some("println(1)"), other = "hello"))
}
classParser.constructEither(Array("--code", "println(1)")) ==>
Right(ClassBase(code = Some("println(1)"), other = "hello"))
}
}
}
5 changes: 5 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,11 @@ command-line friendly tool.

# Changelog

## master

- Fix handling of case class main method parameter default parameters in Scala 3
[#88](https://github.com/com-lihaoyi/mainargs/pull/88)

## 0.5.0

- Remove hard-code support for mainargs.Leftover/Flag/Subparser to support
Expand Down