diff --git a/shadowsocks/client.go b/shadowsocks/client.go index 1f291535..d06b937d 100644 --- a/shadowsocks/client.go +++ b/shadowsocks/client.go @@ -4,6 +4,7 @@ import ( "errors" "io" "net" + "time" onet "github.com/Jigsaw-Code/outline-ss-server/net" "github.com/shadowsocks/go-shadowsocks2/core" @@ -46,6 +47,18 @@ type ssClient struct { cipher shadowaead.Cipher } +// This code contains an optimization to send the initial client payload along with +// the Shadowsocks handshake. This saves one packet during connection, and also +// reduces the distinctiveness of the connection pattern. +// +// Normally, the initial payload will be sent as soon as the socket is connected, +// except for delays due to inter-process communication. However, some protocols +// expect the server to send data first, in which case there is no client payload. +// We therefore use a short delay, longer than any reasonable IPC but shorter than +// typical network latency. (In an Android emulator, the 90th percentile delay +// was ~1 ms.) If no client payload is received by this time, we connect without it. +const helloWait = 10 * time.Millisecond + func (c *ssClient) DialTCP(laddr *net.TCPAddr, raddr string) (onet.DuplexConn, error) { socksTargetAddr := socks.ParseAddr(raddr) if socksTargetAddr == nil { @@ -57,11 +70,14 @@ func (c *ssClient) DialTCP(laddr *net.TCPAddr, raddr string) (onet.DuplexConn, e return nil, err } ssw := NewShadowsocksWriter(proxyConn, c.cipher) - _, err = ssw.Write(socksTargetAddr) + _, err = ssw.LazyWrite(socksTargetAddr) if err != nil { proxyConn.Close() return nil, errors.New("Failed to write target address") } + time.AfterFunc(helloWait, func() { + ssw.Flush() + }) ssr := NewShadowsocksReader(proxyConn, c.cipher) return onet.WrapConn(proxyConn, ssr, ssw), nil } diff --git a/shadowsocks/client_test.go b/shadowsocks/client_test.go index 2e316cff..b592a928 100644 --- a/shadowsocks/client_test.go +++ b/shadowsocks/client_test.go @@ -38,6 +38,73 @@ func TestShadowsocksClient_DialTCP(t *testing.T) { expectEchoPayload(conn, MakeTestPayload(1024), make([]byte, 1024), t) } +func TestShadowsocksClient_DialTCPNoPayload(t *testing.T) { + proxyAddr := startShadowsocksTCPEchoProxy(testTargetAddr, t) + proxyHost, proxyPort, err := splitHostPortNumber(proxyAddr.String()) + if err != nil { + t.Fatalf("Failed to parse proxy address: %v", err) + } + d, err := NewClient(proxyHost, proxyPort, testPassword, testCipher) + if err != nil { + t.Fatalf("Failed to create ShadowsocksClient: %v", err) + } + conn, err := d.DialTCP(nil, testTargetAddr) + if err != nil { + t.Fatalf("ShadowsocksClient.DialTCP failed: %v", err) + } + + // Wait for more than 10 milliseconds to ensure that the target + // address is sent. + time.Sleep(20 * time.Millisecond) + // Force the echo server to verify the target address. + conn.Close() +} + +func TestShadowsocksClient_DialTCPFastClose(t *testing.T) { + // Set up a listener that verifies no data is sent. + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + if err != nil { + t.Fatalf("ListenTCP failed: %v", err) + } + + done := make(chan struct{}) + go func() { + conn, err := listener.Accept() + if err != nil { + t.Error(err) + } + buf := make([]byte, 64) + n, err := conn.Read(buf) + if n > 0 || err != io.EOF { + t.Errorf("Expected EOF, got %v, %v", buf[:n], err) + } + listener.Close() + close(done) + }() + + proxyHost, proxyPort, err := splitHostPortNumber(listener.Addr().String()) + if err != nil { + t.Fatalf("Failed to parse proxy address: %v", err) + } + d, err := NewClient(proxyHost, proxyPort, testPassword, testCipher) + if err != nil { + t.Fatalf("Failed to create ShadowsocksClient: %v", err) + } + + conn, err := d.DialTCP(nil, testTargetAddr) + if err != nil { + t.Fatalf("ShadowsocksClient.DialTCP failed: %v", err) + } + + // Wait for less than 10 milliseconds to ensure that the target + // address is not sent. + time.Sleep(1 * time.Millisecond) + // Close the connection before the target address is sent. + conn.Close() + // Wait for the listener to verify the close. + <-done +} + func TestShadowsocksClient_ListenUDP(t *testing.T) { proxyAddr := startShadowsocksUDPEchoServer(testTargetAddr, t) proxyHost, proxyPort, err := splitHostPortNumber(proxyAddr.String()) diff --git a/shadowsocks/stream.go b/shadowsocks/stream.go index 5dd30fb4..9c6b4d0e 100644 --- a/shadowsocks/stream.go +++ b/shadowsocks/stream.go @@ -21,6 +21,7 @@ import ( "encoding/binary" "fmt" "io" + "sync" "github.com/shadowsocks/go-shadowsocks2/shadowaead" ) @@ -30,17 +31,35 @@ const payloadSizeMask = 0x3FFF // 16*1024 - 1 // Writer is an io.Writer that also implements io.ReaderFrom to // allow for piping the data without extra allocations and copies. +// The LazyWrite and Flush methods allow a header to be +// added but delayed until the first write, for concatenation. +// All methods except Flush must be called from a single thread. type Writer interface { io.Writer io.ReaderFrom + // LazyWrite queues p to be written, but doesn't send it until + // Flush() is called, a non-lazy write is made, or the buffer + // is filled. + LazyWrite(p []byte) (int, error) + // Flush sends the pending data, if any. This method is + // thread-safe. + Flush() error } type shadowsocksWriter struct { - writer io.Writer - ssCipher shadowaead.Cipher + // This type is single-threaded except when needFlush is true. + // mu protects needFlush, and also protects everything + // else while needFlush could be true. + mu sync.Mutex + // Indicates that a concurrent flush is currently allowed. + needFlush bool + writer io.Writer + ssCipher shadowaead.Cipher // Wrapper for input that arrives as a slice. byteWrapper bytes.Reader - // These are lazily initialized: + // Number of plaintext bytes that are currently buffered. + pending int + // These are populated by init(): buf []byte aead cipher.AEAD // Index of the next encrypted chunk to write. @@ -91,6 +110,42 @@ func (sw *shadowsocksWriter) Write(p []byte) (int, error) { return int(n), err } +func (sw *shadowsocksWriter) LazyWrite(p []byte) (int, error) { + if err := sw.init(); err != nil { + return 0, err + } + + // Locking is needed due to potential concurrency with the Flush() + // for a previous call to LazyWrite(). + sw.mu.Lock() + defer sw.mu.Unlock() + + queued := 0 + for { + n := sw.enqueue(p) + queued += n + p = p[n:] + if len(p) == 0 { + sw.needFlush = true + return queued, nil + } + // p didn't fit in the buffer. Flush the buffer and try + // again. + if err := sw.flush(); err != nil { + return queued, err + } + } +} + +func (sw *shadowsocksWriter) Flush() error { + sw.mu.Lock() + defer sw.mu.Unlock() + if !sw.needFlush { + return nil + } + return sw.flush() +} + func isZero(b []byte) bool { for _, v := range b { if v != 0 { @@ -100,12 +155,81 @@ func isZero(b []byte) bool { return true } +// Returns the slices of sw.buf in which to place plaintext for encryption. +func (sw *shadowsocksWriter) buffers() (sizeBuf, payloadBuf []byte) { + // sw.buf starts with the salt. + saltSize := sw.ssCipher.SaltSize() + + // Each Shadowsocks-TCP message consists of a fixed-length size block, + // followed by a variable-length payload block. + sizeBuf = sw.buf[saltSize : saltSize+2] + payloadStart := saltSize + 2 + sw.aead.Overhead() + payloadBuf = sw.buf[payloadStart : payloadStart+payloadSizeMask] + return +} + func (sw *shadowsocksWriter) ReadFrom(r io.Reader) (int64, error) { if err := sw.init(); err != nil { return 0, err } var written int64 + var err error + _, payloadBuf := sw.buffers() + // Special case: one thread-safe read, if necessary + sw.mu.Lock() + if sw.needFlush { + pending := sw.pending + + sw.mu.Unlock() + saltsize := sw.ssCipher.SaltSize() + overhead := sw.aead.Overhead() + // The first pending+overhead bytes of payloadBuf are potentially + // in use, and may be modified on the flush thread. Data after + // that is safe to use on this thread. + readBuf := sw.buf[saltsize+2+overhead+pending+overhead:] + var plaintextSize int + plaintextSize, err = r.Read(readBuf) + written = int64(plaintextSize) + sw.mu.Lock() + + sw.enqueue(readBuf[:plaintextSize]) + if flushErr := sw.flush(); flushErr != nil { + err = flushErr + } + sw.needFlush = false + } + sw.mu.Unlock() + + // Main transfer loop + for err == nil { + sw.pending, err = r.Read(payloadBuf) + written += int64(sw.pending) + if flushErr := sw.flush(); flushErr != nil { + err = flushErr + } + } + + if err == io.EOF { // ignore EOF as per io.ReaderFrom contract + return written, nil + } + return written, fmt.Errorf("Failed to read payload: %v", err) +} + +// Adds as much of `plaintext` into the buffer as will fit, and increases +// sw.pending accordingly. Returns the number of bytes consumed. +func (sw *shadowsocksWriter) enqueue(plaintext []byte) int { + _, payloadBuf := sw.buffers() + n := copy(payloadBuf[sw.pending:], plaintext) + sw.pending += n + return n +} + +// Encrypts all pending data and writes it to the output. +func (sw *shadowsocksWriter) flush() error { + if sw.pending == 0 { + return nil + } // sw.buf starts with the salt. saltSize := sw.ssCipher.SaltSize() // Normally we ignore the salt at the beginning of sw.buf. @@ -117,27 +241,13 @@ func (sw *shadowsocksWriter) ReadFrom(r io.Reader) (int64, error) { start = 0 } - // Each Shadowsocks-TCP message consists of a fixed-length size block, followed by - // a variable-length payload block. - sizeBuf := sw.buf[saltSize : saltSize+2+sw.aead.Overhead()] - payloadBuf := sw.buf[saltSize+len(sizeBuf):] - for { - plaintextSize, err := r.Read(payloadBuf[:payloadSizeMask]) - if plaintextSize > 0 { - binary.BigEndian.PutUint16(sizeBuf, uint16(plaintextSize)) - sw.encryptBlock(sizeBuf[:2]) - payloadSize := sw.encryptBlock(payloadBuf[:plaintextSize]) - _, err = sw.writer.Write(sw.buf[start : saltSize+len(sizeBuf)+payloadSize]) - written += int64(plaintextSize) - start = saltSize // Skip the salt for all writes except the first. - } - if err != nil { - if err == io.EOF { // ignore EOF as per io.ReaderFrom contract - return written, nil - } - return written, fmt.Errorf("Failed to read payload: %v", err) - } - } + sizeBuf, payloadBuf := sw.buffers() + binary.BigEndian.PutUint16(sizeBuf, uint16(sw.pending)) + sizeBlockSize := sw.encryptBlock(sizeBuf) + payloadSize := sw.encryptBlock(payloadBuf[:sw.pending]) + _, err := sw.writer.Write(sw.buf[start : saltSize+sizeBlockSize+payloadSize]) + sw.pending = 0 + return err } // ChunkReader is similar to io.Reader, except that it controls its own diff --git a/shadowsocks/stream_test.go b/shadowsocks/stream_test.go index 80bbe962..4d1bcafe 100644 --- a/shadowsocks/stream_test.go +++ b/shadowsocks/stream_test.go @@ -4,8 +4,11 @@ import ( "bytes" "fmt" "io" + "io/ioutil" "strings" + "sync" "testing" + "time" "github.com/shadowsocks/go-shadowsocks2/shadowaead" "golang.org/x/crypto/chacha20poly1305" @@ -173,3 +176,240 @@ func TestEndToEnd(t *testing.T) { t.Fatalf("Expected output '%v'. Got '%v'", expected, output.String()) } } + +func TestLazyWriteFlush(t *testing.T) { + cipher := newTestCipher(t) + buf := new(bytes.Buffer) + writer := NewShadowsocksWriter(buf, cipher) + header := []byte{1, 2, 3, 4} + n, err := writer.LazyWrite(header) + if n != len(header) { + t.Errorf("Wrong write size: %d", n) + } + if err != nil { + t.Errorf("LazyWrite failed: %v", err) + } + if buf.Len() != 0 { + t.Errorf("LazyWrite isn't lazy: %v", buf.Bytes()) + } + if err = writer.Flush(); err != nil { + t.Errorf("Flush failed: %v", err) + } + len1 := buf.Len() + if len1 <= len(header) { + t.Errorf("Not enough bytes flushed: %d", len1) + } + + // Check that normal writes now work + body := []byte{5, 6, 7} + n, err = writer.Write(body) + if n != len(body) { + t.Errorf("Wrong write size: %d", n) + } + if err != nil { + t.Errorf("Write failed: %v", err) + } + if buf.Len() == len1 { + t.Errorf("No write observed") + } + + // Verify content arrives in two blocks + reader := NewShadowsocksReader(buf, cipher) + decrypted := make([]byte, len(header)+len(body)) + n, err = reader.Read(decrypted) + if n != len(header) { + t.Errorf("Wrong number of bytes out: %d", n) + } + if err != nil { + t.Errorf("Read failed: %v", err) + } + if !bytes.Equal(decrypted[:n], header) { + t.Errorf("Wrong final content: %v", decrypted) + } + n, err = reader.Read(decrypted[n:]) + if n != len(body) { + t.Errorf("Wrong number of bytes out: %d", n) + } + if err != nil { + t.Errorf("Read failed: %v", err) + } + if !bytes.Equal(decrypted[len(header):], body) { + t.Errorf("Wrong final content: %v", decrypted) + } +} + +func TestLazyWriteConcat(t *testing.T) { + cipher := newTestCipher(t) + buf := new(bytes.Buffer) + writer := NewShadowsocksWriter(buf, cipher) + header := []byte{1, 2, 3, 4} + n, err := writer.LazyWrite(header) + if n != len(header) { + t.Errorf("Wrong write size: %d", n) + } + if err != nil { + t.Errorf("LazyWrite failed: %v", err) + } + if buf.Len() != 0 { + t.Errorf("LazyWrite isn't lazy: %v", buf.Bytes()) + } + + // Write additional data and flush the header. + body := []byte{5, 6, 7} + n, err = writer.Write(body) + if n != len(body) { + t.Errorf("Wrong write size: %d", n) + } + if err != nil { + t.Errorf("Write failed: %v", err) + } + len1 := buf.Len() + if len1 <= len(body)+len(header) { + t.Errorf("Not enough bytes flushed: %d", len1) + } + + // Flush after write should have no effect + if err = writer.Flush(); err != nil { + t.Errorf("Flush failed: %v", err) + } + if buf.Len() != len1 { + t.Errorf("Flush should have no effect") + } + + // Verify content arrives in one block + reader := NewShadowsocksReader(buf, cipher) + decrypted := make([]byte, len(body)+len(header)) + n, err = reader.Read(decrypted) + if n != len(decrypted) { + t.Errorf("Wrong number of bytes out: %d", n) + } + if err != nil { + t.Errorf("Read failed: %v", err) + } + if !bytes.Equal(decrypted[:len(header)], header) || + !bytes.Equal(decrypted[len(header):], body) { + t.Errorf("Wrong final content: %v", decrypted) + } +} + +func TestLazyWriteOversize(t *testing.T) { + cipher := newTestCipher(t) + buf := new(bytes.Buffer) + writer := NewShadowsocksWriter(buf, cipher) + N := 25000 // More than one block, less than two. + data := make([]byte, N) + for i := range data { + data[i] = byte(i) + } + n, err := writer.LazyWrite(data) + if n != len(data) { + t.Errorf("Wrong write size: %d", n) + } + if err != nil { + t.Errorf("LazyWrite failed: %v", err) + } + if buf.Len() >= N { + t.Errorf("Too much data in first block: %d", buf.Len()) + } + if err = writer.Flush(); err != nil { + t.Errorf("Flush failed: %v", err) + } + if buf.Len() <= N { + t.Errorf("Not enough data written after flush: %d", buf.Len()) + } + + // Verify content + reader := NewShadowsocksReader(buf, cipher) + decrypted, err := ioutil.ReadAll(reader) + if len(decrypted) != N { + t.Errorf("Wrong number of bytes out: %d", len(decrypted)) + } + if err != nil { + t.Errorf("Read failed: %v", err) + } + if !bytes.Equal(decrypted, data) { + t.Errorf("Wrong final content: %v", decrypted) + } +} + +func TestLazyWriteConcurrentFlush(t *testing.T) { + cipher := newTestCipher(t) + buf := new(bytes.Buffer) + writer := NewShadowsocksWriter(buf, cipher) + header := []byte{1, 2, 3, 4} + n, err := writer.LazyWrite(header) + if n != len(header) { + t.Errorf("Wrong write size: %d", n) + } + if err != nil { + t.Errorf("LazyWrite failed: %v", err) + } + if buf.Len() != 0 { + t.Errorf("LazyWrite isn't lazy: %v", buf.Bytes()) + } + + body := []byte{5, 6, 7} + r, w := io.Pipe() + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + n, err := writer.ReadFrom(r) + if n != int64(len(body)) { + t.Errorf("ReadFrom: Wrong read size %d", n) + } + if err != nil { + t.Errorf("ReadFrom: %v", err) + } + wg.Done() + }() + + // Wait for ReadFrom to start and get blocked. + time.Sleep(20 * time.Millisecond) + + // Flush while ReadFrom is blocked. + if err := writer.Flush(); err != nil { + t.Errorf("Flush error: %v", err) + } + len1 := buf.Len() + if len1 == 0 { + t.Errorf("No bytes flushed") + } + + // Check that normal writes now work + n, err = w.Write(body) + if n != len(body) { + t.Errorf("Wrong write size: %d", n) + } + if err != nil { + t.Errorf("Write failed: %v", err) + } + w.Close() + wg.Wait() + if buf.Len() == len1 { + t.Errorf("No write observed") + } + + // Verify content arrives in two blocks + reader := NewShadowsocksReader(buf, cipher) + decrypted := make([]byte, len(header)+len(body)) + n, err = reader.Read(decrypted) + if n != len(header) { + t.Errorf("Wrong number of bytes out: %d", n) + } + if err != nil { + t.Errorf("Read failed: %v", err) + } + if !bytes.Equal(decrypted[:len(header)], header) { + t.Errorf("Wrong final content: %v", decrypted) + } + n, err = reader.Read(decrypted[len(header):]) + if n != len(body) { + t.Errorf("Wrong number of bytes out: %d", n) + } + if err != nil { + t.Errorf("Read failed: %v", err) + } + if !bytes.Equal(decrypted[len(header):], body) { + t.Errorf("Wrong final content: %v", decrypted) + } +}