Skip to content

Commit

Permalink
pvm memory optimization (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
qiweiii authored Sep 11, 2024
1 parent 5368305 commit d11dc12
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 154 deletions.
1 change: 1 addition & 0 deletions PolkaVM/Sources/PolkaVM/Instruction.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ extension Instruction {
} catch let e as Memory.Error {
// this passes test vector
context.state.consumeGas(gasCost())
logger.debug("memory error: \(e)")
return .exit(.pageFault(e.address))
} catch let e {
// other unknown errors
Expand Down
325 changes: 171 additions & 154 deletions PolkaVM/Sources/PolkaVM/Memory.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,64 @@
import Foundation

public class MemorySection {
/// lowest address bound
public let startAddressBound: UInt32
/// highest address bound
public let endAddressBound: UInt32
/// is the section writable
public let isWritable: Bool
/// allocated data
fileprivate var data: Data

/// current data end address, also the start address of empty space
public var currentEnd: UInt32 {
startAddressBound + UInt32(data.count)
}

public init(startAddressBound: UInt32, endAddressBound: UInt32, data: Data, isWritable: Bool) {
self.startAddressBound = startAddressBound
self.endAddressBound = endAddressBound
self.data = data
self.isWritable = isWritable
}
}

extension MemorySection {
public func read(address: UInt32, length: Int) throws(Memory.Error) -> Data {
guard startAddressBound <= address, address + UInt32(length) < endAddressBound else {
throw Memory.Error.pageFault(address)
}
let start = address - startAddressBound
let end = start + UInt32(length)

let validCount = min(end, UInt32(data.count))
let dataToRead = data[start ..< validCount]

let zeroCount = max(0, Int(end - validCount))
let zeros = Data(repeating: 0, count: zeroCount)

return dataToRead + zeros
}

public func write(address: UInt32, values: some Sequence<UInt8>) throws(Memory.Error) {
let valuesData = Data(values)
guard isWritable else {
throw Memory.Error.notWritable(address)
}
guard startAddressBound <= address, address + UInt32(valuesData.count) < endAddressBound else {
throw Memory.Error.notWritable(address)
}

let start = address - startAddressBound
let end = start + UInt32(valuesData.count)
guard end < data.count else {
throw Memory.Error.notWritable(address)
}

data[start ..< end] = valuesData
}
}

public class Memory {
public enum Error: Swift.Error {
case pageFault(UInt32)
Expand All @@ -18,18 +77,43 @@ public class Memory {
}
}

private let pageMap: [(address: UInt32, length: UInt32, writable: Bool)]
private var chunks: [(address: UInt32, data: Data)]
private let heapStart: UInt32
private var heapEnd: UInt32 // start idx of unallocated heap
private let heapLimit: UInt32
// standard program sections
private var readOnly: MemorySection?
private var heap: MemorySection?
private var stack: MemorySection?
private var argument: MemorySection?

// general program sections
private var memorySections: [MemorySection] = []

/// General program init with a fixed page map and some initial data
public init(pageMap: [(address: UInt32, length: UInt32, writable: Bool)], chunks: [(address: UInt32, data: Data)]) {
self.pageMap = pageMap
self.chunks = chunks
heapStart = pageMap.first(where: { $0.writable })?.address ?? 0
heapLimit = UInt32.max
heapEnd = chunks.reduce(0) { max($0, $1.address + UInt32($1.data.count)) }
readOnly = nil
heap = nil
stack = nil
argument = nil
memorySections = []

let sortedPageMap = pageMap.sorted(by: { $0.address < $1.address })
let sortedChunks = chunks.sorted(by: { $0.address < $1.address })

for (address, length, writable) in sortedPageMap {
var data = Data(repeating: 0, count: Int(length))
if sortedChunks.count != 0 {
let chunkIndex = Memory.binarySearch(array: sortedChunks.map(\.address), value: address)
let chunk = sortedChunks[chunkIndex]
if address <= chunk.address, chunk.address + UInt32(chunk.data.count) <= address + length {
data = chunk.data
}
}
let section = MemorySection(
startAddressBound: address,
endAddressBound: address + length,
data: data,
isWritable: writable
)
memorySections.append(section)
}
}

/// Standard Program init
Expand All @@ -45,174 +129,107 @@ public class Memory {
let argumentDataLen = UInt32(argumentData.count)

let heapStart = 2 * ZQ + Q(readOnlyLen, config)

pageMap = [
(ZQ, readOnlyLen, false),
(ZQ + readOnlyLen, P(readOnlyLen, config) - readOnlyLen, false),
(heapStart, readWriteLen, true), // heap
(heapStart + readWriteLen, P(readWriteLen, config) + heapEmptyPagesSize - readWriteLen, true), // heap
(UInt32(config.pvmProgramInitStackBaseAddress) - P(stackSize, config), stackSize, true), // stack
(UInt32(config.pvmProgramInitInputStartAddress), argumentDataLen, false), // argument
(UInt32(config.pvmProgramInitInputStartAddress) + argumentDataLen, P(argumentDataLen, config) - argumentDataLen, false),
]

chunks = [
(ZQ, readOnlyData),
(heapStart, readWriteData),
(UInt32(config.pvmProgramInitInputStartAddress), argumentData),
]

self.heapStart = heapStart
heapLimit = heapStart + P(readWriteLen, config) + heapEmptyPagesSize
heapEnd = heapStart + readWriteLen
let stackPageAlignedSize = P(stackSize, config)

readOnly = MemorySection(
startAddressBound: ZQ,
endAddressBound: ZQ + P(readOnlyLen, config),
data: readWriteData,
isWritable: false
)
heap = MemorySection(
startAddressBound: heapStart,
endAddressBound: heapStart + P(readWriteLen, config) + heapEmptyPagesSize,
data: readWriteData,
isWritable: true
)
stack = MemorySection(
startAddressBound: UInt32(config.pvmProgramInitStackBaseAddress) - stackPageAlignedSize,
endAddressBound: UInt32(config.pvmProgramInitStackBaseAddress),
// TODO: check is this necessary
data: Data(repeating: 0, count: Int(stackPageAlignedSize)),
isWritable: true
)
argument = MemorySection(
startAddressBound: UInt32(config.pvmProgramInitInputStartAddress),
endAddressBound: UInt32(config.pvmProgramInitInputStartAddress) + P(argumentDataLen, config),
data: argumentData,
isWritable: false
)
}

public func isWritable(address: UInt32) -> Bool {
// check heap range
guard heapStart <= address, address < heapLimit else {
return false
}

// TODO: optimize
for page in pageMap {
if page.address <= address, address < page.address + page.length {
return page.writable
/// if value not in array, return the index of the previous element or 0
static func binarySearch(array: [UInt32], value: UInt32) -> Int {
var low = 0
var high = array.count - 1
while low <= high {
let mid = (low + high) / 2
if array[mid] < value {
low = mid + 1
} else if array[mid] > value {
high = mid - 1
} else {
return mid
}
}

return false
return max(0, low - 1)
}

public func read(address: UInt32) throws(Error) -> UInt8 {
// TODO: optimize this
// check for chunks
for chunk in chunks {
if chunk.address <= address, address < chunk.address + UInt32(chunk.data.count) {
return chunk.data[Int(address - chunk.address)]
private func getSection(forAddress address: UInt32) throws(Error) -> MemorySection {
if memorySections.count != 0 {
return memorySections[Memory.binarySearch(array: memorySections.map(\.startAddressBound), value: address)]
} else if let readOnly {
if address >= readOnly.startAddressBound, address < readOnly.endAddressBound {
return readOnly
}
}
// check for page map
for page in pageMap {
if page.address <= address, address < page.address + page.length {
return 0
} else if let heap {
if address >= heap.startAddressBound, address < heap.endAddressBound {
return heap
}
} else if let stack {
if address >= stack.startAddressBound, address < stack.endAddressBound {
return stack
}
} else if let argument {
if address >= argument.startAddressBound, address < argument.endAddressBound {
return argument
}
}
throw Error.pageFault(address)
}

public func read(address: UInt32) throws(Error) -> UInt8 {
try getSection(forAddress: address).read(address: address, length: 1).first ?? 0
}

public func read(address: UInt32, length: Int) throws -> Data {
// TODO: optimize this
// check for chunks
for chunk in chunks {
if chunk.address <= address, address < chunk.address + UInt32(chunk.data.count) {
let startIndex = Int(address - chunk.address)
let endIndex = min(startIndex + length, chunk.data.endIndex)
let res = chunk.data[startIndex ..< endIndex]
let remaining = length - res.count
if remaining == 0 {
return res
} else {
let startAddress = chunk.address &+ UInt32(chunk.data.count) // wrapped add
let remainingData = try read(address: startAddress, length: remaining)
return res + remainingData
}
}
}
// check for page map
for page in pageMap {
if page.address <= address, address < page.address + page.length {
// TODO: handle reads that cross page boundaries
return Data(repeating: 0, count: length)
}
}
throw Error.pageFault(address)
try getSection(forAddress: address).read(address: address, length: length)
}

public func write(address: UInt32, value: UInt8) throws(Error) {
guard isWritable(address: address) else {
throw Error.notWritable(address)
}

// TODO: optimize this
// check for chunks
for i in 0 ..< chunks.count {
var chunk = chunks[i]
if chunk.address <= address, address < chunk.address + UInt32(chunk.data.count) {
chunk.data[Int(address - chunk.address)] = value
chunks[i] = chunk
return
}
}
// check for page map
for page in pageMap {
if page.address <= address, address < page.address + page.length {
var newChunk = (address: address, data: Data(repeating: 0, count: Int(page.length)))
newChunk.data[Int(address - page.address)] = value
chunks.append(newChunk)
heapEnd = max(heapEnd, address + 1)
return
}
}
throw Error.notWritable(address)
try getSection(forAddress: address).write(address: address, values: Data([value]))
}

public func write(address: UInt32, values: some Sequence<UInt8>) throws(Error) {
guard isWritable(address: address) else {
throw Error.notWritable(address)
}

// TODO: optimize this
// check for chunks
for i in 0 ..< chunks.count {
var chunk = chunks[i]
if chunk.address <= address, address < chunk.address + UInt32(chunk.data.count) {
var idx = Int(address - chunk.address)
for v in values {
if idx == chunk.data.endIndex {
chunk.data.append(v)
} else {
chunk.data[idx] = v
}
idx += 1
}
chunks[i] = chunk
return
}
}
// check for page map
for page in pageMap {
if page.address <= address, address < page.address + page.length {
var newChunk = (address: address, data: Data(repeating: 0, count: Int(page.length)))
var idx = Int(address - page.address)
for v in values {
if idx == newChunk.data.endIndex {
throw Error.notWritable(address)
} else {
newChunk.data[idx] = v
}
idx += 1
}
chunks.append(newChunk)
heapEnd = max(heapEnd, UInt32(idx))
return
}
}
throw Error.notWritable(address)
try getSection(forAddress: address).write(address: address, values: values)
}

public func sbrk(_ increment: UInt32) throws -> UInt32 {
// TODO: optimize
for page in pageMap {
let pageEnd = page.address + page.length
if page.writable, heapEnd >= page.address, heapEnd + increment < pageEnd {
let newChunk = (address: heapEnd, data: Data(repeating: 0, count: Int(increment)))
chunks.append(newChunk)
heapEnd += increment
return heapEnd
}
var section: MemorySection
if let heap {
section = heap
} else if memorySections.count != 0 {
section = memorySections.last!
} else {
throw Error.pageFault(0)
}

throw Error.outOfMemory(heapEnd)
let oldSectionEnd = section.currentEnd
guard section.isWritable, oldSectionEnd + increment < section.endAddressBound else {
throw Error.outOfMemory(oldSectionEnd)
}
section.data.append(Data(repeating: 0, count: Int(increment)))
return oldSectionEnd
}
}

Expand Down

0 comments on commit d11dc12

Please sign in to comment.