diff --git a/rpc/stdio.go b/rpc/stdio.go index 8f6b7bd4bf6b..d5dc066c99bb 100644 --- a/rpc/stdio.go +++ b/rpc/stdio.go @@ -19,6 +19,7 @@ package rpc import ( "context" "errors" + "io" "net" "os" "time" @@ -26,19 +27,30 @@ import ( // DialStdIO creates a client on stdin/stdout. func DialStdIO(ctx context.Context) (*Client, error) { + return DialIO(ctx, os.Stdin, os.Stdout) +} + +// DialIO creates a client which uses the given IO channels +func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) { return newClient(ctx, func(_ context.Context) (ServerCodec, error) { - return NewJSONCodec(stdioConn{}), nil + return NewJSONCodec(stdioConn{ + in: in, + out: out, + }), nil }) } -type stdioConn struct{} +type stdioConn struct { + in io.Reader + out io.Writer +} func (io stdioConn) Read(b []byte) (n int, err error) { - return os.Stdin.Read(b) + return io.in.Read(b) } func (io stdioConn) Write(b []byte) (n int, err error) { - return os.Stdout.Write(b) + return io.out.Write(b) } func (io stdioConn) Close() error {