diff --git a/cmd/server/cmd.go b/cmd/server/cmd.go index 3b042dbd..aab02b2d 100644 --- a/cmd/server/cmd.go +++ b/cmd/server/cmd.go @@ -30,8 +30,9 @@ import ( var ( conf = server.Config{} - peerTLS = security.TLSOption{} - serverTLS = security.TLSOption{} + peerTLS = security.TLSOption{} + serverTLS = security.TLSOption{} + internalServerTLS = security.TLSOption{} Cmd = &cobra.Command{ Use: "server", @@ -61,6 +62,15 @@ func init() { Cmd.Flags().BoolVar(&serverTLS.InsecureSkipVerify, "tls-insecure-skip-verify", false, "Tls insecure skip verify") Cmd.Flags().BoolVar(&serverTLS.ClientAuth, "tls-client-auth", false, "Tls client auth") + // internal server TLS section + Cmd.Flags().StringVar(&internalServerTLS.CertFile, "internal-tls-cert-file", "", "Internal server tls certificate file") + Cmd.Flags().StringVar(&internalServerTLS.KeyFile, "internal-tls-key-file", "", "Internal server tls key file") + Cmd.Flags().Uint16Var(&internalServerTLS.MinVersion, "internal-tls-min-version", 0, "Internal server tls minimum version") + Cmd.Flags().Uint16Var(&internalServerTLS.MaxVersion, "internal-tls-max-version", 0, "Internal server tls maximum version") + Cmd.Flags().StringVar(&internalServerTLS.TrustedCaFile, "internal-tls-trusted-ca-file", "", "Internal server tls trusted ca file") + Cmd.Flags().BoolVar(&internalServerTLS.InsecureSkipVerify, "internal-tls-insecure-skip-verify", false, "Internal server tls insecure skip verify") + Cmd.Flags().BoolVar(&internalServerTLS.ClientAuth, "internal-tls-client-auth", false, "Internal server tls client auth") + // peer client TLS section Cmd.Flags().StringVar(&peerTLS.CertFile, "peer-tls-cert-file", "", "Peer tls certificate file") Cmd.Flags().StringVar(&peerTLS.KeyFile, "peer-tls-key-file", "", "Peer tls key file") @@ -73,17 +83,29 @@ func init() { func exec(*cobra.Command, []string) { common.RunProcess(func() (io.Closer, error) { - var err error - if serverTLS.IsConfigured() { - if conf.ServerTLS, err = serverTLS.MakeServerTLSConf(); err != nil { - return nil, err - } - } - if peerTLS.IsConfigured() { - if conf.PeerTLS, err = peerTLS.MakeClientTLSConf(); err != nil { - return nil, err - } + if err := configureTLS(); err != nil { + return nil, err } return server.New(conf) }) } + +func configureTLS() error { + var err error + if serverTLS.IsConfigured() { + if conf.ServerTLS, err = serverTLS.MakeServerTLSConf(); err != nil { + return err + } + } + if peerTLS.IsConfigured() { + if conf.PeerTLS, err = peerTLS.MakeClientTLSConf(); err != nil { + return err + } + } + if internalServerTLS.IsConfigured() { + if conf.InternalServerTLS, err = internalServerTLS.MakeServerTLSConf(); err != nil { + return err + } + } + return nil +} diff --git a/server/server.go b/server/server.go index d6b5af9c..eb34a499 100644 --- a/server/server.go +++ b/server/server.go @@ -33,6 +33,7 @@ type Config struct { InternalServiceAddr string PeerTLS *tls.Config ServerTLS *tls.Config + InternalServerTLS *tls.Config MetricsServiceAddr string DataDir string WalDir string @@ -92,7 +93,7 @@ func NewWithGrpcProvider(config Config, provider container.GrpcProvider, replica s.shardAssignmentDispatcher = NewShardAssignmentDispatcher(s.healthServer) s.internalRpcServer, err = newInternalRpcServer(provider, config.InternalServiceAddr, - s.shardsDirector, s.shardAssignmentDispatcher, s.healthServer, config.ServerTLS) + s.shardsDirector, s.shardAssignmentDispatcher, s.healthServer, config.InternalServerTLS) if err != nil { return nil, err } diff --git a/tests/security/tls/tls_encryption_test.go b/tests/security/tls/tls_encryption_test.go index 4580628f..e2a43d37 100644 --- a/tests/security/tls/tls_encryption_test.go +++ b/tests/security/tls/tls_encryption_test.go @@ -68,15 +68,23 @@ func getClientTLSOption() (*security.TLSOption, error) { } func newTLSServer(t *testing.T) (s *server.Server, addr model.ServerAddress) { + t.Helper() + return newTLSServerWithInterceptor(t, func(config *server.Config) { + + }) +} + +func newTLSServerWithInterceptor(t *testing.T, interceptor func(config *server.Config)) (s *server.Server, addr model.ServerAddress) { t.Helper() option, err := getPeerTLSOption() assert.NoError(t, err) serverTLSConf, err := option.MakeServerTLSConf() assert.NoError(t, err) + peerTLSConf, err := option.MakeClientTLSConf() assert.NoError(t, err) - s, err = server.New(server.Config{ + config := server.Config{ PublicServiceAddr: "localhost:0", InternalServiceAddr: "localhost:0", MetricsServiceAddr: "", // Disable metrics to avoid conflict @@ -85,7 +93,12 @@ func newTLSServer(t *testing.T) (s *server.Server, addr model.ServerAddress) { NotificationsRetentionTime: 1 * time.Minute, PeerTLS: peerTLSConf, ServerTLS: serverTLSConf, - }) + InternalServerTLS: serverTLSConf, + } + + interceptor(&config) + + s, err = server.New(config) assert.NoError(t, err) @@ -279,3 +292,47 @@ func TestClientHandshakeSuccess(t *testing.T) { assert.NoError(t, err) client.Close() } + +func TestOnlyEnablePublicTls(t *testing.T) { + disableInternalTLS := func(config *server.Config) { + config.InternalServerTLS = nil + config.PeerTLS = nil + } + s1, sa1 := newTLSServerWithInterceptor(t, disableInternalTLS) + defer s1.Close() + s2, sa2 := newTLSServerWithInterceptor(t, disableInternalTLS) + defer s2.Close() + s3, sa3 := newTLSServerWithInterceptor(t, disableInternalTLS) + defer s3.Close() + + metadataProvider := impl.NewMetadataProviderMemory() + clusterConfig := model.ClusterConfig{ + Namespaces: []model.NamespaceConfig{{ + Name: common.DefaultNamespace, + ReplicationFactor: 3, + InitialShardCount: 1, + }}, + Servers: []model.ServerAddress{sa1, sa2, sa3}, + } + clientPool := common.NewClientPool(nil) + defer clientPool.Close() + + coordinator, err := impl.NewCoordinator(metadataProvider, func() (model.ClusterConfig, error) { return clusterConfig, nil }, nil, impl.NewRpcProvider(clientPool)) + assert.NoError(t, err) + defer coordinator.Close() + + // failed without cert + + client, err := oxia.NewSyncClient(sa1.Public, oxia.WithRequestTimeout(1*time.Second)) + assert.Error(t, err) + assert.Nil(t, client) + + // success with cert + tlsOption, err := getClientTLSOption() + assert.NoError(t, err) + tlsConf, err := tlsOption.MakeClientTLSConf() + assert.NoError(t, err) + client, err = oxia.NewSyncClient(sa1.Public, oxia.WithTLS(tlsConf)) + assert.NoError(t, err) + client.Close() +}