Skip to content

Commit

Permalink
Fix code to correctly serve and connect to TLS / plain endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
ineiti committed Sep 26, 2023
1 parent 03719fb commit f164e5b
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 27 deletions.
17 changes: 11 additions & 6 deletions mino/minogrpc/controller/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,20 @@ func (a tokenAction) Execute(req node.Context) error {

token := m.GenerateToken(exp)

chain := m.GetCertificateChain()
var certHash string
if m.ServeTLS() {
chain := m.GetCertificateChain()

digest, err := m.GetCertificateStore().Hash(chain)
if err != nil {
return xerrors.Errorf("couldn't hash certificate: %v", err)
digest, err := m.GetCertificateStore().Hash(chain)
if err != nil {
return xerrors.Errorf("couldn't hash certificate: %v", err)
}

certHash = fmt.Sprintf(" --cert-hash %s", base64.StdEncoding.EncodeToString(digest))
}

fmt.Fprintf(req.Out, "--token %s --cert-hash %s\n",
token, base64.StdEncoding.EncodeToString(digest))
fmt.Fprintf(req.Out, "--token %s%s\n",
token, certHash)

return nil
}
Expand Down
4 changes: 4 additions & 0 deletions mino/minogrpc/controller/actions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ type fakeJoinable struct {
err error
}

func (j fakeJoinable) ServeTLS() bool {
return true
}

func (j fakeJoinable) GetCertificateChain() certs.CertChain {
cert, _ := j.certs.Load(fake.NewAddress(0))

Expand Down
2 changes: 1 addition & 1 deletion mino/minogrpc/controller/mod.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (m miniController) SetCommands(builder node.Builder) {
cli.StringFlag{
Name: "cert-hash",
Usage: "certificate hash of the distant server",
Required: true,
Required: false,
},
)
sub.SetAction(builder.MakeAction(joinAction{}))
Expand Down
3 changes: 3 additions & 0 deletions mino/minogrpc/mod.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ var listener = net.Listen
type Joinable interface {
mino.Mino

// ServeTLS returns true if this node is running with TLS for gRPC.
ServeTLS() bool

// GetCertificateChain returns the certificate chain of the instance.
GetCertificateChain() certs.CertChain

Expand Down
73 changes: 53 additions & 20 deletions mino/minogrpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ type overlay struct {
router router.Router
connMgr session.ConnectionManager
addrFactory mino.AddressFactory
serveTLS bool

// secret and public are the key pair that has generated the server
// certificate.
Expand All @@ -443,7 +444,6 @@ type overlay struct {
// Keep a text marshalled value for the overlay address so that it's not
// calculated for each request.
myAddrStr string
useTLS bool
}

func newOverlay(tmpl *minoTemplate) (*overlay, error) {
Expand Down Expand Up @@ -509,6 +509,11 @@ func newOverlay(tmpl *minoTemplate) (*overlay, error) {
return o, nil
}

// ServeTLS returns true if the gRPC server uses TLS
func (o *overlay) ServeTLS() bool {
return o.serveTLS
}

// GetCertificateChain returns the certificate of the overlay with its private key
// set. This function will panic if the overlay has the "noTLS" flag sets.
func (o *overlay) GetCertificateChain() certs.CertChain {
Expand All @@ -533,19 +538,41 @@ func (o *overlay) GetCertificateStore() certs.Storage {
return o.certs
}

// Join sends a join request to a distant node with token generated beforehands
// by the later.
// Join sends a join request to a distant node with a token generated by the
// remote node.
// The certHash is used to make sure that no man-in-the-middle intercepts the
// communication.
// If the certHash is empty, it supposes that a transparent proxy is handling
// the TLS connection and that we can trust the CAs in place.
func (o *overlay) Join(addr *url.URL, token string, certHash []byte) error {

target := session.NewAddress(addr.Host + addr.Path)
host := addr.Host
if addr.Port() == "" {
switch addr.Scheme {
case "http":
host += ":80"
case "https":
host += ":443"
default:
return xerrors.Errorf("address doesn't contain a port and uses "+
"an unknown scheme: %s", addr.Scheme)
}
}
target := session.NewAddress(host + addr.Path)

chain := o.GetCertificateChain()
chain := &ptypes.CertificateChain{
Address: []byte(o.myAddrStr),
}

// Fetch the certificate of the node we want to join. The hash is used to
// ensure that we get the right certificate.
err := o.certs.Fetch(target, certHash)
if err != nil {
return xerrors.Errorf("couldn't fetch distant certificate: %v", err)
if o.serveTLS {
chain.Value = o.GetCertificateChain()

// Fetch the certificate of the node we want to join. The hash is used to
// ensure that we get the right certificate.
err := o.certs.Fetch(target, certHash)
if err != nil {
return xerrors.Errorf("couldn't fetch distant certificate: %v", err)
}
}

conn, err := o.connMgr.Acquire(target)
Expand All @@ -554,15 +581,11 @@ func (o *overlay) Join(addr *url.URL, token string, certHash []byte) error {
}

defer o.connMgr.Release(target)

client := ptypes.NewOverlayClient(conn)

req := &ptypes.JoinRequest{
Token: token,
Chain: &ptypes.CertificateChain{
Address: []byte(o.myAddrStr),
Value: chain,
},
Chain: chain,
}

ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -573,11 +596,13 @@ func (o *overlay) Join(addr *url.URL, token string, certHash []byte) error {
return xerrors.Errorf("couldn't call join: %v", err)
}

// Update the certificate store with the response from the node we just
// joined. That will allow the node to communicate with the network.
for _, raw := range resp.Peers {
from := o.addrFactory.FromText(raw.GetAddress())
o.certs.Store(from, raw.GetValue())
if o.serveTLS {
// Update the certificate store with the response from the node we just
// joined. That will allow the node to communicate with the network.
for _, raw := range resp.Peers {
from := o.addrFactory.FromText(raw.GetAddress())
o.certs.Store(from, raw.GetValue())
}
}

return nil
Expand Down Expand Up @@ -662,6 +687,14 @@ func (mgr *connManager) Acquire(to mino.Address) (grpc.ClientConnInterface, erro
}

opts = append(opts, grpc.WithTransportCredentials(ta))
} else {
// If the remote is accessible via port 443, we suppose it is TLS terminated and signed by a
// CA available on the system.
if strings.HasSuffix(addr, ":443") {
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})))
} else {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
}

conn, err = grpc.DialContext(
Expand Down

0 comments on commit f164e5b

Please sign in to comment.