137 lines
3.9 KiB
Go
137 lines
3.9 KiB
Go
/*
|
|
Copyright IBM Corp. 2017 All Rights Reserved.
|
|
|
|
SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
package cluster
|
|
|
|
import (
|
|
"crypto/x509"
|
|
"sync"
|
|
|
|
"github.com/hyperledger/fabric/common/crypto"
|
|
"github.com/hyperledger/fabric/common/metrics"
|
|
"github.com/pkg/errors"
|
|
"google.golang.org/grpc"
|
|
)
|
|
|
|
// RemoteVerifier verifies the connection to the remote host
|
|
type RemoteVerifier func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
|
|
|
|
//go:generate mockery --dir . --name SecureDialer --case underscore --output ./mocks/
|
|
|
|
// SecureDialer connects to a remote address
|
|
type SecureDialer interface {
|
|
Dial(address string, verifyFunc RemoteVerifier) (*grpc.ClientConn, error)
|
|
}
|
|
|
|
// ConnectionMapper maps certificates to connections
|
|
type ConnectionMapper interface {
|
|
Lookup(key []byte) (*grpc.ClientConn, bool)
|
|
Put(key []byte, conn *grpc.ClientConn)
|
|
Remove(key []byte)
|
|
Size() int
|
|
}
|
|
|
|
// ConnectionStore stores connections to remote nodes
|
|
type ConnectionStore struct {
|
|
lock sync.RWMutex
|
|
Connections ConnectionMapper
|
|
dialer SecureDialer
|
|
}
|
|
|
|
// NewConnectionStore creates a new ConnectionStore with the given SecureDialer
|
|
func NewConnectionStore(dialer SecureDialer, tlsConnectionCount metrics.Gauge) *ConnectionStore {
|
|
connMapping := &ConnectionStore{
|
|
Connections: &connMapperReporter{
|
|
ConnectionMapper: make(ConnByCertMap),
|
|
tlsConnectionCountMetrics: tlsConnectionCount,
|
|
},
|
|
dialer: dialer,
|
|
}
|
|
return connMapping
|
|
}
|
|
|
|
// verifyHandshake returns a predicate that verifies that the remote node authenticates
|
|
// itself with the given TLS certificate
|
|
func (c *ConnectionStore) verifyHandshake(endpoint string, certificate []byte) RemoteVerifier {
|
|
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
|
err := crypto.CertificatesWithSamePublicKey(certificate, rawCerts[0])
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
return errors.Errorf("public key of server certificate presented by %s doesn't match the expected public key",
|
|
endpoint)
|
|
}
|
|
}
|
|
|
|
// Disconnect closes the gRPC connection that is mapped to the given certificate
|
|
func (c *ConnectionStore) Disconnect(expectedServerCert []byte) {
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
|
|
conn, connected := c.Connections.Lookup(expectedServerCert)
|
|
if !connected {
|
|
return
|
|
}
|
|
conn.Close()
|
|
c.Connections.Remove(expectedServerCert)
|
|
}
|
|
|
|
// Connection obtains a connection to the given endpoint and expects the given server certificate
|
|
// to be presented by the remote node
|
|
func (c *ConnectionStore) Connection(endpoint string, expectedServerCert []byte) (*grpc.ClientConn, error) {
|
|
c.lock.RLock()
|
|
conn, alreadyConnected := c.Connections.Lookup(expectedServerCert)
|
|
c.lock.RUnlock()
|
|
|
|
if alreadyConnected {
|
|
return conn, nil
|
|
}
|
|
|
|
// Else, we need to connect to the remote endpoint
|
|
return c.connect(endpoint, expectedServerCert)
|
|
}
|
|
|
|
// connect connects to the given endpoint and expects the given TLS server certificate
|
|
// to be presented at the time of authentication
|
|
func (c *ConnectionStore) connect(endpoint string, expectedServerCert []byte) (*grpc.ClientConn, error) {
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
// Check again to see if some other goroutine has already connected while
|
|
// we were waiting on the lock
|
|
conn, alreadyConnected := c.Connections.Lookup(expectedServerCert)
|
|
if alreadyConnected {
|
|
return conn, nil
|
|
}
|
|
|
|
v := c.verifyHandshake(endpoint, expectedServerCert)
|
|
conn, err := c.dialer.Dial(endpoint, v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c.Connections.Put(expectedServerCert, conn)
|
|
return conn, nil
|
|
}
|
|
|
|
type connMapperReporter struct {
|
|
tlsConnectionCountMetrics metrics.Gauge
|
|
ConnectionMapper
|
|
}
|
|
|
|
func (cmg *connMapperReporter) Put(cert []byte, conn *grpc.ClientConn) {
|
|
cmg.ConnectionMapper.Put(cert, conn)
|
|
cmg.reportSize()
|
|
}
|
|
|
|
func (cmg *connMapperReporter) Remove(cert []byte) {
|
|
cmg.ConnectionMapper.Remove(cert)
|
|
cmg.reportSize()
|
|
}
|
|
|
|
func (cmg *connMapperReporter) reportSize() {
|
|
cmg.tlsConnectionCountMetrics.Set(float64(cmg.ConnectionMapper.Size()))
|
|
}
|