go_study/fabric-main/internal/pkg/comm/creds.go

174 lines
4.9 KiB
Go

/*
Copyright IBM Corp. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
*/
package comm
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"sync"
"time"
"github.com/hyperledger/fabric/common/flogging"
"google.golang.org/grpc/credentials"
)
var (
ErrClientHandshakeNotImplemented = errors.New("core/comm: client handshakes are not implemented with serverCreds")
ErrServerHandshakeNotImplemented = errors.New("core/comm: server handshakes are not implemented with clientCreds")
ErrOverrideHostnameNotSupported = errors.New("core/comm: OverrideServerName is not supported")
// alpnProtoStr are the specified application level protocols for gRPC.
alpnProtoStr = []string{"h2"}
// Logger for TLS client connections
tlsClientLogger = flogging.MustGetLogger("comm.tls")
)
// NewServerTransportCredentials returns a new initialized
// grpc/credentials.TransportCredentials
func NewServerTransportCredentials(
serverConfig *TLSConfig,
logger *flogging.FabricLogger) credentials.TransportCredentials {
// NOTE: unlike the default grpc/credentials implementation, we do not
// clone the tls.Config which allows us to update it dynamically
serverConfig.config.NextProtos = alpnProtoStr
serverConfig.config.MinVersion = tls.VersionTLS12
if logger == nil {
logger = tlsClientLogger
}
return &serverCreds{
serverConfig: serverConfig,
logger: logger,
}
}
// serverCreds is an implementation of grpc/credentials.TransportCredentials.
type serverCreds struct {
serverConfig *TLSConfig
logger *flogging.FabricLogger
}
type TLSConfig struct {
config *tls.Config
lock sync.RWMutex
}
func NewTLSConfig(config *tls.Config) *TLSConfig {
return &TLSConfig{
config: config,
}
}
func (t *TLSConfig) Config() tls.Config {
t.lock.RLock()
defer t.lock.RUnlock()
if t.config != nil {
return *t.config.Clone()
}
return tls.Config{}
}
func (t *TLSConfig) AddClientRootCA(cert *x509.Certificate) {
t.lock.Lock()
defer t.lock.Unlock()
t.config.ClientCAs.AddCert(cert)
}
func (t *TLSConfig) SetClientCAs(certPool *x509.CertPool) {
t.lock.Lock()
defer t.lock.Unlock()
t.config.ClientCAs = certPool
}
// ClientHandShake is not implemented for `serverCreds`.
func (sc *serverCreds) ClientHandshake(context.Context, string, net.Conn) (net.Conn, credentials.AuthInfo, error) {
return nil, nil, ErrClientHandshakeNotImplemented
}
// ServerHandshake does the authentication handshake for servers.
func (sc *serverCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
serverConfig := sc.serverConfig.Config()
conn := tls.Server(rawConn, &serverConfig)
l := sc.logger.With("remote address", conn.RemoteAddr().String())
start := time.Now()
if err := conn.Handshake(); err != nil {
l.Errorf("Server TLS handshake failed in %s with error %s", time.Since(start), err)
return nil, nil, err
}
l.Debugf("Server TLS handshake completed in %s", time.Since(start))
return conn, credentials.TLSInfo{State: conn.ConnectionState()}, nil
}
// Info provides the ProtocolInfo of this TransportCredentials.
func (sc *serverCreds) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{
SecurityProtocol: "tls",
SecurityVersion: "1.2",
}
}
// Clone makes a copy of this TransportCredentials.
func (sc *serverCreds) Clone() credentials.TransportCredentials {
config := sc.serverConfig.Config()
serverConfig := NewTLSConfig(&config)
return NewServerTransportCredentials(serverConfig, sc.logger)
}
// OverrideServerName overrides the server name used to verify the hostname
// on the returned certificates from the server.
func (sc *serverCreds) OverrideServerName(string) error {
return ErrOverrideHostnameNotSupported
}
type DynamicClientCredentials struct {
TLSConfig *tls.Config
}
func (dtc *DynamicClientCredentials) latestConfig() *tls.Config {
return dtc.TLSConfig.Clone()
}
func (dtc *DynamicClientCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
l := tlsClientLogger.With("remote address", rawConn.RemoteAddr().String())
creds := credentials.NewTLS(dtc.latestConfig())
start := time.Now()
conn, auth, err := creds.ClientHandshake(ctx, authority, rawConn)
if err != nil {
l.Errorf("Client TLS handshake failed after %s with error: %s", time.Since(start), err)
} else {
l.Debugf("Client TLS handshake completed in %s", time.Since(start))
}
return conn, auth, err
}
func (dtc *DynamicClientCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return nil, nil, ErrServerHandshakeNotImplemented
}
func (dtc *DynamicClientCredentials) Info() credentials.ProtocolInfo {
return credentials.NewTLS(dtc.latestConfig()).Info()
}
func (dtc *DynamicClientCredentials) Clone() credentials.TransportCredentials {
return credentials.NewTLS(dtc.latestConfig())
}
func (dtc *DynamicClientCredentials) OverrideServerName(name string) error {
dtc.TLSConfig.ServerName = name
return nil
}