From 36b74d0b6879e6af0c366cb9624d2227bcbeefd4 Mon Sep 17 00:00:00 2001 From: ljun20160606 Date: Tue, 16 May 2023 10:21:33 +0800 Subject: [PATCH] feat: add protocol option --- cluster.go | 2 ++ cluster_test.go | 38 +++++++++++++++++++++++++++ options.go | 4 +++ options_test.go | 3 +++ redis.go | 7 ++++- redis_test.go | 27 +++++++++++++++++++ ring.go | 4 ++- ring_test.go | 40 +++++++++++++++++++++++++++++ sentinel.go | 3 +++ sentinel_test.go | 67 ++++++++++++++++++++++++++++++++++++++++++++++++ universal.go | 4 +++ 11 files changed, 197 insertions(+), 2 deletions(-) diff --git a/cluster.go b/cluster.go index 12417d1f94..d20295c7d9 100644 --- a/cluster.go +++ b/cluster.go @@ -62,6 +62,7 @@ type ClusterOptions struct { OnConnect func(ctx context.Context, cn *Conn) error + Protocol int Username string Password string @@ -263,6 +264,7 @@ func (opt *ClusterOptions) clientOptions() *Options { Dialer: opt.Dialer, OnConnect: opt.OnConnect, + Protocol: opt.Protocol, Username: opt.Username, Password: opt.Password, diff --git a/cluster_test.go b/cluster_test.go index be24942416..75ea40df11 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -583,6 +583,35 @@ var _ = Describe("ClusterClient", func() { }) } + Describe("ClusterClient PROTO 2", func() { + BeforeEach(func() { + opt = redisClusterOptions() + opt.Protocol = 2 + client = cluster.newClusterClient(ctx, opt) + + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + return master.FlushDB(ctx).Err() + }) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + _ = client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + return master.FlushDB(ctx).Err() + }) + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("should CLUSTER PROTO 2", func() { + _ = client.ForEachShard(ctx, func(ctx context.Context, c *redis.Client) error { + val, err := c.Do(ctx, "HELLO").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).Should(ContainElements("proto", int64(2))) + return nil + }) + }) + }) + Describe("ClusterClient", func() { BeforeEach(func() { opt = redisClusterOptions() @@ -746,6 +775,15 @@ var _ = Describe("ClusterClient", func() { }) }) + It("should CLUSTER PROTO 3", func() { + _ = client.ForEachShard(ctx, func(ctx context.Context, c *redis.Client) error { + val, err := c.Do(ctx, "HELLO").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).Should(HaveKeyWithValue("proto", int64(3))) + return nil + }) + }) + It("should CLUSTER MYSHARDID", func() { shardID, err := client.ClusterMyShardID(ctx).Result() Expect(err).NotTo(HaveOccurred()) diff --git a/options.go b/options.go index a4af5884fe..bb4816b274 100644 --- a/options.go +++ b/options.go @@ -45,6 +45,9 @@ type Options struct { // Hook that is called when new connection is established. OnConnect func(ctx context.Context, cn *Conn) error + // Protocol 2 or 3. Use the version to negotiate RESP version with redis-server. + // Default is 3. + Protocol int // Use the specified Username to authenticate the current connection // with one of the connections defined in the ACL list when connecting // to a Redis 6.0 instance, or greater, that is using the Redis ACL system. @@ -437,6 +440,7 @@ func setupConnParams(u *url.URL, o *Options) (*Options, error) { o.DB = db } + o.Protocol = q.int("protocol") o.ClientName = q.string("client_name") o.MaxRetries = q.int("max_retries") o.MinRetryBackoff = q.duration("min_retry_backoff") diff --git a/options_test.go b/options_test.go index 4ad9175344..fa9ac6c9ee 100644 --- a/options_test.go +++ b/options_test.go @@ -62,6 +62,9 @@ func TestParseURL(t *testing.T) { }, { url: "redis://localhost:123/?db=2&client_name=hi", // client name o: &Options{Addr: "localhost:123", DB: 2, ClientName: "hi"}, + }, { + url: "redis://localhost:123/?db=2&protocol=2", // RESP Protocol + o: &Options{Addr: "localhost:123", DB: 2, Protocol: 2}, }, { url: "unix:///tmp/redis.sock", o: &Options{Addr: "/tmp/redis.sock"}, diff --git a/redis.go b/redis.go index cae12f8c9a..c7fbd0de8a 100644 --- a/redis.go +++ b/redis.go @@ -279,10 +279,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { conn := newConn(c.opt, connPool) var auth bool + protocol := c.opt.Protocol + // By default, use RESP3 in current version. + if protocol < 2 { + protocol = 3 + } // for redis-server versions that do not support the HELLO command, // RESP2 will continue to be used. - if err := conn.Hello(ctx, 3, username, password, "").Err(); err == nil { + if err := conn.Hello(ctx, protocol, username, password, "").Err(); err == nil { auth = true } else if !isRedisError(err) { // When the server responds with the RESP protocol and the result is not a normal diff --git a/redis_test.go b/redis_test.go index 92b24c7298..c9d8df6bf0 100644 --- a/redis_test.go +++ b/redis_test.go @@ -185,6 +185,33 @@ var _ = Describe("Client", func() { Expect(val).Should(ContainSubstring("name=hi")) }) + It("should client PROTO 2", func() { + opt := redisOptions() + opt.Protocol = 2 + db := redis.NewClient(opt) + + defer func() { + Expect(db.Close()).NotTo(HaveOccurred()) + }() + + val, err := db.Do(ctx, "HELLO").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).Should(ContainElements("proto", int64(2))) + }) + + It("should client PROTO 3", func() { + opt := redisOptions() + db := redis.NewClient(opt) + + defer func() { + Expect(db.Close()).NotTo(HaveOccurred()) + }() + + val, err := db.Do(ctx, "HELLO").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).Should(HaveKeyWithValue("proto", int64(3))) + }) + It("processes custom commands", func() { cmd := redis.NewCmd(ctx, "PING") _ = client.Process(ctx, cmd) diff --git a/ring.go b/ring.go index f924ac0ad0..0572ba3460 100644 --- a/ring.go +++ b/ring.go @@ -12,7 +12,7 @@ import ( "time" "github.com/cespare/xxhash/v2" - rendezvous "github.com/dgryski/go-rendezvous" //nolint + "github.com/dgryski/go-rendezvous" //nolint "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hashtag" @@ -70,6 +70,7 @@ type RingOptions struct { Dialer func(ctx context.Context, network, addr string) (net.Conn, error) OnConnect func(ctx context.Context, cn *Conn) error + Protocol int Username string Password string DB int @@ -136,6 +137,7 @@ func (opt *RingOptions) clientOptions() *Options { Dialer: opt.Dialer, OnConnect: opt.OnConnect, + Protocol: opt.Protocol, Username: opt.Username, Password: opt.Password, DB: opt.DB, diff --git a/ring_test.go b/ring_test.go index 73bb2fc49d..b349059669 100644 --- a/ring_test.go +++ b/ring_test.go @@ -15,6 +15,37 @@ import ( "github.com/redis/go-redis/v9" ) +var _ = Describe("Redis Ring PROTO 2", func() { + const heartbeat = 100 * time.Millisecond + + var ring *redis.Ring + + BeforeEach(func() { + opt := redisRingOptions() + opt.Protocol = 2 + opt.HeartbeatFrequency = heartbeat + ring = redis.NewRing(opt) + + err := ring.ForEachShard(ctx, func(ctx context.Context, cl *redis.Client) error { + return cl.FlushDB(ctx).Err() + }) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(ring.Close()).NotTo(HaveOccurred()) + }) + + It("should PROTO 2", func() { + _ = ring.ForEachShard(ctx, func(ctx context.Context, c *redis.Client) error { + val, err := c.Do(ctx, "HELLO").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).Should(ContainElements("proto", int64(2))) + return nil + }) + }) +}) + var _ = Describe("Redis Ring", func() { const heartbeat = 100 * time.Millisecond @@ -65,6 +96,15 @@ var _ = Describe("Redis Ring", func() { }) }) + It("should ring PROTO 3", func() { + _ = ring.ForEachShard(ctx, func(ctx context.Context, c *redis.Client) error { + val, err := c.Do(ctx, "HELLO").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).Should(HaveKeyWithValue("proto", int64(3))) + return nil + }) + }) + It("distributes keys", func() { setRingKeys() diff --git a/sentinel.go b/sentinel.go index 5ea41f17b8..dbff406039 100644 --- a/sentinel.go +++ b/sentinel.go @@ -54,6 +54,7 @@ type FailoverOptions struct { Dialer func(ctx context.Context, network, addr string) (net.Conn, error) OnConnect func(ctx context.Context, cn *Conn) error + Protocol int Username string Password string DB int @@ -88,6 +89,7 @@ func (opt *FailoverOptions) clientOptions() *Options { OnConnect: opt.OnConnect, DB: opt.DB, + Protocol: opt.Protocol, Username: opt.Username, Password: opt.Password, @@ -151,6 +153,7 @@ func (opt *FailoverOptions) clusterOptions() *ClusterOptions { Dialer: opt.Dialer, OnConnect: opt.OnConnect, + Protocol: opt.Protocol, Username: opt.Username, Password: opt.Password, diff --git a/sentinel_test.go b/sentinel_test.go index fc9dbcbc5a..705b6a0946 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -10,6 +10,30 @@ import ( "github.com/redis/go-redis/v9" ) +var _ = Describe("Sentinel PROTO 2", func() { + var client *redis.Client + + BeforeEach(func() { + client = redis.NewFailoverClient(&redis.FailoverOptions{ + MasterName: sentinelName, + SentinelAddrs: sentinelAddrs, + MaxRetries: -1, + Protocol: 2, + }) + Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + _ = client.Close() + }) + + It("should sentinel client PROTO 2", func() { + val, err := client.Do(ctx, "HELLO").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).Should(ContainElements("proto", int64(2))) + }) +}) + var _ = Describe("Sentinel", func() { var client *redis.Client var master *redis.Client @@ -134,6 +158,40 @@ var _ = Describe("Sentinel", func() { Expect(err).NotTo(HaveOccurred()) Expect(val).Should(ContainSubstring("name=sentinel_hi")) }) + + It("should sentinel client PROTO 3", func() { + val, err := client.Do(ctx, "HELLO").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).Should(HaveKeyWithValue("proto", int64(3))) + }) +}) + +var _ = Describe("NewFailoverClusterClient PROTO 2", func() { + var client *redis.ClusterClient + + BeforeEach(func() { + client = redis.NewFailoverClusterClient(&redis.FailoverOptions{ + MasterName: sentinelName, + SentinelAddrs: sentinelAddrs, + Protocol: 2, + + RouteRandomly: true, + }) + Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + _ = client.Close() + }) + + It("should sentinel cluster PROTO 2", func() { + _ = client.ForEachShard(ctx, func(ctx context.Context, c *redis.Client) error { + val, err := client.Do(ctx, "HELLO").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).Should(ContainElements("proto", int64(2))) + return nil + }) + }) }) var _ = Describe("NewFailoverClusterClient", func() { @@ -237,6 +295,15 @@ var _ = Describe("NewFailoverClusterClient", func() { return nil }) }) + + It("should sentinel cluster PROTO 3", func() { + _ = client.ForEachShard(ctx, func(ctx context.Context, c *redis.Client) error { + val, err := client.Do(ctx, "HELLO").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).Should(HaveKeyWithValue("proto", int64(3))) + return nil + }) + }) }) var _ = Describe("SentinelAclAuth", func() { diff --git a/universal.go b/universal.go index 9d1a8520a6..53ece18565 100644 --- a/universal.go +++ b/universal.go @@ -26,6 +26,7 @@ type UniversalOptions struct { Dialer func(ctx context.Context, network, addr string) (net.Conn, error) OnConnect func(ctx context.Context, cn *Conn) error + Protocol int Username string Password string SentinelUsername string @@ -77,6 +78,7 @@ func (o *UniversalOptions) Cluster() *ClusterOptions { Dialer: o.Dialer, OnConnect: o.OnConnect, + Protocol: o.Protocol, Username: o.Username, Password: o.Password, @@ -122,6 +124,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions { OnConnect: o.OnConnect, DB: o.DB, + Protocol: o.Protocol, Username: o.Username, Password: o.Password, SentinelUsername: o.SentinelUsername, @@ -162,6 +165,7 @@ func (o *UniversalOptions) Simple() *Options { OnConnect: o.OnConnect, DB: o.DB, + Protocol: o.Protocol, Username: o.Username, Password: o.Password,