Skip to content

Commit

Permalink
Snitch frep extension support (#50)
Browse files Browse the repository at this point in the history
This patch adds the frep extension to the snitch emulator. Register staggering is not supported currently, but could be added at a later date.

This patch also adds f64 support to the xssr extension.
  • Loading branch information
AntonLydike authored Oct 11, 2023
1 parent 5a939ee commit f50b3a9
Show file tree
Hide file tree
Showing 10 changed files with 296 additions and 103 deletions.
10 changes: 4 additions & 6 deletions riscemu/core/registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
"""

from collections import defaultdict
from typing import Union, Type
from typing import Type

from ..helpers import *

from . import Int32, BaseFloat, Float32, Float64
from . import Int32, BaseFloat


class Registers:
Expand Down Expand Up @@ -222,14 +222,12 @@ def get(self, reg: str, mark_read: bool = True) -> Int32:
self.last_read = reg
return self.vals[reg]

def get_f(self, reg: str, mark_read: bool = True) -> BaseFloat:
def get_f(self, reg: str) -> BaseFloat:
if not self.infinite_regs and reg not in self.float_regs:
raise RuntimeError("Invalid float register: {}".format(reg))
if mark_read:
self.last_read = reg
return self.float_vals[reg]

def set_f(self, reg: str, val: Union[float, BaseFloat]):
def set_f(self, reg: str, val: BaseFloat):
if not self.infinite_regs and reg not in self.float_regs:
raise RuntimeError("Invalid float register: {}".format(reg))
self.float_vals[reg] = self._float_type.bitcast(val)
Expand Down
4 changes: 1 addition & 3 deletions riscemu/core/usermode_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def step(self, verbose: bool = False):
self.cycle += 1
ins = self.mmu.read_ins(self.pc)
if verbose:
print(
FMT_CPU + " Running 0x{:08X}:{} {}".format(self.pc, FMT_NONE, ins)
)
print(FMT_CPU + " 0x{:08X}:{} {}".format(self.pc, FMT_NONE, ins))
self.pc += self.INS_XLEN
self.run_instruction(ins)
except RiscemuBaseException as ex:
Expand Down
12 changes: 6 additions & 6 deletions snitch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@
"""
import sys

from .regs import StreamingRegs
from .xssr import RV32_Xssr_pseudo
from .xssr import Xssr_pseudo
from .frep import FrepEnabledCpu, Xfrep
from riscemu.riscemu_main import RiscemuMain


class SnitchMain(RiscemuMain):
def configure_cpu(self):
super().configure_cpu()
self.cpu.regs = StreamingRegs(self.cpu.mmu)
def instantiate_cpu(self):
self.cpu = FrepEnabledCpu(self.selected_ins_sets, self.cfg)
self.configure_cpu()

def register_all_isas(self):
super().register_all_isas()
self.available_ins_sets.update({"Xssr": RV32_Xssr_pseudo})
self.available_ins_sets.update({"Xssr": Xssr_pseudo, "Xfrep": Xfrep})


if __name__ == "__main__":
Expand Down
103 changes: 103 additions & 0 deletions snitch/frep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import List, Type, Union, Set, Literal

from riscemu.colors import FMT_CPU, FMT_NONE
from riscemu.config import RunConfig
from riscemu.core import UserModeCPU
from riscemu.instructions import InstructionSet, Instruction, RV32F, RV32D

from dataclasses import dataclass

from snitch.regs import StreamingRegs


@dataclass(frozen=True)
class FrepState:
rep_count: int
ins_count: int
mode: Literal["inner", "outer"]


class FrepEnabledCpu(UserModeCPU):
repeat: Union[FrepState, None]
allowed_ins: Set[str]

def __init__(self, instruction_sets: List[Type["InstructionSet"]], conf: RunConfig):
super().__init__(instruction_sets, conf)
self.regs = StreamingRegs(
mem=self.mmu, infinite_regs=conf.unlimited_registers, flen=conf.flen
)
self.repeats = None
# only floating point instructions are allowed inside an frep!
self.allowed_ins = set(x for x, y in RV32F(self).get_instructions())
if conf.flen > 32:
self.allowed_ins.union(x for x, y in RV32D(self).get_instructions())

def step(self, verbose: bool = False):
if self.repeats is None:
super().step(verbose=verbose)
return
# get the spec
spec: FrepState = self.repeats
self.repeats = None

instructions = [
self.mmu.read_ins(self.pc + i * self.INS_XLEN)
for i in range(spec.ins_count)
]

for ins in instructions:
if ins.name not in self.allowed_ins:
# TODO: wrap in a nicer error type
raise RuntimeError(
"Forbidden instruction inside frep loop: {}".format(ins)
)

if verbose:
print(
FMT_CPU
+ "┌────── floating point repetition ({}) {} times".format(
spec.mode, spec.rep_count + 1
)
)
for i, ins in enumerate(instructions):
print(
FMT_CPU
+ "│ 0x{:08X}:{} {}".format(
self.pc + i * self.INS_XLEN, FMT_NONE, ins
)
)
print(FMT_CPU + "└────── end of floating point repetition" + FMT_NONE)

pc = self.pc
if spec.mode == "outer":
for _ in range(spec.rep_count + 1):
for ins in instructions:
self.run_instruction(ins)
elif spec.mode == "inner":
for ins in instructions:
for _ in range(spec.rep_count + 1):
self.run_instruction(ins)
else:
raise RuntimeError(f"Unknown frep mode: {spec.mode}")
self.cycle += (spec.rep_count + 1) * spec.ins_count
self.pc = pc + (spec.ins_count * self.INS_XLEN)


class Xfrep(InstructionSet):
def instruction_frep_o(self, ins: Instruction):
self.frep(ins, "outer")

def instruction_frep_i(self, ins: Instruction):
self.frep(ins, "inner")

def frep(self, ins: Instruction, mode: Literal["inner", "outer"]):
assert isinstance(self.cpu, FrepEnabledCpu)
assert len(ins.args) == 4
assert ins.get_imm(2).abs_value.value == 0, "staggering not supported yet"
assert ins.get_imm(3).abs_value.value == 0, "staggering not supported yet"

self.cpu.repeats = FrepState(
rep_count=self.regs.get(ins.get_reg(0)).unsigned_value,
ins_count=ins.get_imm(1).abs_value.value,
mode=mode,
)
22 changes: 9 additions & 13 deletions snitch/regs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
mem: MMU,
xssr_regs: Tuple[str] = ("ft0", "ft1", "ft2"),
infinite_regs: bool = False,
flen: int = 64,
):
self.mem = mem
self.enabled = False
Expand All @@ -61,11 +62,11 @@ def __init__(
stream_def = StreamDef()
self.dm_by_id.append(stream_def)
self.streams[reg] = stream_def
super().__init__(infinite_regs)
super().__init__(infinite_regs=infinite_regs, flen=flen)

def get_f(self, reg, mark_read=True) -> "BaseFloat":
def get_f(self, reg) -> "BaseFloat":
if not self.enabled or reg not in self.streams:
return super().get_f(reg, mark_read)
return super().get_f(reg)

# do the streaming stuff:
stream = self.streams[reg]
Expand All @@ -74,26 +75,21 @@ def get_f(self, reg, mark_read=True) -> "BaseFloat":
# TODO: Check overflow
# TODO: repetition
addr = stream.base + (stream.pos * stream.stride)
val = self.mem.read_float(addr)
val = self._float_type(self.mem.read(addr, self.flen // 8))
# increment pos
print(
"stream: got val {} from addr 0x{:x}, stream {}".format(val, addr, stream)
)
stream.pos += 1
return val

def set_f(self, reg, val: "BaseFloat", mark_set=True) -> bool:
def set_f(self, reg, val: "BaseFloat") -> bool:
if not self.enabled or reg not in self.streams:
return super().set_f(reg, mark_set)
return super().set_f(reg, val)

stream = self.streams[reg]
assert stream.mode is StreamMode.WRITE

addr = stream.base + (stream.pos * stream.stride)
self.mem.write(addr, 4, bytearray(val.bytes))
data = val.bytes
self.mem.write(addr + (self.flen // 8) - len(data), len(data), bytearray(data))

print(
"stream: wrote val {} into addr 0x{:x}, stream {}".format(val, addr, stream)
)
stream.pos += 1
return True
2 changes: 1 addition & 1 deletion snitch/xssr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .regs import StreamingRegs, StreamDef, StreamMode


class RV32_Xssr_pseudo(InstructionSet):
class Xssr_pseudo(InstructionSet):
def instruction_ssr_enable(self, ins: Instruction):
self._stream.enabled = True

Expand Down
24 changes: 24 additions & 0 deletions test/filecheck/snitch/frep_only.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: python3 -m snitch %s -o libc -v | filecheck %s

.text
.globl main
main:
// load constants
li t0, 0
fcvt.s.w ft0, t0
li t0, 1
fcvt.s.w ft1, t0

// repeat 100 times
li t0, 99
frep.i t0, 1, 0, 0
fadd.s ft0, ft0, ft1 // add one

// print result to stdout
printf "100 * 1 = {:f32}", ft0
// CHECK: 100 * 1 = 100.0
// return 0
li a0, 0
ret

// CHECK-NEXT: [CPU] Program exited with code 0
75 changes: 75 additions & 0 deletions test/filecheck/snitch/ssr_frep.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// RUN: python3 -m snitch %s -o libc -v --flen 32 | filecheck %s

.data

vec0:
.word 0x0, 0x3f800000, 0x40000000, 0x40400000, 0x40800000, 0x40a00000, 0x40c00000, 0x40e00000, 0x41000000, 0x41100000
vec1:
.word 0x0, 0x3f800000, 0x40000000, 0x40400000, 0x40800000, 0x40a00000, 0x40c00000, 0x40e00000, 0x41000000, 0x41100000
dest:
.space 40
expected:
.word 0x0, 0x3e800000, 0x3f800000, 0x40100000, 0x40800000, 0x40c80000, 0x41100000, 0x41440000, 0x41800000, 0x41a20000

.text
.globl main

main:
// ssr config
ssr.configure 0, 10, 4
ssr.configure 1, 10, 4
ssr.configure 2, 10, 4

// ft0 streams from vec0
la a0, vec0
ssr.read a0, 0, 0

// ft1 streams from vec1
la a0, vec1
ssr.read a0, 1, 0

// ft2 streams to dest
la a0, dest
ssr.write a0, 2, 0

li a0, 9
// some constant to divide by
li t0, 4
fcvt.s.w ft3, t0
ssr.enable

frep.o a0, 2, 0, 0
fmul.s ft4, ft0, ft1 // ft3 = vec0[i] * vec1[i]
fdiv.s ft2, ft4, ft3 // dest[i] = ft3 / 4

// stop ssr
ssr.disable

// check values were written correctly:
la t0, dest
la t1, expected
li a0, 36
loop:
add s0, t0, a0
add s1, t1, a0

// load vec0, vec1 and dest elements
flw ft0, 0(s0)
flw ft1, 0(s1)

// assert ft0 == ft1 (expected[i] == dest[i])
feq.s s0, ft0, ft1
beq zero, s0, fail

addi a0, a0, -4
bge a0, zero loop

li a0, 0
ret

fail:
printf "Assertion failure: {} != {} (at {})", ft0, ft1, a0
li a0, -1
ret

// CHECK: [CPU] Program exited with code 0
Loading

0 comments on commit f50b3a9

Please sign in to comment.