-
Notifications
You must be signed in to change notification settings - Fork 3
/
Llama2Runner.scala
163 lines (141 loc) · 5.41 KB
/
Llama2Runner.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
package net.virtualvoid.llama2
import scala.annotation.tailrec
import scala.util.Random
trait Sampler {
def sample(logits: Tensor1DMut): Int
}
object ArgmaxSampler extends Sampler {
def sample(logits: Tensor1DMut): Int = logits.toFloatArray.zipWithIndex.maxBy(_._1)._2
}
class TemperatureSampling(temperature: Float, topp: Float = 1f, random: Random = new Random, vocab: Vocab) extends Sampler {
override def sample(logits: Tensor1DMut): Int = {
def sample(tensor1DMut: Tensor1DMut): Int = {
val p = random.nextFloat()
tensor1DMut.toFloatArray
.iterator
.zipWithIndex
.scanLeft((0f, 0)) { case (sum, i) => (sum._1 + i._1, i._2) }
.dropWhile(_._1 < p)
.next()._2
}
if (temperature == 0f)
logits.toFloatArray.zipWithIndex.maxBy(_._1)._2 // argmax
else {
logits /= temperature
logits.softmaxMut()
if (topp != 1f) {
def selectBySorting(): Int = {
val sorted = logits.toFloatArray.zipWithIndex.sortBy(-_._1).toVector
val cumSum = sorted.iterator.scanLeft(0f)(_ + _._1)
val idx = cumSum.indexWhere(_ > topp)
val interesting = Tensor1DMut(sorted.take(idx).map(_._1).toArray, idx)
interesting /= interesting.sum
sorted(sample(interesting))._2
}
/**
* This method tries to avoid full on sorting of the big logits array.
*
* The idea is to exploit properties of the distribution array:
* 1. we expect that the top-k values are much larger than the rest (i.e. a power-law-like distribution)
* 2. values need to add up to 1
*
* 1. means that the top-p entries are few, so scanning is reasonable compared to sorting
* 2. each scan can be aborted early if we found a value that accounts for more than half of the remaining probability
*/
def selectByScanning(): Int = {
// FIXME: needs work if there are multiple values of prevMax
def maxLessThan(ls: Array[Float], remaining: Float, prevMax: Float): Int = {
//println(f"prevMax: $prevMax remaining: $remaining")
val halfRemaining = remaining / 2
var maxIndex = 0
var max = ls(0)
var i = 1
while (i < ls.length && max < halfRemaining) {
val v = ls(i)
if (v > max && v < prevMax) {
max = v
maxIndex = i
}
i += 1
}
maxIndex
}
val ls = logits.toFloatArray
val scanAttempts = 100
val idxBuffer = new Array[Int](scanAttempts)
val pBuffer = new Array[Float](scanAttempts)
def collect(numFound: Int, sum: Float, prevMax: Float): Int =
if (sum > topp) numFound
else if (numFound >= idxBuffer.size) -1
else {
val maxIdx = maxLessThan(ls, 1f - sum, prevMax)
idxBuffer(numFound) = maxIdx
pBuffer(numFound) = ls(maxIdx)
collect(numFound + 1, sum + ls(maxIdx), ls(maxIdx))
}
val numFound = collect(0, 0f, 2f)
if (numFound == -1) selectBySorting()
else {
val interesting = Tensor1DMut(pBuffer.take(numFound), numFound)
interesting /= interesting.sum
val sid = sample(interesting)
idxBuffer(sid)
}
}
selectByScanning()
} else
sample(logits)
}
}
}
class Llama2Runner(transformer: Llama2Transformer, model: Llama2Model) {
import model.vocab
def iterate(steps: Int, sampler: Sampler = ArgmaxSampler, prompt: String = ""): Iterator[String] = new Iterator[String] {
val promptTokens = bpeEncode(prompt)
var pos = 0
var token = 1
def hasNext: Boolean = pos < steps && token != 0
def next(): String = {
val logits = transformer.step(token, pos)
/* interesting to see what possible completions are
if (pos == -1) {
val myLogits = Tensor1DMut.zero(logits.size)
myLogits := Tensor1DMut(logits, logits.size)
myLogits.softmaxMut()
println()
myLogits.toFloatArray.zipWithIndex.sortBy(-_._1).take(10).foreach {
case (p, t) =>
println(f"${vocab.tokenScores(t)._1}%-20s ${p * 100}%5.2f %%")
}
}*/
val next =
if (pos < promptTokens.length) promptTokens(pos)
else
sampler.sample(Tensor1DMut(logits, logits.size))
val tok = vocab.tokenScores(next)._1
val tokenStr = if (token == 1 && tok == " ") tok.drop(1) else tok
token = next
pos += 1
tokenStr
}
}
def bpeEncode(prompt: String): Seq[Int] = {
val tokMap = vocab.tokenScores.zipWithIndex.map { case ((t, s), i) => t -> (i, s) }.toMap
var toks: Vector[Int] = prompt.map(x => tokMap(x.toString)._1).toVector
@tailrec def mergeTokensStep(toks: Vector[Int]): Vector[Int] = {
val candidates =
toks.sliding(2).zipWithIndex.flatMap {
case (Seq(t1: Int, t2: Int), i: Int) =>
val tok = vocab.tokenScores(t1)._1 + vocab.tokenScores(t2)._1
val id = tokMap.get(tok)
id.map { case (newTok, score) => (i, newTok, score) }
}
if (candidates.isEmpty) toks
else {
val (idx, newTok, _) = candidates.maxBy(_._3)
mergeTokensStep(toks.take(idx) ++ Seq(newTok) ++ toks.drop(idx + 2))
}
}
mergeTokensStep(toks)
}
}