-
Notifications
You must be signed in to change notification settings - Fork 24
/
DualResnetModel.scala
44 lines (37 loc) · 1.49 KB
/
DualResnetModel.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
package org.deeplearning4j.scalphagozero.models
import org.deeplearning4j.nn.conf.CacheMode
import org.deeplearning4j.nn.graph.ComputationGraph
/**
* Define and load an AlphaGo Zero dual ResNet architecture
* into DL4J.
*
* The dual residual architecture is the strongest
* of the architectures tested by DeepMind for AlphaGo
* Zero. It consists of an initial convolution layer block,
* followed by a number (40 for the strongest, 20 as
* baseline) of residual blocks. The network is topped
* off by two "heads", one to predict policies and one
* for value functions.
*
* @author Max Pumperla
*/
object DualResnetModel {
def apply(numBlocks: Int, numPlanes: Int, boardSize: Int): ComputationGraph = {
val builder = new DL4JAlphaGoZeroBuilder(boardSize)
val input = "in"
builder.addInputs(input)
val initBlock = "init"
val convOut = builder.addConvBatchNormBlock(initBlock, input, numPlanes)
val towerOut: String = builder.addResidualTower(numBlocks, convOut)
val policyOut = builder.addPolicyHead(towerOut)
val valueOut = builder.addValueHead(towerOut)
builder.addOutputs(List(policyOut, valueOut))
val model = new ComputationGraph(builder.buildAndReturn())
model.init()
model.setCacheMode(CacheMode.HOST)
// This can be used to give an indication of model memory usage
//val report = model.getConfiguration.getMemoryReport(builder.inputTypes)
//println("Memory report for DualResnetModel: \n" + report.toJson)
model
}
}