340 lines
8.5 KiB
Go
340 lines
8.5 KiB
Go
/*
|
|
Copyright IBM Corp. All Rights Reserved.
|
|
|
|
SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
package comm_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/hyperledger/fabric/internal/pkg/comm"
|
|
"github.com/hyperledger/fabric/internal/pkg/comm/testpb"
|
|
"github.com/pkg/errors"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
)
|
|
|
|
const testTimeout = 1 * time.Second // conservative
|
|
|
|
type echoServer struct{}
|
|
|
|
func (es *echoServer) EchoCall(ctx context.Context,
|
|
echo *testpb.Echo) (*testpb.Echo, error) {
|
|
return echo, nil
|
|
}
|
|
|
|
func TestClientConfigDial(t *testing.T) {
|
|
t.Parallel()
|
|
testCerts := comm.LoadTestCerts(t)
|
|
|
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
badAddress := l.Addr().String()
|
|
defer l.Close()
|
|
|
|
certPool := x509.NewCertPool()
|
|
ok := certPool.AppendCertsFromPEM(testCerts.CAPEM)
|
|
if !ok {
|
|
t.Fatal("failed to create test root cert pool")
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
clientAddress string
|
|
config comm.ClientConfig
|
|
serverTLS *tls.Config
|
|
success bool
|
|
errorMsg string
|
|
}{
|
|
{
|
|
name: "client / server same port",
|
|
config: comm.ClientConfig{
|
|
DialTimeout: testTimeout,
|
|
},
|
|
success: true,
|
|
},
|
|
{
|
|
name: "client / server wrong port",
|
|
clientAddress: badAddress,
|
|
config: comm.ClientConfig{
|
|
DialTimeout: time.Second,
|
|
},
|
|
success: false,
|
|
errorMsg: "(connection refused|context deadline exceeded)",
|
|
},
|
|
{
|
|
name: "client / server wrong port but with asynchronous should succeed",
|
|
clientAddress: badAddress,
|
|
config: comm.ClientConfig{
|
|
AsyncConnect: true,
|
|
DialTimeout: testTimeout,
|
|
},
|
|
success: true,
|
|
},
|
|
{
|
|
name: "client TLS / server no TLS",
|
|
config: comm.ClientConfig{
|
|
SecOpts: comm.SecureOptions{
|
|
Certificate: testCerts.CertPEM,
|
|
Key: testCerts.KeyPEM,
|
|
UseTLS: true,
|
|
ServerRootCAs: [][]byte{testCerts.CAPEM},
|
|
RequireClientCert: true,
|
|
},
|
|
DialTimeout: time.Second,
|
|
},
|
|
success: false,
|
|
errorMsg: "context deadline exceeded",
|
|
},
|
|
{
|
|
name: "client TLS / server TLS match",
|
|
config: comm.ClientConfig{
|
|
SecOpts: comm.SecureOptions{
|
|
Certificate: testCerts.CertPEM,
|
|
Key: testCerts.KeyPEM,
|
|
UseTLS: true,
|
|
ServerRootCAs: [][]byte{testCerts.CAPEM},
|
|
},
|
|
DialTimeout: testTimeout,
|
|
},
|
|
serverTLS: &tls.Config{
|
|
Certificates: []tls.Certificate{testCerts.ServerCert},
|
|
},
|
|
success: true,
|
|
},
|
|
{
|
|
name: "client TLS / server TLS no server roots",
|
|
config: comm.ClientConfig{
|
|
SecOpts: comm.SecureOptions{
|
|
Certificate: testCerts.CertPEM,
|
|
Key: testCerts.KeyPEM,
|
|
UseTLS: true,
|
|
ServerRootCAs: [][]byte{},
|
|
},
|
|
DialTimeout: testTimeout,
|
|
},
|
|
serverTLS: &tls.Config{
|
|
Certificates: []tls.Certificate{testCerts.ServerCert},
|
|
},
|
|
success: false,
|
|
errorMsg: "context deadline exceeded",
|
|
},
|
|
{
|
|
name: "client TLS / server TLS missing client cert",
|
|
config: comm.ClientConfig{
|
|
SecOpts: comm.SecureOptions{
|
|
Certificate: testCerts.CertPEM,
|
|
Key: testCerts.KeyPEM,
|
|
UseTLS: true,
|
|
ServerRootCAs: [][]byte{testCerts.CAPEM},
|
|
},
|
|
DialTimeout: testTimeout,
|
|
},
|
|
serverTLS: &tls.Config{
|
|
Certificates: []tls.Certificate{testCerts.ServerCert},
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
MaxVersion: tls.VersionTLS12, // https://github.com/golang/go/issues/33368
|
|
},
|
|
success: false,
|
|
errorMsg: "tls: bad certificate",
|
|
},
|
|
{
|
|
name: "client TLS / server TLS client cert",
|
|
config: comm.ClientConfig{
|
|
SecOpts: comm.SecureOptions{
|
|
Certificate: testCerts.CertPEM,
|
|
Key: testCerts.KeyPEM,
|
|
UseTLS: true,
|
|
RequireClientCert: true,
|
|
ServerRootCAs: [][]byte{testCerts.CAPEM},
|
|
},
|
|
DialTimeout: testTimeout,
|
|
},
|
|
serverTLS: &tls.Config{
|
|
Certificates: []tls.Certificate{testCerts.ServerCert},
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
ClientCAs: certPool,
|
|
},
|
|
success: true,
|
|
},
|
|
{
|
|
name: "server TLS pinning success",
|
|
config: comm.ClientConfig{
|
|
SecOpts: comm.SecureOptions{
|
|
VerifyCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
|
if bytes.Equal(rawCerts[0], testCerts.ServerCert.Certificate[0]) {
|
|
return nil
|
|
}
|
|
panic("mismatched certificate")
|
|
},
|
|
Certificate: testCerts.CertPEM,
|
|
Key: testCerts.KeyPEM,
|
|
UseTLS: true,
|
|
RequireClientCert: true,
|
|
ServerRootCAs: [][]byte{testCerts.CAPEM},
|
|
},
|
|
DialTimeout: testTimeout,
|
|
},
|
|
serverTLS: &tls.Config{
|
|
Certificates: []tls.Certificate{testCerts.ServerCert},
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
ClientCAs: certPool,
|
|
},
|
|
success: true,
|
|
},
|
|
{
|
|
name: "server TLS pinning failure",
|
|
config: comm.ClientConfig{
|
|
SecOpts: comm.SecureOptions{
|
|
VerifyCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
|
return errors.New("TLS certificate mismatch")
|
|
},
|
|
Certificate: testCerts.CertPEM,
|
|
Key: testCerts.KeyPEM,
|
|
UseTLS: true,
|
|
RequireClientCert: true,
|
|
ServerRootCAs: [][]byte{testCerts.CAPEM},
|
|
},
|
|
DialTimeout: testTimeout,
|
|
},
|
|
serverTLS: &tls.Config{
|
|
Certificates: []tls.Certificate{testCerts.ServerCert},
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
ClientCAs: certPool,
|
|
},
|
|
success: false,
|
|
errorMsg: "context deadline exceeded",
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
test := test
|
|
t.Run(test.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("error creating server for test: %v", err)
|
|
}
|
|
defer lis.Close()
|
|
serverOpts := []grpc.ServerOption{}
|
|
if test.serverTLS != nil {
|
|
serverOpts = append(serverOpts, grpc.Creds(credentials.NewTLS(test.serverTLS)))
|
|
}
|
|
srv := grpc.NewServer(serverOpts...)
|
|
defer srv.Stop()
|
|
go srv.Serve(lis)
|
|
address := lis.Addr().String()
|
|
if test.clientAddress != "" {
|
|
address = test.clientAddress
|
|
}
|
|
conn, err := test.config.Dial(address)
|
|
if test.success {
|
|
require.NoError(t, err)
|
|
require.NotNil(t, conn)
|
|
} else {
|
|
t.Log(errors.WithStack(err))
|
|
require.Regexp(t, test.errorMsg, err.Error())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSetMessageSize(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// setup test server
|
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("failed to create listener for test server: %v", err)
|
|
}
|
|
srv, err := comm.NewGRPCServerFromListener(lis, comm.ServerConfig{})
|
|
if err != nil {
|
|
t.Fatalf("failed to create test server: %v", err)
|
|
}
|
|
testpb.RegisterEchoServiceServer(srv.Server(), &echoServer{})
|
|
defer srv.Stop()
|
|
go srv.Start()
|
|
|
|
tests := []struct {
|
|
name string
|
|
maxRecvSize int
|
|
maxSendSize int
|
|
failRecv bool
|
|
failSend bool
|
|
}{
|
|
{
|
|
name: "defaults should pass",
|
|
failRecv: false,
|
|
failSend: false,
|
|
},
|
|
{
|
|
name: "non-defaults should pass",
|
|
failRecv: false,
|
|
failSend: false,
|
|
maxRecvSize: 20,
|
|
maxSendSize: 20,
|
|
},
|
|
{
|
|
name: "recv should fail",
|
|
failRecv: true,
|
|
failSend: false,
|
|
maxRecvSize: 1,
|
|
},
|
|
{
|
|
name: "send should fail",
|
|
failRecv: false,
|
|
failSend: true,
|
|
maxSendSize: 1,
|
|
},
|
|
}
|
|
|
|
// run tests
|
|
for _, test := range tests {
|
|
test := test
|
|
address := lis.Addr().String()
|
|
t.Run(test.name, func(t *testing.T) {
|
|
t.Log(test.name)
|
|
config := comm.ClientConfig{
|
|
DialTimeout: testTimeout,
|
|
MaxRecvMsgSize: test.maxRecvSize,
|
|
MaxSendMsgSize: test.maxSendSize,
|
|
}
|
|
conn, err := config.Dial(address)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
// create service client from conn
|
|
svcClient := testpb.NewEchoServiceClient(conn)
|
|
callCtx := context.Background()
|
|
callCtx, cancel := context.WithTimeout(callCtx, testTimeout)
|
|
defer cancel()
|
|
// invoke service
|
|
echo := &testpb.Echo{
|
|
Payload: []byte{0, 0, 0, 0, 0},
|
|
}
|
|
resp, err := svcClient.EchoCall(callCtx, echo)
|
|
if !test.failRecv && !test.failSend {
|
|
require.NoError(t, err)
|
|
require.True(t, proto.Equal(echo, resp))
|
|
}
|
|
if test.failSend {
|
|
t.Logf("send error: %v", err)
|
|
require.Contains(t, err.Error(), "trying to send message larger than max")
|
|
}
|
|
if test.failRecv {
|
|
t.Logf("recv error: %v", err)
|
|
require.Contains(t, err.Error(), "received message larger than max")
|
|
}
|
|
})
|
|
}
|
|
}
|