Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add message reader and writer #44

Merged
merged 10 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 127 additions & 79 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package connection

import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
Expand Down Expand Up @@ -40,18 +39,30 @@ const (
StatusUnknown Status = ""
)

// ErrUnpack returns error with possibility to access RawMessage when
// UnpackError returns error with possibility to access RawMessage when
// connection failed to unpack message
type ErrUnpack struct {
type UnpackError struct {
Err error
RawMessage []byte
}

func (e *ErrUnpack) Error() string {
func (e *UnpackError) Error() string {
return e.Err.Error()
}

func (e *ErrUnpack) Unwrap() error {
func (e *UnpackError) Unwrap() error {
return e.Err
}

type PackError struct {
Err error
}

func (e *PackError) Error() string {
return e.Err.Error()
}

func (e *PackError) Unwrap() error {
return e.Err
}

Expand All @@ -62,7 +73,7 @@ type Connection struct {
Opts Options
conn io.ReadWriteCloser
requestsCh chan request
readResponseCh chan []byte
readResponseCh chan *iso8583.Message
done chan struct{}

// spec that will be used to unpack received messages
Expand Down Expand Up @@ -105,7 +116,7 @@ func New(addr string, spec *iso8583.MessageSpec, mlReader MessageLengthReader, m
addr: addr,
Opts: opts,
requestsCh: make(chan request),
readResponseCh: make(chan []byte),
readResponseCh: make(chan *iso8583.Message),
done: make(chan struct{}),
respMap: make(map[string]response),
spec: spec,
Expand Down Expand Up @@ -296,8 +307,8 @@ func (c *Connection) Done() <-chan struct{} {

// request represents request to the ISO 8583 server
type request struct {
// includes length header and message itself
rawMessage []byte
// message to send
message *iso8583.Message

// ID of the request (based on STAN, RRN, etc.)
requestID string
Expand Down Expand Up @@ -330,34 +341,17 @@ func (c *Connection) Send(message *iso8583.Message) (*iso8583.Message, error) {
c.mutex.Unlock()
defer c.wg.Done()

var buf bytes.Buffer
packed, err := message.Pack()
if err != nil {
return nil, fmt.Errorf("packing message: %w", err)
}

// create header
_, err = c.writeMessageLength(&buf, len(packed))
if err != nil {
return nil, fmt.Errorf("writing message header to buffer: %w", err)
}

_, err = buf.Write(packed)
if err != nil {
return nil, fmt.Errorf("writing packed message to buffer: %w", err)
}

// prepare request
reqID, err := c.Opts.RequestIDGenerator.GenerateRequestID(message)
if err != nil {
return nil, fmt.Errorf("creating request ID: %w", err)
}

req := request{
rawMessage: buf.Bytes(),
requestID: reqID,
replyCh: make(chan *iso8583.Message),
errCh: make(chan error),
message: message,
requestID: reqID,
replyCh: make(chan *iso8583.Message),
errCh: make(chan error),
}

var resp *iso8583.Message
Expand Down Expand Up @@ -399,49 +393,58 @@ func (c *Connection) Send(message *iso8583.Message) (*iso8583.Message, error) {
return resp, err
}

// Reply sends the message and does not wait for a reply to be received.
// Any reply received for message send using Reply will be handled with
// unmatchedMessageHandler
func (c *Connection) Reply(message *iso8583.Message) error {
c.mutex.Lock()
if c.closing {
c.mutex.Unlock()
return ErrConnectionClosed
func (c *Connection) writeMessage(w io.Writer, message *iso8583.Message) error {
if c.Opts.MessageWriter != nil {
return c.Opts.MessageWriter.WriteMessage(w, message)
}
// calling wg.Add(1) within mutex guarantees that it does not pass the wg.Wait() call in the Close method
// otherwise we will have data race issue
c.wg.Add(1)
c.mutex.Unlock()
defer c.wg.Done()

// prepare message for sending
var buf bytes.Buffer
// default message writer
packed, err := message.Pack()
if err != nil {
return fmt.Errorf("packing message: %w", err)
return utils.NewSafeError(&PackError{err}, "failed to pack message")
}

// create header
_, err = c.writeMessageLength(&buf, len(packed))
_, err = c.writeMessageLength(w, len(packed))
if err != nil {
return fmt.Errorf("writing message header to buffer: %w", err)
}

_, err = buf.Write(packed)
_, err = w.Write(packed)
if err != nil {
return fmt.Errorf("writing packed message to buffer: %w", err)
}

return nil
}

// Reply sends the message and does not wait for a reply to be received.
// Any reply received for message send using Reply will be handled with
// unmatchedMessageHandler
func (c *Connection) Reply(message *iso8583.Message) error {
c.mutex.Lock()
if c.closing {
c.mutex.Unlock()
return ErrConnectionClosed
}
// calling wg.Add(1) within mutex guarantees that it does not pass the wg.Wait() call in the Close method
// otherwise we will have data race issue
c.wg.Add(1)
c.mutex.Unlock()
defer c.wg.Done()

req := request{
rawMessage: buf.Bytes(),
errCh: make(chan error),
message: message,
errCh: make(chan error),
}

c.requestsCh <- req

sendTimeoutTimer := time.NewTimer(c.Opts.SendTimeout)
defer sendTimeoutTimer.Stop()

var err error

select {
case err = <-req.errCh:
case <-sendTimeoutTimer.C:
Expand Down Expand Up @@ -507,9 +510,25 @@ func (c *Connection) writeLoop() {
c.pendingRequestsMu.Unlock()
}

_, err = c.conn.Write([]byte(req.rawMessage))
err = c.writeMessage(c.conn, req.message)
if err != nil {
c.handleError(utils.NewSafeError(err, "failed to write message into connection"))
c.handleError(fmt.Errorf("writing message: %w", err))

var packErr *PackError
if errors.As(err, &packErr) {
// let caller know that his message was not not sent
// because of pack error. We don't set all type of errors to errCh
// as this case is handled by handleConnectionError(err)
// which sends the same error to all pending requests, including
// this one
req.errCh <- err

err = nil

// we can continue to write other messages
continue
}

break
}

Expand Down Expand Up @@ -539,29 +558,70 @@ func (c *Connection) writeLoop() {
// readLoop reads data from the socket (message length header and raw message)
// and runs a goroutine to handle the message
func (c *Connection) readLoop() {
var err error
var messageLength int
var outErr error

r := bufio.NewReader(c.conn)
for {
messageLength, err = c.readMessageLength(r)
message, err := c.readMessage(r)
if err != nil {
c.handleError(utils.NewSafeError(err, "failed to read message length"))
c.handleError(utils.NewSafeError(err, "failed to read message from connection"))

// if err is ErrUnpack, we can still continue reading
// from the connection
var unpackErr *UnpackError
if errors.As(err, &unpackErr) {
continue
}

outErr = err
break
}

// read the packed message
rawMessage := make([]byte, messageLength)
_, err = io.ReadFull(r, rawMessage)
if err != nil {
c.handleError(utils.NewSafeError(err, "failed to read message from connection"))
break
// if readMessage returns nil message, it means that
// it was a ping message or something else, not a regular
// iso8583 message and we can continue reading
if message == nil {
continue
}

c.readResponseCh <- rawMessage
c.readResponseCh <- message
}

c.handleConnectionError(err)
c.handleConnectionError(outErr)
}

// readMessage reads message length header and raw message from the connection
// and returns iso8583.Message and error if any
func (c *Connection) readMessage(r io.Reader) (*iso8583.Message, error) {
if c.Opts.MessageReader != nil {
return c.Opts.MessageReader.ReadMessage(r)
}

// default message reader
messageLength, err := c.readMessageLength(r)
if err != nil {
return nil, fmt.Errorf("failed to read message length: %w", err)
}

// read the packed message
rawMessage := make([]byte, messageLength)
_, err = io.ReadFull(r, rawMessage)
if err != nil {
return nil, fmt.Errorf("failed to read message from connection: %w", err)
}

// unpack the message
message := iso8583.NewMessage(c.spec)
err = message.Unpack(rawMessage)
if err != nil {
unpackErr := &UnpackError{
Err: err,
RawMessage: rawMessage,
}
return nil, fmt.Errorf("unpacking message: %w", unpackErr)
}

return message, nil
}

func (c *Connection) readResponseLoop() {
Expand All @@ -584,21 +644,9 @@ func (c *Connection) readResponseLoop() {
}
}

// handleResponse unpacks the message and then sends it to the reply channel
// that corresponds to the message ID (request ID)
func (c *Connection) handleResponse(rawMessage []byte) {
// create message
message := iso8583.NewMessage(c.spec)
err := message.Unpack(rawMessage)
if err != nil {
unpackErr := &ErrUnpack{
Err: err,
RawMessage: rawMessage,
}
c.handleError(utils.NewSafeError(unpackErr, "failed to unpack message"))
return
}

// handleResponse sends message to the reply channel that corresponds to the
// message ID (request ID)
func (c *Connection) handleResponse(message *iso8583.Message) {
if isResponse(message) {
reqID, err := c.Opts.RequestIDGenerator.GenerateRequestID(message)
if err != nil {
Expand Down
Loading