Skip to content

Commit

Permalink
scrooge: Allow null values for nullable fields in StructBuilder
Browse files Browse the repository at this point in the history
Problem
-------
`c.t.scrooge.StructBuilder` methods `setField` and `setAllFields` don't accept
`null` as a valid value for nullable fields, making it impossible to use
`StructBuilder` to construct otherwise valid instances that have `null` as a value
of one or more such fields.

Solution
--------
Enhance `StructBuilder` so it allows setting `null` values for nullable fields.

JIRA Issues: CSL-12312

Differential Revision: https://phabricator.twitter.biz/D976737
  • Loading branch information
szegedi authored and jenkins committed Sep 22, 2022
1 parent 5024fe9 commit e35e184
Show file tree
Hide file tree
Showing 15 changed files with 197 additions and 93 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ Runtime Behavior Changes
* scrooge: Use the util-mock for testing which depends on the 3.12.4 "org.mockito", remove the
"org.jmock" dependencies. ``PHAB_ID=D973946``

* scrooge: `c.t.scrooge.StructBuilder` now accepts `null` values for non-primitive default-required fields. ``PHAB_ID=D976737``

22.7.0
------

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
package com.twitter.scrooge

import com.twitter.util.Memoize
import org.apache.thrift.protocol.TType
import scala.reflect.ClassTag

/**
* A class that provides an interface for building a new ThriftStruct from an existing
* ThriftStruct or statically from T.
*
* We pass in a list of [[ClassTag]]s which describe each of the struct's field types
* so that we can validate the values we are setting at runtime.
* so that we can validate the values we are setting at runtime. We also pass in a set
* of indices of non-primitive default-required fields; those can have null as valid value.
*/
abstract class StructBuilder[T <: ThriftStruct](fieldTypes: IndexedSeq[ClassTag[_]]) {
abstract class StructBuilder[T <: ThriftStruct](
fieldTypes: IndexedSeq[ClassTag[_]],
nullableIndices: Set[Int]) {
protected val fieldArray: Array[Any] = new Array[Any](fieldTypes.size)

/**
Expand All @@ -23,6 +27,7 @@ abstract class StructBuilder[T <: ThriftStruct](fieldTypes: IndexedSeq[ClassTag[
private[this] def addOrUpdateFieldArray[A](index: Int, v: Any)(implicit tag: ClassTag[A]): Unit =
v match {
case inputValue: A => fieldArray(index) = inputValue
case null if nullableIndices(index) => fieldArray(index) = StructBuilder.ExplicitNull
case _ => throw new IllegalArgumentException(s"value at index $index must be of type $tag")
}

Expand Down Expand Up @@ -84,12 +89,44 @@ abstract class StructBuilder[T <: ThriftStruct](fieldTypes: IndexedSeq[ClassTag[
* This object provides operations to obtain `StructBuilder` instances.
*/
object StructBuilder {
// Used as an explicitly set null value for default-required fields
private[StructBuilder] object ExplicitNull

def unwrapExplicitNull(v: Any): Any =
v match {
case ExplicitNull => null
case _ => v
}

private[this] val memoizeBuilderMethod: Class[_] => () => Any = Memoize.classValue { clazz =>
val thriftCodec = ThriftStructCodec.forStructClass(clazz.asSubclass(classOf[ThriftStruct]))
val m = thriftCodec.getClass.getMethod("newBuilder")
() => m.invoke(thriftCodec)
}

/**
* Given a sequence of struct field information, returns a set of indices of non-primitive
* default-required fields.
* @param fieldInfos a sequence of field informations
* @return a set of indices for non-primitive default-required fields
*/
final def nullableIndices(fieldInfos: Seq[ThriftStructFieldInfo]): Set[Int] = {
fieldInfos.zipWithIndex.collect {
case (f, i) if isNullable(f) => i
}.toSet
}

private[this] def isNullable(f: ThriftStructFieldInfo): Boolean = {
if (f.isOptional || f.isRequired) {
false
} else {
f.tfield.`type` match {
case TType.BOOL | TType.BYTE | TType.I16 | TType.I32 | TType.I64 | TType.DOUBLE => false
case _ => true
}
}
}

/**
* For a given scrooge-generated thrift struct or union class, returns its StructBuilder.
* This can be used for building a new ThriftStruct object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ object AnotherException extends ValidatingThriftStructCodec3[AnotherException] w
)


lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos)

val structAnnotations: immutable$Map[String, String] =
immutable$Map.empty[String, String]

Expand Down Expand Up @@ -111,7 +113,7 @@ object AnotherException extends ValidatingThriftStructCodec3[AnotherException] w
)
}

def newBuilder(): StructBuilder[AnotherException] = new AnotherExceptionStructBuilder(_root_.scala.None, fieldTypes)
def newBuilder(): StructBuilder[AnotherException] = new AnotherExceptionStructBuilder(_root_.scala.None, fieldTypes, nullableIndices)

override def encode(_item: AnotherException, _oproto: TProtocol): Unit = {
_item.write(_oproto)
Expand Down Expand Up @@ -348,11 +350,11 @@ class AnotherException(
flags
)

def newBuilder(): StructBuilder[AnotherException] = new AnotherExceptionStructBuilder(_root_.scala.Some(this), fieldTypes)
def newBuilder(): StructBuilder[AnotherException] = new AnotherExceptionStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices)
}

private[thriftscala] class AnotherExceptionStructBuilder(instance: _root_.scala.Option[AnotherException], fieldTypes: IndexedSeq[ClassTag[_]])
extends StructBuilder[AnotherException](fieldTypes) {
private[thriftscala] class AnotherExceptionStructBuilder(instance: _root_.scala.Option[AnotherException], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int])
extends StructBuilder[AnotherException](fieldTypes, nullableIndices) {

def build(): AnotherException = {
val _fieldArray = fieldArray // shadow variable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ object CollectionId extends ValidatingThriftStructCodec3[CollectionId] with Stru
)


lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos)

val structAnnotations: immutable$Map[String, String] =
immutable$Map.empty[String, String]

Expand Down Expand Up @@ -114,7 +116,7 @@ object CollectionId extends ValidatingThriftStructCodec3[CollectionId] with Stru
)
}

def newBuilder(): StructBuilder[CollectionId] = new CollectionIdStructBuilder(_root_.scala.None, fieldTypes)
def newBuilder(): StructBuilder[CollectionId] = new CollectionIdStructBuilder(_root_.scala.None, fieldTypes, nullableIndices)

override def encode(_item: CollectionId, _oproto: TProtocol): Unit = {
_item.write(_oproto)
Expand Down Expand Up @@ -444,11 +446,11 @@ trait CollectionId

def _codec: ValidatingThriftStructCodec3[CollectionId] = CollectionId

def newBuilder(): StructBuilder[CollectionId] = new CollectionIdStructBuilder(_root_.scala.Some(this), fieldTypes)
def newBuilder(): StructBuilder[CollectionId] = new CollectionIdStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices)
}

private[thriftscala] class CollectionIdStructBuilder(instance: _root_.scala.Option[CollectionId], fieldTypes: IndexedSeq[ClassTag[_]])
extends StructBuilder[CollectionId](fieldTypes) {
private[thriftscala] class CollectionIdStructBuilder(instance: _root_.scala.Option[CollectionId], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int])
extends StructBuilder[CollectionId](fieldTypes, nullableIndices) {

def build(): CollectionId = {
val _fieldArray = fieldArray // shadow variable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ
)


lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos)

val structAnnotations: immutable$Map[String, String] =
immutable$Map.empty[String, String]

Expand Down Expand Up @@ -425,7 +427,7 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ
)
}

def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.None, fieldTypes)
def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.None, fieldTypes, nullableIndices)

override def encode(_item: Args, _oproto: TProtocol): Unit = {
_item.write(_oproto)
Expand Down Expand Up @@ -589,27 +591,27 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ

def _codec: ValidatingThriftStructCodec3[Args] = Args

def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.Some(this), fieldTypes)
def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices)
}

private[thriftscala] class ArgsStructBuilder(instance: _root_.scala.Option[Args], fieldTypes: IndexedSeq[ClassTag[_]])
extends StructBuilder[Args](fieldTypes) {
private[thriftscala] class ArgsStructBuilder(instance: _root_.scala.Option[Args], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int])
extends StructBuilder[Args](fieldTypes, nullableIndices) {

def build(): Args = {
val _fieldArray = fieldArray // shadow variable
if (instance.isDefined) {
val instanceValue = instance.get
Args(
if (_fieldArray(0) == null) instanceValue.request else _fieldArray(0).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request],
if (_fieldArray(1) == null) instanceValue.unionRequest else _fieldArray(1).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestUnion],
if (_fieldArray(2) == null) instanceValue.exceptionRequest else _fieldArray(2).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestException]
if (_fieldArray(0) == null) instanceValue.request else StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request],
if (_fieldArray(1) == null) instanceValue.unionRequest else StructBuilder.unwrapExplicitNull(_fieldArray(1)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestUnion],
if (_fieldArray(2) == null) instanceValue.exceptionRequest else StructBuilder.unwrapExplicitNull(_fieldArray(2)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestException]
)
} else {
if (genericArrayOps(_fieldArray).contains(null)) throw new InvalidFieldsException(structBuildError("Args"))
Args(
_fieldArray(0).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request],
_fieldArray(1).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestUnion],
_fieldArray(2).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestException]
StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request],
StructBuilder.unwrapExplicitNull(_fieldArray(1)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestUnion],
StructBuilder.unwrapExplicitNull(_fieldArray(2)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestException]
)
}
}
Expand Down Expand Up @@ -655,6 +657,8 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ
)


lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos)

val structAnnotations: immutable$Map[String, String] =
immutable$Map.empty[String, String]

Expand Down Expand Up @@ -739,7 +743,7 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ
)
}

def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.None, fieldTypes)
def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.None, fieldTypes, nullableIndices)

override def encode(_item: Result, _oproto: TProtocol): Unit = {
_item.write(_oproto)
Expand Down Expand Up @@ -886,11 +890,11 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ

def _codec: ValidatingThriftStructCodec3[Result] = Result

def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.Some(this), fieldTypes)
def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices)
}

private[thriftscala] class ResultStructBuilder(instance: _root_.scala.Option[Result], fieldTypes: IndexedSeq[ClassTag[_]])
extends StructBuilder[Result](fieldTypes) {
private[thriftscala] class ResultStructBuilder(instance: _root_.scala.Option[Result], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int])
extends StructBuilder[Result](fieldTypes, nullableIndices) {

def build(): Result = {
val _fieldArray = fieldArray // shadow variable
Expand Down Expand Up @@ -971,6 +975,8 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ
)


lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos)

val structAnnotations: immutable$Map[String, String] =
immutable$Map.empty[String, String]

Expand Down Expand Up @@ -1035,7 +1041,7 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ
)
}

def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.None, fieldTypes)
def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.None, fieldTypes, nullableIndices)

override def encode(_item: Args, _oproto: TProtocol): Unit = {
_item.write(_oproto)
Expand Down Expand Up @@ -1156,23 +1162,23 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ

def _codec: ValidatingThriftStructCodec3[Args] = Args

def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.Some(this), fieldTypes)
def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices)
}

private[thriftscala] class ArgsStructBuilder(instance: _root_.scala.Option[Args], fieldTypes: IndexedSeq[ClassTag[_]])
extends StructBuilder[Args](fieldTypes) {
private[thriftscala] class ArgsStructBuilder(instance: _root_.scala.Option[Args], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int])
extends StructBuilder[Args](fieldTypes, nullableIndices) {

def build(): Args = {
val _fieldArray = fieldArray // shadow variable
if (instance.isDefined) {
val instanceValue = instance.get
Args(
if (_fieldArray(0) == null) instanceValue.request else _fieldArray(0).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request]
if (_fieldArray(0) == null) instanceValue.request else StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request]
)
} else {
if (genericArrayOps(_fieldArray).contains(null)) throw new InvalidFieldsException(structBuildError("Args"))
Args(
_fieldArray(0).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request]
StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request]
)
}
}
Expand Down Expand Up @@ -1204,6 +1210,8 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ
)


lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos)

val structAnnotations: immutable$Map[String, String] =
immutable$Map.empty[String, String]

Expand Down Expand Up @@ -1270,7 +1278,7 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ
)
}

def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.None, fieldTypes)
def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.None, fieldTypes, nullableIndices)

override def encode(_item: Result, _oproto: TProtocol): Unit = {
_item.write(_oproto)
Expand Down Expand Up @@ -1393,11 +1401,11 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ

def _codec: ValidatingThriftStructCodec3[Result] = Result

def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.Some(this), fieldTypes)
def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices)
}

private[thriftscala] class ResultStructBuilder(instance: _root_.scala.Option[Result], fieldTypes: IndexedSeq[ClassTag[_]])
extends StructBuilder[Result](fieldTypes) {
private[thriftscala] class ResultStructBuilder(instance: _root_.scala.Option[Result], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int])
extends StructBuilder[Result](fieldTypes, nullableIndices) {

def build(): Result = {
val _fieldArray = fieldArray // shadow variable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ object OverCapacityException extends ValidatingThriftStructCodec3[OverCapacityEx
)


lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos)

lazy val structAnnotations: immutable$Map[String, String] =
immutable$Map[String, String](
("e.annotation", "true")
Expand Down Expand Up @@ -115,7 +117,7 @@ object OverCapacityException extends ValidatingThriftStructCodec3[OverCapacityEx
)
}

def newBuilder(): StructBuilder[OverCapacityException] = new OverCapacityExceptionStructBuilder(_root_.scala.None, fieldTypes)
def newBuilder(): StructBuilder[OverCapacityException] = new OverCapacityExceptionStructBuilder(_root_.scala.None, fieldTypes, nullableIndices)

override def encode(_item: OverCapacityException, _oproto: TProtocol): Unit = {
_item.write(_oproto)
Expand Down Expand Up @@ -352,11 +354,11 @@ class OverCapacityException(
flags
)

def newBuilder(): StructBuilder[OverCapacityException] = new OverCapacityExceptionStructBuilder(_root_.scala.Some(this), fieldTypes)
def newBuilder(): StructBuilder[OverCapacityException] = new OverCapacityExceptionStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices)
}

private[thriftscala] class OverCapacityExceptionStructBuilder(instance: _root_.scala.Option[OverCapacityException], fieldTypes: IndexedSeq[ClassTag[_]])
extends StructBuilder[OverCapacityException](fieldTypes) {
private[thriftscala] class OverCapacityExceptionStructBuilder(instance: _root_.scala.Option[OverCapacityException], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int])
extends StructBuilder[OverCapacityException](fieldTypes, nullableIndices) {

def build(): OverCapacityException = {
val _fieldArray = fieldArray // shadow variable
Expand Down
Loading

0 comments on commit e35e184

Please sign in to comment.