Skip to content

Commit

Permalink
internal/kernel: implement ring buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
adambabik committed Jan 24, 2023
1 parent 04b908d commit e6f0579
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 12 deletions.
153 changes: 153 additions & 0 deletions internal/kernel/ring_buffer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package kernel

import (
"io"
"sync"
"sync/atomic"

"github.com/pkg/errors"
)

type ringBuffer struct {
mu sync.Mutex
buf []byte
size int
r int // next position to read
w int // next position to write
isFull bool
state atomic.Bool
more chan struct{}
closed chan struct{}
}

func newRingBuffer(size int) *ringBuffer {
return &ringBuffer{
buf: make([]byte, size),
size: size,
}
}

func (b *ringBuffer) Reset() {
b.mu.Lock()
b.r = 0
b.w = 0
b.mu.Unlock()
}

func (b *ringBuffer) Read(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}

b.mu.Lock()
n, err = b.read(p)
b.mu.Unlock()

if err != nil && errors.Is(err, io.EOF) && b.state.Load() {
select {
case <-b.more:
case <-b.closed:
return 0, io.EOF
}
return n, nil
}

return n, err
}

func (b *ringBuffer) read(p []byte) (n int, err error) {
if b.w == b.r && !b.isFull {
return 0, io.EOF
}

if b.w > b.r {
n = b.w - b.r
if n > len(p) {
n = len(p)
}
copy(p, b.buf[b.r:b.r+n])
b.r = (b.r + n) % b.size
return
}

n = b.size - b.r + b.w
if n > len(p) {
n = len(p)
}

if b.r+n <= b.size || b.isFull {
copy(p, b.buf[b.r:b.r+n])
} else {
copy(p, b.buf[b.r:b.size])
c1 := b.size - b.r
c2 := n - c1
copy(p[c1:], b.buf[0:c2])
}
b.r = (b.r + n) % b.size

return n, err
}

func (b *ringBuffer) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
if len(p) > b.size {
return 0, errors.New("buffer is too small")
}

b.mu.Lock()
n, err = b.write(p)
b.mu.Unlock()

select {
case b.more <- struct{}{}:
default:
}

return n, err
}

func (b *ringBuffer) write(p []byte) (n int, err error) {
var avail int
if b.w >= b.r {
avail = b.size - b.w + b.r
} else {
avail = b.r - b.w
}

n = len(p)

if len(p) > avail {
b.isFull = false
b.r = b.w
copy(b.buf[b.w:], p[:b.size-b.w])
b.w = copy(b.buf[0:], p[b.size-b.w:])
return n, nil
}

if b.w >= b.r {
c1 := b.size - b.w
if c1 >= n {
copy(b.buf[b.w:], p)
b.w += n
} else {
copy(b.buf[b.w:], p[:c1])
c2 := n - c1
copy(b.buf[0:], p[c1:])
b.w = c2
}
} else {
copy(b.buf[b.w:], p)
b.w += n
}

if b.w == b.size {
b.isFull = true
b.w = 0
} else {
b.isFull = false
}

return n, err
}
54 changes: 54 additions & 0 deletions internal/kernel/ring_buffer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package kernel

import (
"testing"

"github.com/stretchr/testify/assert"
)

func Test_ringBuffer(t *testing.T) {
assertRead := func(t *testing.T, b *ringBuffer, expected []byte) {
got := make([]byte, len(expected))
n, err := b.read(got)
assert.Nil(t, err)
assert.Equal(t, len(expected), n)
assert.Equal(t, expected, got)
}

t.Run("basic", func(t *testing.T) {
data := []byte("hello")
buf := newRingBuffer(10)
n, err := buf.Write(data)
assert.Nil(t, err)
assert.Equal(t, len(data), n)
assertRead(t, buf, data)
})

t.Run("overwriting", func(t *testing.T) {
data := []byte("hello")
buf := newRingBuffer(5)
n, err := buf.Write(data)
assert.Nil(t, err)
assert.Equal(t, len(data), n)

data = []byte("world")
n, err = buf.Write(data)
assert.Nil(t, err)
assert.Equal(t, len(data), n)
assertRead(t, buf, data)
})

t.Run("wrapping", func(t *testing.T) {
data := []byte("hello")
buf := newRingBuffer(10)
n, err := buf.Write(data)
assert.Nil(t, err)
assert.Equal(t, len(data), n)

data = []byte("123world")
n, err = buf.Write(data)
assert.Nil(t, err)
assert.Equal(t, len(data), n)
assertRead(t, buf, data)
})
}
28 changes: 16 additions & 12 deletions internal/kernel/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import (
)

type limitedBuffer struct {
*bytes.Buffer // TODO: switch to a ring buffer
mu sync.Mutex
state *atomic.Uint32
more chan struct{}
closed chan struct{}
buf *ringBuffer
mu sync.Mutex
state *atomic.Uint32
more chan struct{}
closed chan struct{}
}

func (b *limitedBuffer) Close() error {
Expand All @@ -37,7 +37,7 @@ func (b *limitedBuffer) Close() error {

func (b *limitedBuffer) Read(p []byte) (int, error) {
b.mu.Lock()
n, err := b.Buffer.Read(p)
n, err := b.buf.Read(p)
b.mu.Unlock()
if err != nil && errors.Is(err, io.EOF) && b.state.Load() == 0 {
select {
Expand All @@ -50,15 +50,19 @@ func (b *limitedBuffer) Read(p []byte) (int, error) {
return n, err
}

func (b *limitedBuffer) Write(p []byte) (int, error) {
func (b *limitedBuffer) Write(p []byte) (n int, err error) {
b.mu.Lock()
n, _ := b.Buffer.Write(p)
n, err = b.buf.Write(p)
b.mu.Unlock()
select {
case b.more <- struct{}{}:
default:
}
return n, nil
return
}

func (b *limitedBuffer) Reset() {
b.buf.Reset()
}

type session struct {
Expand All @@ -82,7 +86,7 @@ func newSession(command, prompt string, logger *zap.Logger) (*session, []byte, e
}

buf := &limitedBuffer{
Buffer: bytes.NewBuffer(nil),
buf: newRingBuffer(10 * 1024 * 1024), // 10MB
state: new(atomic.Uint32),
more: make(chan struct{}),
closed: make(chan struct{}),
Expand Down Expand Up @@ -147,11 +151,11 @@ func newSession(command, prompt string, logger *zap.Logger) (*session, []byte, e
}

// Reset buffer as we don't want to send the setting changes back.
_, _ = io.Copy(io.Discard, s.output)
s.output.Reset()

// Write the matched prompt back as invitation.
_, _ = s.output.Write(match)
_, _ = s.output.WriteRune(' ')
_, _ = s.output.Write([]byte{' '})

go func() {
s.setErr(<-cmdErr)
Expand Down

0 comments on commit e6f0579

Please sign in to comment.