295 lines
8.5 KiB
Go
295 lines
8.5 KiB
Go
/*
|
|
Copyright IBM Corp. All Rights Reserved.
|
|
|
|
SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
package comm
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
|
"github.com/pkg/errors"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/health"
|
|
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
|
)
|
|
|
|
type GRPCServer struct {
|
|
// Listen address for the server specified as hostname:port
|
|
address string
|
|
// Listener for handling network requests
|
|
listener net.Listener
|
|
// GRPC server
|
|
server *grpc.Server
|
|
// Certificate presented by the server for TLS communication
|
|
// stored as an atomic reference
|
|
serverCertificate atomic.Value
|
|
// lock to protect concurrent access to append / remove
|
|
lock *sync.Mutex
|
|
// TLS configuration used by the grpc server
|
|
tls *TLSConfig
|
|
// Server for gRPC Health Check Protocol.
|
|
healthServer *health.Server
|
|
}
|
|
|
|
// NewGRPCServer creates a new implementation of a GRPCServer given a
|
|
// listen address
|
|
func NewGRPCServer(address string, serverConfig ServerConfig) (*GRPCServer, error) {
|
|
if address == "" {
|
|
return nil, errors.New("missing address parameter")
|
|
}
|
|
// create our listener
|
|
lis, err := net.Listen("tcp", address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewGRPCServerFromListener(lis, serverConfig)
|
|
}
|
|
|
|
// NewGRPCServerFromListener creates a new implementation of a GRPCServer given
|
|
// an existing net.Listener instance using default keepalive
|
|
func NewGRPCServerFromListener(listener net.Listener, serverConfig ServerConfig) (*GRPCServer, error) {
|
|
grpcServer := &GRPCServer{
|
|
address: listener.Addr().String(),
|
|
listener: listener,
|
|
lock: &sync.Mutex{},
|
|
}
|
|
|
|
// set up our server options
|
|
var serverOpts []grpc.ServerOption
|
|
|
|
secureConfig := serverConfig.SecOpts
|
|
if secureConfig.UseTLS {
|
|
// both key and cert are required
|
|
if secureConfig.Key != nil && secureConfig.Certificate != nil {
|
|
// load server public and private keys
|
|
cert, err := tls.X509KeyPair(secureConfig.Certificate, secureConfig.Key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
grpcServer.serverCertificate.Store(cert)
|
|
|
|
// set up our TLS config
|
|
if len(secureConfig.CipherSuites) == 0 {
|
|
secureConfig.CipherSuites = DefaultTLSCipherSuites
|
|
}
|
|
getCert := func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
cert := grpcServer.serverCertificate.Load().(tls.Certificate)
|
|
return &cert, nil
|
|
}
|
|
|
|
grpcServer.tls = NewTLSConfig(&tls.Config{
|
|
VerifyPeerCertificate: secureConfig.VerifyCertificate,
|
|
GetCertificate: getCert,
|
|
SessionTicketsDisabled: true,
|
|
CipherSuites: secureConfig.CipherSuites,
|
|
})
|
|
|
|
if serverConfig.SecOpts.TimeShift > 0 {
|
|
timeShift := serverConfig.SecOpts.TimeShift
|
|
grpcServer.tls.config.Time = func() time.Time {
|
|
return time.Now().Add((-1) * timeShift)
|
|
}
|
|
}
|
|
grpcServer.tls.config.ClientAuth = tls.RequestClientCert
|
|
// check if client authentication is required
|
|
if secureConfig.RequireClientCert {
|
|
// require TLS client auth
|
|
grpcServer.tls.config.ClientAuth = tls.RequireAndVerifyClientCert
|
|
// if we have client root CAs, create a certPool
|
|
if len(secureConfig.ClientRootCAs) > 0 {
|
|
grpcServer.tls.config.ClientCAs = x509.NewCertPool()
|
|
for _, clientRootCA := range secureConfig.ClientRootCAs {
|
|
err = grpcServer.appendClientRootCA(clientRootCA)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// create credentials and add to server options
|
|
creds := NewServerTransportCredentials(grpcServer.tls, serverConfig.Logger)
|
|
serverOpts = append(serverOpts, grpc.Creds(creds))
|
|
} else {
|
|
return nil, errors.New("serverConfig.SecOpts must contain both Key and Certificate when UseTLS is true")
|
|
}
|
|
}
|
|
|
|
// set max send and recv msg sizes
|
|
maxSendMsgSize := DefaultMaxSendMsgSize
|
|
if serverConfig.MaxSendMsgSize != 0 {
|
|
maxSendMsgSize = serverConfig.MaxSendMsgSize
|
|
}
|
|
maxRecvMsgSize := DefaultMaxRecvMsgSize
|
|
if serverConfig.MaxRecvMsgSize != 0 {
|
|
maxRecvMsgSize = serverConfig.MaxRecvMsgSize
|
|
}
|
|
serverOpts = append(serverOpts, grpc.MaxSendMsgSize(maxSendMsgSize))
|
|
serverOpts = append(serverOpts, grpc.MaxRecvMsgSize(maxRecvMsgSize))
|
|
// set the keepalive options
|
|
serverOpts = append(serverOpts, serverConfig.KaOpts.ServerKeepaliveOptions()...)
|
|
// set connection timeout
|
|
if serverConfig.ConnectionTimeout <= 0 {
|
|
serverConfig.ConnectionTimeout = DefaultConnectionTimeout
|
|
}
|
|
serverOpts = append(
|
|
serverOpts,
|
|
grpc.ConnectionTimeout(serverConfig.ConnectionTimeout))
|
|
// set the interceptors
|
|
if len(serverConfig.StreamInterceptors) > 0 {
|
|
serverOpts = append(
|
|
serverOpts,
|
|
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(serverConfig.StreamInterceptors...)),
|
|
)
|
|
}
|
|
|
|
if len(serverConfig.UnaryInterceptors) > 0 {
|
|
serverOpts = append(
|
|
serverOpts,
|
|
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(serverConfig.UnaryInterceptors...)),
|
|
)
|
|
}
|
|
|
|
if serverConfig.ServerStatsHandler != nil {
|
|
serverOpts = append(serverOpts, grpc.StatsHandler(serverConfig.ServerStatsHandler))
|
|
}
|
|
|
|
grpcServer.server = grpc.NewServer(serverOpts...)
|
|
|
|
if serverConfig.HealthCheckEnabled {
|
|
grpcServer.healthServer = health.NewServer()
|
|
healthpb.RegisterHealthServer(grpcServer.server, grpcServer.healthServer)
|
|
}
|
|
|
|
return grpcServer, nil
|
|
}
|
|
|
|
// SetServerCertificate assigns the current TLS certificate to be the peer's server certificate
|
|
func (gServer *GRPCServer) SetServerCertificate(cert tls.Certificate) {
|
|
gServer.serverCertificate.Store(cert)
|
|
}
|
|
|
|
// Address returns the listen address for this GRPCServer instance
|
|
func (gServer *GRPCServer) Address() string {
|
|
return gServer.address
|
|
}
|
|
|
|
// Listener returns the net.Listener for the GRPCServer instance
|
|
func (gServer *GRPCServer) Listener() net.Listener {
|
|
return gServer.listener
|
|
}
|
|
|
|
// Server returns the grpc.Server for the GRPCServer instance
|
|
func (gServer *GRPCServer) Server() *grpc.Server {
|
|
return gServer.server
|
|
}
|
|
|
|
// ServerCertificate returns the tls.Certificate used by the grpc.Server
|
|
func (gServer *GRPCServer) ServerCertificate() tls.Certificate {
|
|
return gServer.serverCertificate.Load().(tls.Certificate)
|
|
}
|
|
|
|
// TLSEnabled is a flag indicating whether or not TLS is enabled for the
|
|
// GRPCServer instance
|
|
func (gServer *GRPCServer) TLSEnabled() bool {
|
|
return gServer.tls != nil
|
|
}
|
|
|
|
// MutualTLSRequired is a flag indicating whether or not client certificates
|
|
// are required for this GRPCServer instance
|
|
func (gServer *GRPCServer) MutualTLSRequired() bool {
|
|
return gServer.TLSEnabled() &&
|
|
gServer.tls.Config().ClientAuth == tls.RequireAndVerifyClientCert
|
|
}
|
|
|
|
// Start starts the underlying grpc.Server
|
|
func (gServer *GRPCServer) Start() error {
|
|
// if health check is enabled, set the health status for all registered services
|
|
if gServer.healthServer != nil {
|
|
for name := range gServer.server.GetServiceInfo() {
|
|
gServer.healthServer.SetServingStatus(
|
|
name,
|
|
healthpb.HealthCheckResponse_SERVING,
|
|
)
|
|
}
|
|
|
|
gServer.healthServer.SetServingStatus(
|
|
"",
|
|
healthpb.HealthCheckResponse_SERVING,
|
|
)
|
|
}
|
|
return gServer.server.Serve(gServer.listener)
|
|
}
|
|
|
|
// Stop stops the underlying grpc.Server
|
|
func (gServer *GRPCServer) Stop() {
|
|
gServer.server.Stop()
|
|
}
|
|
|
|
// internal function to add a PEM-encoded clientRootCA
|
|
func (gServer *GRPCServer) appendClientRootCA(clientRoot []byte) error {
|
|
certs, err := pemToX509Certs(clientRoot)
|
|
if err != nil {
|
|
return errors.WithMessage(err, "failed to append client root certificate(s)")
|
|
}
|
|
|
|
if len(certs) < 1 {
|
|
return errors.New("no client root certificates found")
|
|
}
|
|
|
|
for _, cert := range certs {
|
|
gServer.tls.AddClientRootCA(cert)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// parse PEM-encoded certs
|
|
func pemToX509Certs(pemCerts []byte) ([]*x509.Certificate, error) {
|
|
var certs []*x509.Certificate
|
|
|
|
// it's possible that multiple certs are encoded
|
|
for len(pemCerts) > 0 {
|
|
var block *pem.Block
|
|
block, pemCerts = pem.Decode(pemCerts)
|
|
if block == nil {
|
|
break
|
|
}
|
|
|
|
cert, err := x509.ParseCertificate(block.Bytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
certs = append(certs, cert)
|
|
}
|
|
|
|
return certs, nil
|
|
}
|
|
|
|
// SetClientRootCAs sets the list of authorities used to verify client
|
|
// certificates based on a list of PEM-encoded X509 certificate authorities
|
|
func (gServer *GRPCServer) SetClientRootCAs(clientRoots [][]byte) error {
|
|
gServer.lock.Lock()
|
|
defer gServer.lock.Unlock()
|
|
|
|
certPool := x509.NewCertPool()
|
|
for _, clientRoot := range clientRoots {
|
|
if !certPool.AppendCertsFromPEM(clientRoot) {
|
|
return errors.New("failed to set client root certificate(s)")
|
|
}
|
|
}
|
|
gServer.tls.SetClientCAs(certPool)
|
|
return nil
|
|
}
|