Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instructions refactor #89

Merged
merged 2 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Blockchain/Sources/Blockchain/Config/ProtocolConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ public struct ProtocolConfig: Sendable {
public typealias ProtocolConfigRef = Ref<ProtocolConfig>

extension ProtocolConfig: PvmConfig {}
// silence the warning about cross module conformances as we owns all the code
extension Ref: @retroactive PvmConfig where T == ProtocolConfig {
public var pvmDynamicAddressAlignmentFactor: Int { value.pvmDynamicAddressAlignmentFactor }
public var pvmProgramInitInputDataSize: Int { value.pvmProgramInitInputDataSize }
public var pvmProgramInitPageSize: Int { value.pvmProgramInitPageSize }
public var pvmProgramInitSegmentSize: Int { value.pvmProgramInitSegmentSize }
}

extension ProtocolConfig {
public enum AuditTranchePeriod: ReadInt {
Expand Down
4 changes: 1 addition & 3 deletions Blockchain/Sources/Blockchain/Safrole.swift
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ func generateFallbackIndices(entropy: Data32, count: Int, length: Int) throws ->
hasher.update(Data(bytes))
let hash = hasher.finalize()
let hash4 = hash.data[0 ..< 4]
let idx: UInt32 = hash4.withUnsafeBytes { ptr in
ptr.loadUnaligned(as: UInt32.self)
}
let idx = hash4.decode(UInt32.self)
return Int(idx % UInt32(length))
}
}
Expand Down
2 changes: 1 addition & 1 deletion JAMTests/Tests/JAMTests/PVMTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct PVMTests {
gas: UInt64(testCase.initialGas),
memory: memory
)
let engine = Engine()
let engine = Engine(config: DefaultPvmConfig())
let exitReason = engine.execute(program: program, state: vmState)
let exitReason2: Status = switch exitReason {
case .halt:
Expand Down
21 changes: 10 additions & 11 deletions PolkaVM/Sources/PolkaVM/Engine.swift
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import Foundation

public class Engine {
public enum Constants {
public static let exitAddress: UInt32 = 0xFFFF_0000
}
let config: PvmConfig

public init() {}
public init(config: PvmConfig) {
self.config = config
}

public func execute(program: ProgramCode, state: VMState) -> ExitReason {
let context = ExecutionContext(state: state, config: config)
while true {
guard state.getGas() > 0 else {
return .outOfGas
}
if case let .exit(reason) = step(program: program, state: state) {
if case let .exit(reason) = step(program: program, context: context) {
return reason
}
}
}

public func step(program: ProgramCode, state: VMState) -> ExecOutcome {
let pc = state.pc
let skip = program.skip(state.pc)
public func step(program: ProgramCode, context: ExecutionContext) -> ExecOutcome {
let pc = context.state.pc
let skip = program.skip(pc)
let startIndex = program.code.startIndex + Int(pc)
let endIndex = startIndex + 1 + Int(skip)
let data = if endIndex <= program.code.endIndex {
Expand All @@ -32,8 +33,6 @@ public class Engine {
return .exit(.panic(.invalidInstruction))
}

let res = inst.execute(state: state, skip: skip)

return res
return inst.execute(context: context, skip: skip)
}
}
26 changes: 18 additions & 8 deletions PolkaVM/Sources/PolkaVM/Instruction.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,31 @@ public protocol Instruction {
init(data: Data) throws

func gasCost() -> UInt64
func updatePC(state: VMState, skip: UInt32) -> ExecOutcome
func updatePC(context: ExecutionContext, skip: UInt32) -> ExecOutcome

// protected method
func _executeImpl(state: VMState) throws -> ExecOutcome
func _executeImpl(context: ExecutionContext) throws -> ExecOutcome
}

public class ExecutionContext {
public let state: VMState
public let config: PvmConfig

public init(state: VMState, config: PvmConfig) {
self.state = state
self.config = config
}
}

extension Instruction {
public func execute(state: VMState, skip: UInt32) -> ExecOutcome {
state.consumeGas(gasCost())
public func execute(context: ExecutionContext, skip: UInt32) -> ExecOutcome {
context.state.consumeGas(gasCost())
do {
let execRes = try _executeImpl(state: state)
let execRes = try _executeImpl(context: context)
if case .exit = execRes {
return execRes
}
return updatePC(state: state, skip: skip)
return updatePC(context: context, skip: skip)
} catch let e as Memory.Error {
return .exit(.pageFault(e.address))
} catch {
Expand All @@ -34,8 +44,8 @@ extension Instruction {
1
}

public func updatePC(state: VMState, skip: UInt32) -> ExecOutcome {
state.increasePC(skip + 1)
public func updatePC(context: ExecutionContext, skip: UInt32) -> ExecOutcome {
context.state.increasePC(skip + 1)
return .continued
}
}
39 changes: 39 additions & 0 deletions PolkaVM/Sources/PolkaVM/Instructions/BranchCompare.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
protocol BranchCompare {
static func compare(a: UInt32, b: UInt32) -> Bool
}

struct CompareEq: BranchCompare {
static func compare(a: UInt32, b: UInt32) -> Bool {
Int32(bitPattern: a) == Int32(bitPattern: b)
}
}

struct CompareNe: BranchCompare {
static func compare(a: UInt32, b: UInt32) -> Bool {
Int32(bitPattern: a) != Int32(bitPattern: b)
}
}

struct CompareLt: BranchCompare {
static func compare(a: UInt32, b: UInt32) -> Bool {
a < b
}
}

struct CompareLe: BranchCompare {
static func compare(a: UInt32, b: UInt32) -> Bool {
a <= b
}
}

struct CompareGe: BranchCompare {
static func compare(a: UInt32, b: UInt32) -> Bool {
a >= b
}
}

struct CompareGt: BranchCompare {
static func compare(a: UInt32, b: UInt32) -> Bool {
a > b
}
}
79 changes: 79 additions & 0 deletions PolkaVM/Sources/PolkaVM/Instructions/BranchInstructionBase.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import Foundation

// for branch in A.5.7
protocol BranchInstructionBase<Compare>: Instruction {
associatedtype Compare: BranchCompare

var register: Registers.Index { get set }
var value: UInt32 { get set }
var offset: UInt32 { get set }

func condition(state: VMState) -> Bool
}

extension BranchInstructionBase {
public static func parse(data: Data) throws -> (Registers.Index, UInt32, UInt32) {
let register = try Registers.Index(ra: data.at(relative: 0))
let (value, offset) = try Instructions.decodeImmediate2(data, divideBy: 16)
return (register, value, offset)
}

public func _executeImpl(context _: ExecutionContext) throws -> ExecOutcome { .continued }

public func updatePC(context: ExecutionContext, skip: UInt32) -> ExecOutcome {
guard Instructions.isBranchValid(context: context, offset: offset) else {
return .exit(.panic(.invalidBranch))
}
if condition(state: context.state) {
context.state.increasePC(offset)
} else {
context.state.increasePC(skip + 1)
}
return .continued
}

public func condition(state: VMState) -> Bool {
let regVal = state.readRegister(register)
return Compare.compare(a: regVal, b: value)
}
}

// for branch in A.5.10
protocol BranchInstructionBase2<Compare>: Instruction {
associatedtype Compare: BranchCompare

var r1: Registers.Index { get set }
var r2: Registers.Index { get set }
var offset: UInt32 { get set }

func condition(state: VMState) -> Bool
}

extension BranchInstructionBase2 {
public static func parse(data: Data) throws -> (Registers.Index, Registers.Index, UInt32) {
let offset = try Instructions.decodeImmediate(data.at(relative: 1...))
let r1 = try Registers.Index(ra: data.at(relative: 0))
let r2 = try Registers.Index(rb: data.at(relative: 0))
return (r1, r2, offset)
}

public func _executeImpl(context _: ExecutionContext) throws -> ExecOutcome { .continued }

public func updatePC(context: ExecutionContext, skip: UInt32) -> ExecOutcome {
guard Instructions.isBranchValid(context: context, offset: offset) else {
return .exit(.panic(.invalidBranch))
}
if condition(state: context.state) {
context.state.increasePC(offset)
} else {
context.state.increasePC(skip + 1)
}
return .continued
}

public func condition(state: VMState) -> Bool {
let r1Val = state.readRegister(r1)
let r2Val = state.readRegister(r2)
return Compare.compare(a: r1Val, b: r2Val)
}
}
79 changes: 79 additions & 0 deletions PolkaVM/Sources/PolkaVM/Instructions/Instructions+Helpers.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import Foundation

extension Instructions {
enum Constants {
static let djumpHaltAddress: UInt32 = 0xFFFF_0000
}

static func decodeImmediate(_ data: Data) -> UInt32 {
let len = min(data.count, 4)
if len == 0 {
return 0
}
var value: UInt32 = 0
for i in 0 ..< len {
value = value | (UInt32(data[relative: i]) << (8 * i))
}
let shift = (4 - len) * 8
// shift left so that the MSB is the sign bit
// and then do signed shift right to fill the empty bits using the sign bit
// and then convert back to UInt32
return UInt32(bitPattern: Int32(bitPattern: value << shift) >> shift)
}

static func decodeImmediate2(_ data: Data, divideBy: UInt8 = 1, minus: Int = 1) throws -> (UInt32, UInt32) {
let lX1 = try Int((data.at(relative: 0) / divideBy) & 0b111)
let lX = min(4, lX1)
let lY = min(4, max(0, data.count - Int(lX) - minus))

let vX = try decodeImmediate(data.at(relative: 1 ..< 1 + lX))
let vY = try decodeImmediate(data.at(relative: (1 + lX) ..< (1 + lX + lY)))
return (vX, vY)
}

static func isBranchValid(context: ExecutionContext, offset: UInt32) -> Bool {
context.state.program.basicBlockIndices.contains(context.state.pc &+ offset)
}

static func isDjumpValid(context: ExecutionContext, target a: UInt32, targetAligned: UInt32) -> Bool {
let za = context.config.pvmDynamicAddressAlignmentFactor
return !(a == 0 ||
a > context.state.program.jumpTable.count * za ||
Int(a) % za != 0 ||
context.state.program.basicBlockIndices.contains(targetAligned))
}

static func djump(context: ExecutionContext, target: UInt32) -> ExecOutcome {
if target == Constants.djumpHaltAddress {
return .exit(.halt)
}

let entrySize = Int(context.state.program.jumpTableEntrySize)
let start = ((Int(target) / context.config.pvmDynamicAddressAlignmentFactor) - 1) * entrySize
let end = start + entrySize
var targetAlignedData = context.state.program.jumpTable[relative: start ..< end]
guard let targetAligned = targetAlignedData.decode() else {
fatalError("unreachable: jump table entry should be valid")
}

guard isDjumpValid(context: context, target: target, targetAligned: UInt32(truncatingIfNeeded: targetAligned)) else {
return .exit(.panic(.invalidDynamicJump))
}

context.state.updatePC(UInt32(targetAligned))
return .continued
}

static func deocdeRegisters(_ data: Data) throws -> (Registers.Index, Registers.Index) {
let ra = try Registers.Index(ra: data.at(relative: 0))
let rb = try Registers.Index(rb: data.at(relative: 0))
return (ra, rb)
}

static func deocdeRegisters(_ data: Data) throws -> (Registers.Index, Registers.Index, Registers.Index) {
let ra = try Registers.Index(ra: data.at(relative: 0))
let rb = try Registers.Index(rb: data.at(relative: 0))
let rd = try Registers.Index(rd: data.at(relative: 1))
return (ra, rb, rd)
}
}
Loading
Loading