349 lines
9.7 KiB
Go
349 lines
9.7 KiB
Go
/*
|
|
Copyright IBM Corp. All Rights Reserved.
|
|
|
|
SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
package comm
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"io/ioutil"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hyperledger/fabric/common/crypto/tlsgen"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/keepalive"
|
|
)
|
|
|
|
func TestServerKeepaliveOptions(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
kap := keepalive.ServerParameters{
|
|
Time: DefaultKeepaliveOptions.ServerInterval,
|
|
Timeout: DefaultKeepaliveOptions.ServerTimeout,
|
|
}
|
|
kep := keepalive.EnforcementPolicy{
|
|
MinTime: DefaultKeepaliveOptions.ServerMinInterval,
|
|
PermitWithoutStream: true,
|
|
}
|
|
expectedOpts := []grpc.ServerOption{
|
|
grpc.KeepaliveParams(kap),
|
|
grpc.KeepaliveEnforcementPolicy(kep),
|
|
}
|
|
opts := DefaultKeepaliveOptions.ServerKeepaliveOptions()
|
|
|
|
// Unable to test equality of options since the option methods return
|
|
// functions and each instance is a different func.
|
|
// Unable to test the equality of applying the options to the server
|
|
// implementation because the server embeds channels.
|
|
// Fallback to a sanity check.
|
|
require.Len(t, opts, len(expectedOpts))
|
|
for i := range opts {
|
|
require.IsType(t, expectedOpts[i], opts[i])
|
|
}
|
|
}
|
|
|
|
func TestClientKeepaliveOptions(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
kap := keepalive.ClientParameters{
|
|
Time: DefaultKeepaliveOptions.ClientInterval,
|
|
Timeout: DefaultKeepaliveOptions.ClientTimeout,
|
|
PermitWithoutStream: true,
|
|
}
|
|
expectedOpts := []grpc.DialOption{grpc.WithKeepaliveParams(kap)}
|
|
opts := DefaultKeepaliveOptions.ClientKeepaliveOptions()
|
|
|
|
// Unable to test equality of options since the option methods return
|
|
// functions and each instance is a different func.
|
|
// Fallback to a sanity check.
|
|
require.Len(t, opts, len(expectedOpts))
|
|
for i := range opts {
|
|
require.IsType(t, expectedOpts[i], opts[i])
|
|
}
|
|
}
|
|
|
|
func TestClientConfigClone(t *testing.T) {
|
|
origin := ClientConfig{
|
|
KaOpts: KeepaliveOptions{
|
|
ClientInterval: time.Second,
|
|
},
|
|
SecOpts: SecureOptions{
|
|
Key: []byte{1, 2, 3},
|
|
},
|
|
DialTimeout: time.Second,
|
|
AsyncConnect: true,
|
|
}
|
|
|
|
clone := origin
|
|
|
|
// Same content, different inner fields references.
|
|
require.Equal(t, origin, clone)
|
|
|
|
// We change the contents of the fields and ensure it doesn't
|
|
// propagate across instances.
|
|
origin.AsyncConnect = false
|
|
origin.KaOpts.ServerInterval = time.Second
|
|
origin.KaOpts.ClientInterval = time.Hour
|
|
origin.SecOpts.Certificate = []byte{1, 2, 3}
|
|
origin.SecOpts.Key = []byte{5, 4, 6}
|
|
origin.DialTimeout = time.Second * 2
|
|
|
|
clone.SecOpts.UseTLS = true
|
|
clone.KaOpts.ServerMinInterval = time.Hour
|
|
|
|
expectedOriginState := ClientConfig{
|
|
KaOpts: KeepaliveOptions{
|
|
ClientInterval: time.Hour,
|
|
ServerInterval: time.Second,
|
|
},
|
|
SecOpts: SecureOptions{
|
|
Key: []byte{5, 4, 6},
|
|
Certificate: []byte{1, 2, 3},
|
|
},
|
|
DialTimeout: time.Second * 2,
|
|
}
|
|
|
|
expectedCloneState := ClientConfig{
|
|
KaOpts: KeepaliveOptions{
|
|
ClientInterval: time.Second,
|
|
ServerMinInterval: time.Hour,
|
|
},
|
|
SecOpts: SecureOptions{
|
|
Key: []byte{1, 2, 3},
|
|
UseTLS: true,
|
|
},
|
|
DialTimeout: time.Second,
|
|
AsyncConnect: true,
|
|
}
|
|
|
|
require.Equal(t, expectedOriginState, origin)
|
|
require.Equal(t, expectedCloneState, clone)
|
|
}
|
|
|
|
func TestSecureOptionsTLSConfig(t *testing.T) {
|
|
ca1, err := tlsgen.NewCA()
|
|
require.NoError(t, err, "failed to create CA1")
|
|
ca2, err := tlsgen.NewCA()
|
|
require.NoError(t, err, "failed to create CA2")
|
|
ckp, err := ca1.NewClientCertKeyPair()
|
|
require.NoError(t, err, "failed to create client key pair")
|
|
clientCert, err := tls.X509KeyPair(ckp.Cert, ckp.Key)
|
|
require.NoError(t, err, "failed to create client certificate")
|
|
|
|
newCertPool := func(cas ...tlsgen.CA) *x509.CertPool {
|
|
cp := x509.NewCertPool()
|
|
for _, ca := range cas {
|
|
ok := cp.AppendCertsFromPEM(ca.CertBytes())
|
|
require.True(t, ok, "failed to add cert to pool")
|
|
}
|
|
return cp
|
|
}
|
|
|
|
tests := []struct {
|
|
desc string
|
|
so SecureOptions
|
|
tc *tls.Config
|
|
expectedErr string
|
|
}{
|
|
{desc: "TLSDisabled"},
|
|
{desc: "TLSEnabled", so: SecureOptions{UseTLS: true}, tc: &tls.Config{MinVersion: tls.VersionTLS12}},
|
|
{
|
|
desc: "ServerNameOverride",
|
|
so: SecureOptions{UseTLS: true, ServerNameOverride: "bob"},
|
|
tc: &tls.Config{MinVersion: tls.VersionTLS12, ServerName: "bob"},
|
|
},
|
|
{
|
|
desc: "WithServerRootCAs",
|
|
so: SecureOptions{UseTLS: true, ServerRootCAs: [][]byte{ca1.CertBytes(), ca2.CertBytes()}},
|
|
tc: &tls.Config{MinVersion: tls.VersionTLS12, RootCAs: newCertPool(ca1, ca2)},
|
|
},
|
|
{
|
|
desc: "BadServerRootCertificate",
|
|
so: SecureOptions{
|
|
UseTLS: true,
|
|
ServerRootCAs: [][]byte{
|
|
[]byte("-----BEGIN CERTIFICATE-----\nYm9ndXM=\n-----END CERTIFICATE-----"),
|
|
},
|
|
},
|
|
expectedErr: "error adding root certificate",
|
|
},
|
|
{
|
|
desc: "WithRequiredClientKeyPair",
|
|
so: SecureOptions{UseTLS: true, RequireClientCert: true, Key: ckp.Key, Certificate: ckp.Cert},
|
|
tc: &tls.Config{MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{clientCert}},
|
|
},
|
|
{
|
|
desc: "MissingClientKey",
|
|
so: SecureOptions{UseTLS: true, RequireClientCert: true, Certificate: ckp.Cert},
|
|
expectedErr: "both Key and Certificate are required when using mutual TLS",
|
|
},
|
|
{
|
|
desc: "MissingClientCert",
|
|
so: SecureOptions{UseTLS: true, RequireClientCert: true, Key: ckp.Key},
|
|
expectedErr: "both Key and Certificate are required when using mutual TLS",
|
|
},
|
|
{
|
|
desc: "WithTimeShift",
|
|
so: SecureOptions{UseTLS: true, TimeShift: 2 * time.Hour},
|
|
tc: &tls.Config{MinVersion: tls.VersionTLS12},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.desc, func(t *testing.T) {
|
|
tc, err := tt.so.TLSConfig()
|
|
if tt.expectedErr != "" {
|
|
require.ErrorContainsf(t, err, tt.expectedErr, "got %v, want %s", err, tt.expectedErr)
|
|
return
|
|
}
|
|
require.NoError(t, err)
|
|
|
|
if len(tt.so.ServerRootCAs) != 0 {
|
|
require.NotNil(t, tc.RootCAs)
|
|
require.Len(t, tc.RootCAs.Subjects(), len(tt.so.ServerRootCAs))
|
|
for _, subj := range tt.tc.RootCAs.Subjects() {
|
|
require.Contains(t, tc.RootCAs.Subjects(), subj, "missing subject %x", subj)
|
|
}
|
|
tt.tc.RootCAs, tc.RootCAs = nil, nil
|
|
}
|
|
|
|
if tt.so.TimeShift != 0 {
|
|
require.NotNil(t, tc.Time)
|
|
require.WithinDuration(t, time.Now().Add(-1*tt.so.TimeShift), tc.Time(), 10*time.Second)
|
|
tc.Time = nil
|
|
}
|
|
|
|
require.Equal(t, tt.tc, tc)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestClientConfigDialOptions_GoodConfig(t *testing.T) {
|
|
testCerts := LoadTestCerts(t)
|
|
|
|
config := ClientConfig{}
|
|
opts, err := config.DialOptions()
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, opts)
|
|
|
|
secOpts := SecureOptions{
|
|
UseTLS: true,
|
|
ServerRootCAs: [][]byte{testCerts.CAPEM},
|
|
RequireClientCert: false,
|
|
}
|
|
config.SecOpts = secOpts
|
|
opts, err = config.DialOptions()
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, opts)
|
|
|
|
secOpts = SecureOptions{
|
|
Certificate: testCerts.CertPEM,
|
|
Key: testCerts.KeyPEM,
|
|
UseTLS: true,
|
|
ServerRootCAs: [][]byte{testCerts.CAPEM},
|
|
RequireClientCert: true,
|
|
}
|
|
clientCert, err := secOpts.ClientCertificate()
|
|
require.NoError(t, err)
|
|
require.Equal(t, testCerts.ClientCert, clientCert)
|
|
config.SecOpts = secOpts
|
|
opts, err = config.DialOptions()
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, opts)
|
|
}
|
|
|
|
func TestClientConfigDialOptions_BadConfig(t *testing.T) {
|
|
testCerts := LoadTestCerts(t)
|
|
|
|
// bad root cert
|
|
config := ClientConfig{
|
|
SecOpts: SecureOptions{
|
|
UseTLS: true,
|
|
ServerRootCAs: [][]byte{[]byte(badPEM)},
|
|
},
|
|
}
|
|
_, err := config.DialOptions()
|
|
require.ErrorContains(t, err, "error adding root certificate")
|
|
|
|
// missing key
|
|
config.SecOpts = SecureOptions{
|
|
Certificate: []byte("cert"),
|
|
UseTLS: true,
|
|
RequireClientCert: true,
|
|
}
|
|
_, err = config.DialOptions()
|
|
require.ErrorContains(t, err, "both Key and Certificate are required when using mutual TLS")
|
|
|
|
// missing cert
|
|
config.SecOpts = SecureOptions{
|
|
Key: []byte("key"),
|
|
UseTLS: true,
|
|
RequireClientCert: true,
|
|
}
|
|
_, err = config.DialOptions()
|
|
require.ErrorContains(t, err, "both Key and Certificate are required when using mutual TLS")
|
|
|
|
// bad key
|
|
config.SecOpts = SecureOptions{
|
|
Certificate: testCerts.CertPEM,
|
|
Key: []byte(badPEM),
|
|
UseTLS: true,
|
|
RequireClientCert: true,
|
|
}
|
|
_, err = config.DialOptions()
|
|
require.ErrorContains(t, err, "failed to load client certificate")
|
|
|
|
// bad cert
|
|
config.SecOpts = SecureOptions{
|
|
Certificate: []byte(badPEM),
|
|
Key: testCerts.KeyPEM,
|
|
UseTLS: true,
|
|
RequireClientCert: true,
|
|
}
|
|
_, err = config.DialOptions()
|
|
require.ErrorContains(t, err, "failed to load client certificate")
|
|
}
|
|
|
|
type TestCerts struct {
|
|
CAPEM []byte
|
|
CertPEM []byte
|
|
KeyPEM []byte
|
|
ClientCert tls.Certificate
|
|
ServerCert tls.Certificate
|
|
}
|
|
|
|
func LoadTestCerts(t *testing.T) TestCerts {
|
|
t.Helper()
|
|
|
|
var certs TestCerts
|
|
var err error
|
|
certs.CAPEM, err = ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem"))
|
|
if err != nil {
|
|
t.Fatalf("unexpected error reading root cert for test: %v", err)
|
|
}
|
|
certs.CertPEM, err = ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-client1-cert.pem"))
|
|
if err != nil {
|
|
t.Fatalf("unexpected error reading cert for test: %v", err)
|
|
}
|
|
certs.KeyPEM, err = ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-client1-key.pem"))
|
|
if err != nil {
|
|
t.Fatalf("unexpected error reading key for test: %v", err)
|
|
}
|
|
certs.ClientCert, err = tls.X509KeyPair(certs.CertPEM, certs.KeyPEM)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error loading certificate for test: %v", err)
|
|
}
|
|
certs.ServerCert, err = tls.LoadX509KeyPair(
|
|
filepath.Join("testdata", "certs", "Org1-server1-cert.pem"),
|
|
filepath.Join("testdata", "certs", "Org1-server1-key.pem"),
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
return certs
|
|
}
|