diff --git a/connection_windows.go b/connection_windows.go index 70083c0b6..943e60b3c 100644 --- a/connection_windows.go +++ b/connection_windows.go @@ -23,6 +23,7 @@ package gnet import ( "net" + "sync" "github.com/panjf2000/gnet/pool/bytebuffer" prb "github.com/panjf2000/gnet/pool/ringbuffer" @@ -34,8 +35,14 @@ type stderr struct { err error } -type wakeReq struct { - c *stdConn +type signalTask struct { + run func(*stdConn) error + c *stdConn +} + +type dataTask struct { + run func([]byte) (int, error) + buf []byte } type tcpConn struct { @@ -47,6 +54,11 @@ type udpConn struct { c *stdConn } +var ( + signalTaskPool = sync.Pool{New: func() interface{} { return new(signalTask) }} + dataTaskPool = sync.Pool{New: func() interface{} { return new(dataTask) }} +) + type stdConn struct { ctx interface{} // user-defined context conn net.Conn // original connection @@ -134,6 +146,13 @@ func (c *stdConn) read() ([]byte, error) { return c.codec.Decode(c) } +func (c *stdConn) write(data []byte) (n int, err error) { + if c.conn != nil { + n, err = c.conn.Write(data) + } + return +} + // ================================= Public APIs of gnet.Conn ================================= func (c *stdConn) Read() []byte { @@ -212,12 +231,10 @@ func (c *stdConn) BufferLength() int { func (c *stdConn) AsyncWrite(buf []byte) (err error) { var encodedBuf []byte if encodedBuf, err = c.codec.Encode(c, buf); err == nil { - c.loop.ch <- func() (err error) { - if c.conn != nil { - _, err = c.conn.Write(encodedBuf) - } - return - } + task := dataTaskPool.Get().(*dataTask) + task.run = c.write + task.buf = encodedBuf + c.loop.ch <- task } return } @@ -228,14 +245,18 @@ func (c *stdConn) SendTo(buf []byte) (err error) { } func (c *stdConn) Wake() error { - c.loop.ch <- wakeReq{c} + task := signalTaskPool.Get().(*signalTask) + task.run = c.loop.loopWake + task.c = c + c.loop.ch <- task return nil } func (c *stdConn) Close() error { - c.loop.ch <- func() error { - return c.loop.loopCloseConn(c) - } + task := signalTaskPool.Get().(*signalTask) + task.run = c.loop.loopCloseConn + task.c = c + c.loop.ch <- task return nil } diff --git a/eventloop_windows.go b/eventloop_windows.go index 53d118fde..701030c56 100644 --- a/eventloop_windows.go +++ b/eventloop_windows.go @@ -74,8 +74,8 @@ func (el *eventloop) loopRun(lockOSThread bool) { el.svr.loopWG.Done() }() - for v := range el.ch { - switch v := v.(type) { + for i := range el.ch { + switch v := i.(type) { case error: err = v case *stdConn: @@ -87,10 +87,12 @@ func (el *eventloop) loopRun(lockOSThread bool) { err = el.loopReadUDP(v.c) case *stderr: err = el.loopError(v.c, v.err) - case wakeReq: - err = el.loopWake(v.c) - case func() error: - err = v() + case *signalTask: + err = v.run(v.c) + signalTaskPool.Put(i) + case *dataTask: + _, err = v.run(v.buf) + dataTaskPool.Put(i) } if err == errors.ErrServerShutdown { @@ -183,7 +185,7 @@ func (el *eventloop) loopTicker(ctx context.Context) { for { delay, action = el.eventHandler.Tick() if action == Shutdown { - el.ch <- func() error { return errors.ErrServerShutdown } + el.ch <- errors.ErrServerShutdown // logging.Debugf("stopping ticker in event-loop(%d) from Tick()", el.idx) } if timer == nil { diff --git a/server_windows.go b/server_windows.go index 4bf5b5e27..382b518b2 100644 --- a/server_windows.go +++ b/server_windows.go @@ -28,7 +28,7 @@ import ( "sync" "sync/atomic" - errors2 "github.com/panjf2000/gnet/errors" + gerrors "github.com/panjf2000/gnet/errors" "github.com/panjf2000/gnet/internal/logging" ) @@ -124,7 +124,7 @@ func (svr *server) stop(s Server) { // Notify all loops to close. svr.lb.iterate(func(i int, el *eventloop) bool { - el.ch <- errors2.ErrServerShutdown + el.ch <- gerrors.ErrServerShutdown return true })