go_study/fabric-main/gossip/comm/comm_test.go

1100 lines
32 KiB
Go

/*
Copyright IBM Corp. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
*/
package comm
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/tls"
"errors"
"fmt"
"io"
"math/rand"
"net"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
cb "github.com/hyperledger/fabric-protos-go/common"
proto "github.com/hyperledger/fabric-protos-go/gossip"
"github.com/hyperledger/fabric/bccsp/factory"
"github.com/hyperledger/fabric/common/flogging"
"github.com/hyperledger/fabric/common/metrics/disabled"
"github.com/hyperledger/fabric/gossip/api"
"github.com/hyperledger/fabric/gossip/api/mocks"
gmocks "github.com/hyperledger/fabric/gossip/comm/mocks"
"github.com/hyperledger/fabric/gossip/common"
"github.com/hyperledger/fabric/gossip/identity"
"github.com/hyperledger/fabric/gossip/metrics"
"github.com/hyperledger/fabric/gossip/protoext"
"github.com/hyperledger/fabric/gossip/util"
"github.com/hyperledger/fabric/internal/pkg/comm"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
func init() {
util.SetupTestLogging()
rand.Seed(time.Now().UnixNano())
factory.InitFactories(nil)
naiveSec.On("OrgByPeerIdentity", mock.Anything).Return(api.OrgIdentityType{})
}
var testCommConfig = CommConfig{
DialTimeout: 300 * time.Millisecond,
ConnTimeout: DefConnTimeout,
RecvBuffSize: DefRecvBuffSize,
SendBuffSize: DefSendBuffSize,
}
func acceptAll(msg interface{}) bool {
return true
}
var noopPurgeIdentity = func(_ common.PKIidType, _ api.PeerIdentityType) {
}
var (
naiveSec = &naiveSecProvider{}
hmacKey = []byte{0, 0, 0}
disabledMetrics = metrics.NewGossipMetrics(&disabled.Provider{}).CommMetrics
)
type naiveSecProvider struct {
mocks.SecurityAdvisor
}
func (nsp *naiveSecProvider) OrgByPeerIdentity(identity api.PeerIdentityType) api.OrgIdentityType {
return nsp.SecurityAdvisor.Called(identity).Get(0).(api.OrgIdentityType)
}
func (*naiveSecProvider) Expiration(peerIdentity api.PeerIdentityType) (time.Time, error) {
return time.Now().Add(time.Hour), nil
}
func (*naiveSecProvider) ValidateIdentity(peerIdentity api.PeerIdentityType) error {
return nil
}
// GetPKIidOfCert returns the PKI-ID of a peer's identity
func (*naiveSecProvider) GetPKIidOfCert(peerIdentity api.PeerIdentityType) common.PKIidType {
return common.PKIidType(peerIdentity)
}
// VerifyBlock returns nil if the block is properly signed,
// else returns error
func (*naiveSecProvider) VerifyBlock(channelID common.ChannelID, seqNum uint64, signedBlock *cb.Block) error {
return nil
}
// VerifyBlockAttestation returns nil if the block attestation is properly signed,
// else returns error
func (*naiveSecProvider) VerifyBlockAttestation(channelID string, signedBlock *cb.Block) error {
return nil
}
// Sign signs msg with this peer's signing key and outputs
// the signature if no error occurred.
func (*naiveSecProvider) Sign(msg []byte) ([]byte, error) {
mac := hmac.New(sha256.New, hmacKey)
mac.Write(msg)
return mac.Sum(nil), nil
}
// Verify checks that signature is a valid signature of message under a peer's verification key.
// If the verification succeeded, Verify returns nil meaning no error occurred.
// If peerCert is nil, then the signature is verified against this peer's verification key.
func (*naiveSecProvider) Verify(peerIdentity api.PeerIdentityType, signature, message []byte) error {
mac := hmac.New(sha256.New, hmacKey)
mac.Write(message)
expected := mac.Sum(nil)
if !bytes.Equal(signature, expected) {
return fmt.Errorf("Wrong certificate:%v, %v", signature, message)
}
return nil
}
// VerifyByChannel verifies a peer's signature on a message in the context
// of a specific channel
func (*naiveSecProvider) VerifyByChannel(_ common.ChannelID, _ api.PeerIdentityType, _, _ []byte) error {
return nil
}
func newCommInstanceOnlyWithMetrics(t *testing.T, commMetrics *metrics.CommMetrics, sec *naiveSecProvider,
gRPCServer *comm.GRPCServer, certs *common.TLSCertificates,
secureDialOpts api.PeerSecureDialOpts, dialOpts ...grpc.DialOption) Comm {
_, portString, err := net.SplitHostPort(gRPCServer.Address())
require.NoError(t, err)
endpoint := fmt.Sprintf("127.0.0.1:%s", portString)
id := []byte(endpoint)
identityMapper := identity.NewIdentityMapper(sec, id, noopPurgeIdentity, sec)
commInst, err := NewCommInstance(gRPCServer.Server(), certs, identityMapper, id, secureDialOpts,
sec, commMetrics, testCommConfig, dialOpts...)
require.NoError(t, err)
go func() {
err := gRPCServer.Start()
require.NoError(t, err)
}()
return &commGRPC{commInst.(*commImpl), gRPCServer}
}
type commGRPC struct {
*commImpl
gRPCServer *comm.GRPCServer
}
func (c *commGRPC) Stop() {
c.commImpl.Stop()
c.commImpl.idMapper.Stop()
c.gRPCServer.Stop()
}
func newCommInstanceOnly(t *testing.T, sec *naiveSecProvider,
gRPCServer *comm.GRPCServer, certs *common.TLSCertificates,
secureDialOpts api.PeerSecureDialOpts, dialOpts ...grpc.DialOption) Comm {
return newCommInstanceOnlyWithMetrics(t, disabledMetrics, sec, gRPCServer, certs, secureDialOpts, dialOpts...)
}
func newCommInstance(t *testing.T, sec *naiveSecProvider) (c Comm, port int) {
port, gRPCServer, certs, secureDialOpts, dialOpts := util.CreateGRPCLayer()
comm := newCommInstanceOnly(t, sec, gRPCServer, certs, secureDialOpts, dialOpts...)
return comm, port
}
type msgMutator func(*protoext.SignedGossipMessage) *protoext.SignedGossipMessage
type tlsType int
const (
none tlsType = iota
oneWayTLS
mutualTLS
)
func handshaker(port int, endpoint string, comm Comm, t *testing.T, connMutator msgMutator, connType tlsType) <-chan protoext.ReceivedMessage {
c := &commImpl{}
cert := GenerateCertificatesOrPanic()
tlsCfg := &tls.Config{
InsecureSkipVerify: true,
}
if connType == mutualTLS {
tlsCfg.Certificates = []tls.Certificate{cert}
}
ta := credentials.NewTLS(tlsCfg)
secureOpts := grpc.WithTransportCredentials(ta)
if connType == none {
secureOpts = grpc.WithInsecure()
}
acceptChan := comm.Accept(acceptAll)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
target := fmt.Sprintf("127.0.0.1:%d", port)
conn, err := grpc.DialContext(ctx, target, secureOpts, grpc.WithBlock())
require.NoError(t, err, "%v", err)
if err != nil {
return nil
}
cl := proto.NewGossipClient(conn)
stream, err := cl.GossipStream(context.Background())
require.NoError(t, err, "%v", err)
if err != nil {
return nil
}
var clientCertHash []byte
if len(tlsCfg.Certificates) > 0 {
clientCertHash = certHashFromRawCert(tlsCfg.Certificates[0].Certificate[0])
}
pkiID := common.PKIidType(endpoint)
require.NoError(t, err, "%v", err)
msg, _ := c.createConnectionMsg(pkiID, clientCertHash, []byte(endpoint), func(msg []byte) ([]byte, error) {
mac := hmac.New(sha256.New, hmacKey)
mac.Write(msg)
return mac.Sum(nil), nil
}, false)
// Mutate connection message to test negative paths
msg = connMutator(msg)
// Send your own connection message
stream.Send(msg.Envelope)
// Wait for connection message from the other side
envelope, err := stream.Recv()
if err != nil {
return acceptChan
}
require.NoError(t, err, "%v", err)
msg, err = protoext.EnvelopeToGossipMessage(envelope)
require.NoError(t, err, "%v", err)
require.Equal(t, []byte(target), msg.GetConn().PkiId)
require.Equal(t, extractCertificateHashFromContext(stream.Context()), msg.GetConn().TlsCertHash)
msg2Send := createGossipMsg()
nonce := uint64(rand.Int())
msg2Send.Nonce = nonce
go stream.Send(msg2Send.Envelope)
return acceptChan
}
func TestMutualParallelSendWithAck(t *testing.T) {
// This test tests concurrent and parallel sending of many (1000) messages
// from 2 instances to one another at the same time.
msgNum := 1000
comm1, port1 := newCommInstance(t, naiveSec)
comm2, port2 := newCommInstance(t, naiveSec)
defer comm1.Stop()
defer comm2.Stop()
acceptData := func(o interface{}) bool {
m := o.(protoext.ReceivedMessage).GetGossipMessage()
return protoext.IsDataMsg(m.GossipMessage)
}
inc1 := comm1.Accept(acceptData)
inc2 := comm2.Accept(acceptData)
// Send a message from comm1 to comm2, to make the instances establish a preliminary connection
comm1.Send(createGossipMsg(), remotePeer(port2))
// Wait for the message to be received in comm2
<-inc2
for i := 0; i < msgNum; i++ {
go comm1.SendWithAck(createGossipMsg(), time.Second*5, 1, remotePeer(port2))
}
for i := 0; i < msgNum; i++ {
go comm2.SendWithAck(createGossipMsg(), time.Second*5, 1, remotePeer(port1))
}
go func() {
for i := 0; i < msgNum; i++ {
<-inc1
}
}()
for i := 0; i < msgNum; i++ {
<-inc2
}
}
func getAvailablePort(t *testing.T) (port int, endpoint string, ll net.Listener) {
ll, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
endpoint = ll.Addr().String()
_, portS, err := net.SplitHostPort(endpoint)
require.NoError(t, err)
portInt, err := strconv.Atoi(portS)
require.NoError(t, err)
return portInt, endpoint, ll
}
func TestHandshake(t *testing.T) {
signer := func(msg []byte) ([]byte, error) {
mac := hmac.New(sha256.New, hmacKey)
mac.Write(msg)
return mac.Sum(nil), nil
}
mutator := func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
return msg
}
assertPositivePath := func(msg protoext.ReceivedMessage, endpoint string) {
expectedPKIID := common.PKIidType(endpoint)
require.Equal(t, expectedPKIID, msg.GetConnectionInfo().ID)
require.Equal(t, api.PeerIdentityType(endpoint), msg.GetConnectionInfo().Identity)
require.NotNil(t, msg.GetConnectionInfo().Auth)
sig, _ := (&naiveSecProvider{}).Sign(msg.GetConnectionInfo().Auth.SignedData)
require.Equal(t, sig, msg.GetConnectionInfo().Auth.Signature)
}
// Positive path 1 - check authentication without TLS
port, endpoint, ll := getAvailablePort(t)
s := grpc.NewServer()
id := []byte(endpoint)
idMapper := identity.NewIdentityMapper(naiveSec, id, noopPurgeIdentity, naiveSec)
inst, err := NewCommInstance(s, nil, idMapper, api.PeerIdentityType(endpoint), func() []grpc.DialOption {
return []grpc.DialOption{grpc.WithInsecure()}
}, naiveSec, disabledMetrics, testCommConfig)
go s.Serve(ll)
require.NoError(t, err)
var msg protoext.ReceivedMessage
_, tempEndpoint, tempL := getAvailablePort(t)
acceptChan := handshaker(port, tempEndpoint, inst, t, mutator, none)
select {
case <-time.After(time.Duration(time.Second * 4)):
require.FailNow(t, "Didn't receive a message, seems like handshake failed")
case msg = <-acceptChan:
}
require.Equal(t, common.PKIidType(tempEndpoint), msg.GetConnectionInfo().ID)
require.Equal(t, api.PeerIdentityType(tempEndpoint), msg.GetConnectionInfo().Identity)
sig, _ := (&naiveSecProvider{}).Sign(msg.GetConnectionInfo().Auth.SignedData)
require.Equal(t, sig, msg.GetConnectionInfo().Auth.Signature)
inst.Stop()
s.Stop()
ll.Close()
tempL.Close()
time.Sleep(time.Second)
comm, port := newCommInstance(t, naiveSec)
defer comm.Stop()
// Positive path 2: initiating peer sends its own certificate
_, tempEndpoint, tempL = getAvailablePort(t)
acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
select {
case <-time.After(time.Second * 2):
require.FailNow(t, "Didn't receive a message, seems like handshake failed")
case msg = <-acceptChan:
}
assertPositivePath(msg, tempEndpoint)
tempL.Close()
// Negative path: initiating peer doesn't send its own certificate
_, tempEndpoint, tempL = getAvailablePort(t)
acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, oneWayTLS)
time.Sleep(time.Second)
require.Equal(t, 0, len(acceptChan))
tempL.Close()
// Negative path, signature is wrong
_, tempEndpoint, tempL = getAvailablePort(t)
mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
msg.Signature = append(msg.Signature, 0)
return msg
}
acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
time.Sleep(time.Second)
require.Equal(t, 0, len(acceptChan))
tempL.Close()
// Negative path, the PKIid doesn't match the identity
_, tempEndpoint, tempL = getAvailablePort(t)
mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
msg.GetConn().PkiId = []byte(tempEndpoint)
// Sign the message again
msg.Sign(signer)
return msg
}
_, tempEndpoint2, tempL2 := getAvailablePort(t)
acceptChan = handshaker(port, tempEndpoint2, comm, t, mutator, mutualTLS)
time.Sleep(time.Second)
require.Equal(t, 0, len(acceptChan))
tempL.Close()
tempL2.Close()
// Negative path, the cert hash isn't what is expected
_, tempEndpoint, tempL = getAvailablePort(t)
mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
msg.GetConn().TlsCertHash = append(msg.GetConn().TlsCertHash, 0)
msg.Sign(signer)
return msg
}
acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
time.Sleep(time.Second)
require.Equal(t, 0, len(acceptChan))
tempL.Close()
// Negative path, no PKI-ID was sent
_, tempEndpoint, tempL = getAvailablePort(t)
mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
msg.GetConn().PkiId = nil
msg.Sign(signer)
return msg
}
acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
time.Sleep(time.Second)
require.Equal(t, 0, len(acceptChan))
tempL.Close()
// Negative path, connection message is of a different type
_, tempEndpoint, tempL = getAvailablePort(t)
mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
msg.Content = &proto.GossipMessage_Empty{
Empty: &proto.Empty{},
}
msg.Sign(signer)
return msg
}
acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
time.Sleep(time.Second)
require.Equal(t, 0, len(acceptChan))
tempL.Close()
// Negative path, the peer didn't respond to the handshake in due time
_, tempEndpoint, tempL = getAvailablePort(t)
mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
time.Sleep(time.Second * 5)
return msg
}
acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
time.Sleep(time.Second)
require.Equal(t, 0, len(acceptChan))
tempL.Close()
}
func TestConnectUnexpectedPeer(t *testing.T) {
// Scenarios: In both scenarios, comm1 connects to comm2 or comm3.
// and expects to see a PKI-ID which is equal to comm4's PKI-ID.
// The connection attempt would succeed or fail based on whether comm2 or comm3
// are in the same org as comm4
identityByPort := func(port int) api.PeerIdentityType {
return api.PeerIdentityType(fmt.Sprintf("127.0.0.1:%d", port))
}
customNaiveSec := &naiveSecProvider{}
comm1Port, gRPCServer1, certs1, secureDialOpts1, dialOpts1 := util.CreateGRPCLayer()
comm2Port, gRPCServer2, certs2, secureDialOpts2, dialOpts2 := util.CreateGRPCLayer()
comm3Port, gRPCServer3, certs3, secureDialOpts3, dialOpts3 := util.CreateGRPCLayer()
comm4Port, gRPCServer4, certs4, secureDialOpts4, dialOpts4 := util.CreateGRPCLayer()
customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm1Port)).Return(api.OrgIdentityType("O"))
customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm2Port)).Return(api.OrgIdentityType("A"))
customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm3Port)).Return(api.OrgIdentityType("B"))
customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm4Port)).Return(api.OrgIdentityType("A"))
comm1 := newCommInstanceOnly(t, customNaiveSec, gRPCServer1, certs1, secureDialOpts1, dialOpts1...)
comm2 := newCommInstanceOnly(t, naiveSec, gRPCServer2, certs2, secureDialOpts2, dialOpts2...)
comm3 := newCommInstanceOnly(t, naiveSec, gRPCServer3, certs3, secureDialOpts3, dialOpts3...)
comm4 := newCommInstanceOnly(t, naiveSec, gRPCServer4, certs4, secureDialOpts4, dialOpts4...)
defer comm1.Stop()
defer comm2.Stop()
defer comm3.Stop()
defer comm4.Stop()
messagesForComm1 := comm1.Accept(acceptAll)
messagesForComm2 := comm2.Accept(acceptAll)
messagesForComm3 := comm3.Accept(acceptAll)
// Have comm4 send a message to comm1
// in order for comm1 to know comm4
comm4.Send(createGossipMsg(), remotePeer(comm1Port))
<-messagesForComm1
// Close the connection with comm4
comm1.CloseConn(remotePeer(comm4Port))
// At this point, comm1 knows comm4's identity and organization
t.Run("Same organization", func(t *testing.T) {
unexpectedRemotePeer := remotePeer(comm2Port)
unexpectedRemotePeer.PKIID = remotePeer(comm4Port).PKIID
comm1.Send(createGossipMsg(), unexpectedRemotePeer)
select {
case <-messagesForComm2:
case <-time.After(time.Second * 5):
require.Fail(t, "Didn't receive a message within a timely manner")
util.PrintStackTrace()
}
})
t.Run("Unexpected organization", func(t *testing.T) {
unexpectedRemotePeer := remotePeer(comm3Port)
unexpectedRemotePeer.PKIID = remotePeer(comm4Port).PKIID
comm1.Send(createGossipMsg(), unexpectedRemotePeer)
select {
case <-messagesForComm3:
require.Fail(t, "Message shouldn't have been received")
case <-time.After(time.Second * 5):
}
})
}
func TestGetConnectionInfo(t *testing.T) {
comm1, port1 := newCommInstance(t, naiveSec)
comm2, _ := newCommInstance(t, naiveSec)
defer comm1.Stop()
defer comm2.Stop()
m1 := comm1.Accept(acceptAll)
comm2.Send(createGossipMsg(), remotePeer(port1))
select {
case <-time.After(time.Second * 10):
t.Fatal("Didn't receive a message in time")
case msg := <-m1:
require.Equal(t, comm2.GetPKIid(), msg.GetConnectionInfo().ID)
require.NotNil(t, msg.GetSourceEnvelope())
}
}
func TestCloseConn(t *testing.T) {
comm1, port1 := newCommInstance(t, naiveSec)
defer comm1.Stop()
acceptChan := comm1.Accept(acceptAll)
cert := GenerateCertificatesOrPanic()
tlsCfg := &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{cert},
}
ta := credentials.NewTLS(tlsCfg)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
target := fmt.Sprintf("127.0.0.1:%d", port1)
conn, err := grpc.DialContext(ctx, target, grpc.WithTransportCredentials(ta), grpc.WithBlock())
require.NoError(t, err, "%v", err)
cl := proto.NewGossipClient(conn)
stream, err := cl.GossipStream(context.Background())
require.NoError(t, err, "%v", err)
c := &commImpl{}
tlsCertHash := certHashFromRawCert(tlsCfg.Certificates[0].Certificate[0])
connMsg, _ := c.createConnectionMsg(common.PKIidType("pkiID"), tlsCertHash, api.PeerIdentityType("pkiID"), func(msg []byte) ([]byte, error) {
mac := hmac.New(sha256.New, hmacKey)
mac.Write(msg)
return mac.Sum(nil), nil
}, false)
require.NoError(t, stream.Send(connMsg.Envelope))
stream.Send(createGossipMsg().Envelope)
select {
case <-acceptChan:
case <-time.After(time.Second):
require.Fail(t, "Didn't receive a message within a timely period")
}
comm1.CloseConn(&RemotePeer{PKIID: common.PKIidType("pkiID")})
time.Sleep(time.Second * 10)
gotErr := false
msg2Send := createGossipMsg()
msg2Send.GetDataMsg().Payload = &proto.Payload{
Data: make([]byte, 1024*1024),
}
protoext.NoopSign(msg2Send.GossipMessage)
for i := 0; i < DefRecvBuffSize; i++ {
err := stream.Send(msg2Send.Envelope)
if err != nil {
gotErr = true
break
}
}
require.True(t, gotErr, "Should have failed because connection is closed")
}
// TestCommSend makes sure that enough messages get through
// eventually. Comm.Send() is both asynchronous and best-effort, so this test
// case assumes some will fail, but that eventually enough messages will get
// through that the test will end.
func TestCommSend(t *testing.T) {
sendMessages := func(c Comm, peer *RemotePeer, stopChan <-chan struct{}) {
ticker := time.NewTicker(time.Millisecond)
defer ticker.Stop()
for {
emptyMsg := createGossipMsg()
select {
case <-stopChan:
return
case <-ticker.C:
c.Send(emptyMsg, peer)
}
}
}
comm1, port1 := newCommInstance(t, naiveSec)
comm2, port2 := newCommInstance(t, naiveSec)
defer comm1.Stop()
defer comm2.Stop()
// Create the receive channel before sending the messages
ch1 := comm1.Accept(acceptAll)
ch2 := comm2.Accept(acceptAll)
// control channels for background senders
stopch1 := make(chan struct{})
stopch2 := make(chan struct{})
go sendMessages(comm1, remotePeer(port2), stopch1)
go sendMessages(comm2, remotePeer(port1), stopch2)
c1received := 0
c2received := 0
// hopefully in some runs we'll fill both send and receive buffers and
// drop overflowing messages, but still finish, because the endless
// stream of messages inexorably gets through unless something is very
// broken.
totalMessagesReceived := (DefSendBuffSize + DefRecvBuffSize) * 2
timer := time.NewTimer(30 * time.Second)
defer timer.Stop()
RECV:
for {
select {
case <-ch1:
c1received++
if c1received == totalMessagesReceived {
close(stopch2)
}
case <-ch2:
c2received++
if c2received == totalMessagesReceived {
close(stopch1)
}
case <-timer.C:
t.Fatalf("timed out waiting for messages to be received.\nc1 got %d messages\nc2 got %d messages", c1received, c2received)
default:
if c1received >= totalMessagesReceived && c2received >= totalMessagesReceived {
break RECV
}
}
}
t.Logf("c1 got %d messages\nc2 got %d messages", c1received, c2received)
}
type nonResponsivePeer struct {
*grpc.Server
port int
}
func newNonResponsivePeer(t *testing.T) *nonResponsivePeer {
port, gRPCServer, _, _, _ := util.CreateGRPCLayer()
nrp := &nonResponsivePeer{
Server: gRPCServer.Server(),
port: port,
}
proto.RegisterGossipServer(gRPCServer.Server(), nrp)
return nrp
}
func (bp *nonResponsivePeer) Ping(context.Context, *proto.Empty) (*proto.Empty, error) {
time.Sleep(time.Second * 15)
return &proto.Empty{}, nil
}
func (bp *nonResponsivePeer) GossipStream(stream proto.Gossip_GossipStreamServer) error {
return nil
}
func (bp *nonResponsivePeer) stop() {
bp.Server.Stop()
}
func TestNonResponsivePing(t *testing.T) {
c, _ := newCommInstance(t, naiveSec)
defer c.Stop()
nonRespPeer := newNonResponsivePeer(t)
defer nonRespPeer.stop()
s := make(chan struct{})
go func() {
c.Probe(remotePeer(nonRespPeer.port))
s <- struct{}{}
}()
select {
case <-time.After(time.Second * 10):
require.Fail(t, "Request wasn't cancelled on time")
case <-s:
}
}
func TestResponses(t *testing.T) {
comm1, port1 := newCommInstance(t, naiveSec)
comm2, _ := newCommInstance(t, naiveSec)
defer comm1.Stop()
defer comm2.Stop()
wg := sync.WaitGroup{}
msg := createGossipMsg()
wg.Add(1)
go func() {
inChan := comm1.Accept(acceptAll)
wg.Done()
for m := range inChan {
reply := createGossipMsg()
reply.Nonce = m.GetGossipMessage().Nonce + 1
m.Respond(reply.GossipMessage)
}
}()
expectedNOnce := uint64(msg.Nonce + 1)
responsesFromComm1 := comm2.Accept(acceptAll)
ticker := time.NewTicker(10 * time.Second)
wg.Wait()
comm2.Send(msg, remotePeer(port1))
select {
case <-ticker.C:
require.Fail(t, "Haven't got response from comm1 within a timely manner")
break
case resp := <-responsesFromComm1:
ticker.Stop()
require.Equal(t, expectedNOnce, resp.GetGossipMessage().Nonce)
break
}
}
// TestAccept makes sure that accept filters work. The probability of the parity
// of all nonces being 0 or 1 is very low.
func TestAccept(t *testing.T) {
comm1, port1 := newCommInstance(t, naiveSec)
comm2, _ := newCommInstance(t, naiveSec)
evenNONCESelector := func(m interface{}) bool {
return m.(protoext.ReceivedMessage).GetGossipMessage().Nonce%2 == 0
}
oddNONCESelector := func(m interface{}) bool {
return m.(protoext.ReceivedMessage).GetGossipMessage().Nonce%2 != 0
}
evenNONCES := comm1.Accept(evenNONCESelector)
oddNONCES := comm1.Accept(oddNONCESelector)
var evenResults []uint64
var oddResults []uint64
out := make(chan uint64)
sem := make(chan struct{})
readIntoSlice := func(a *[]uint64, ch <-chan protoext.ReceivedMessage) {
for m := range ch {
*a = append(*a, m.GetGossipMessage().Nonce)
select {
case out <- m.GetGossipMessage().Nonce:
default: // avoid blocking when we stop reading from out
}
}
sem <- struct{}{}
}
go readIntoSlice(&evenResults, evenNONCES)
go readIntoSlice(&oddResults, oddNONCES)
stopSend := make(chan struct{})
go func() {
for {
select {
case <-stopSend:
return
default:
comm2.Send(createGossipMsg(), remotePeer(port1))
}
}
}()
waitForMessages(t, out, (DefSendBuffSize+DefRecvBuffSize)*2, "Didn't receive all messages sent")
close(stopSend)
comm1.Stop()
comm2.Stop()
<-sem
<-sem
t.Logf("%d even nonces received", len(evenResults))
t.Logf("%d odd nonces received", len(oddResults))
require.NotEmpty(t, evenResults)
require.NotEmpty(t, oddResults)
remainderPredicate := func(a []uint64, rem uint64) {
for _, n := range a {
require.Equal(t, n%2, rem)
}
}
remainderPredicate(evenResults, 0)
remainderPredicate(oddResults, 1)
}
func TestReConnections(t *testing.T) {
comm1, port1 := newCommInstance(t, naiveSec)
comm2, port2 := newCommInstance(t, naiveSec)
reader := func(out chan uint64, in <-chan protoext.ReceivedMessage) {
for {
msg := <-in
if msg == nil {
return
}
out <- msg.GetGossipMessage().Nonce
}
}
out1 := make(chan uint64, 10)
out2 := make(chan uint64, 10)
go reader(out1, comm1.Accept(acceptAll))
go reader(out2, comm2.Accept(acceptAll))
// comm1 connects to comm2
comm1.Send(createGossipMsg(), remotePeer(port2))
waitForMessages(t, out2, 1, "Comm2 didn't receive a message from comm1 in a timely manner")
// comm2 sends to comm1
comm2.Send(createGossipMsg(), remotePeer(port1))
waitForMessages(t, out1, 1, "Comm1 didn't receive a message from comm2 in a timely manner")
comm1.Stop()
comm1, port1 = newCommInstance(t, naiveSec)
out1 = make(chan uint64, 1)
go reader(out1, comm1.Accept(acceptAll))
comm2.Send(createGossipMsg(), remotePeer(port1))
waitForMessages(t, out1, 1, "Comm1 didn't receive a message from comm2 in a timely manner")
comm1.Stop()
comm2.Stop()
}
func TestProbe(t *testing.T) {
comm1, port1 := newCommInstance(t, naiveSec)
defer comm1.Stop()
comm2, port2 := newCommInstance(t, naiveSec)
time.Sleep(time.Duration(1) * time.Second)
require.NoError(t, comm1.Probe(remotePeer(port2)))
_, err := comm1.Handshake(remotePeer(port2))
require.NoError(t, err)
tempPort, _, ll := getAvailablePort(t)
defer ll.Close()
require.Error(t, comm1.Probe(remotePeer(tempPort)))
_, err = comm1.Handshake(remotePeer(tempPort))
require.Error(t, err)
comm2.Stop()
time.Sleep(time.Duration(1) * time.Second)
require.Error(t, comm1.Probe(remotePeer(port2)))
_, err = comm1.Handshake(remotePeer(port2))
require.Error(t, err)
comm2, port2 = newCommInstance(t, naiveSec)
defer comm2.Stop()
time.Sleep(time.Duration(1) * time.Second)
require.NoError(t, comm2.Probe(remotePeer(port1)))
_, err = comm2.Handshake(remotePeer(port1))
require.NoError(t, err)
require.NoError(t, comm1.Probe(remotePeer(port2)))
_, err = comm1.Handshake(remotePeer(port2))
require.NoError(t, err)
// Now try a deep probe with an expected PKI-ID that doesn't match
wrongRemotePeer := remotePeer(port2)
if wrongRemotePeer.PKIID[0] == 0 {
wrongRemotePeer.PKIID[0] = 1
} else {
wrongRemotePeer.PKIID[0] = 0
}
_, err = comm1.Handshake(wrongRemotePeer)
require.Error(t, err)
// Try a deep probe with a nil PKI-ID
endpoint := fmt.Sprintf("127.0.0.1:%d", port2)
id, err := comm1.Handshake(&RemotePeer{Endpoint: endpoint})
require.NoError(t, err)
require.Equal(t, api.PeerIdentityType(endpoint), id)
}
func TestPresumedDead(t *testing.T) {
comm1, _ := newCommInstance(t, naiveSec)
comm2, port2 := newCommInstance(t, naiveSec)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Wait()
comm1.Send(createGossipMsg(), remotePeer(port2))
}()
ticker := time.NewTicker(time.Duration(10) * time.Second)
acceptCh := comm2.Accept(acceptAll)
wg.Done()
select {
case <-acceptCh:
ticker.Stop()
case <-ticker.C:
require.Fail(t, "Didn't get first message")
}
comm2.Stop()
go func() {
for i := 0; i < 5; i++ {
comm1.Send(createGossipMsg(), remotePeer(port2))
time.Sleep(time.Millisecond * 200)
}
}()
ticker = time.NewTicker(time.Second * time.Duration(3))
select {
case <-ticker.C:
require.Fail(t, "Didn't get a presumed dead message within a timely manner")
break
case <-comm1.PresumedDead():
ticker.Stop()
break
}
}
func TestReadFromStream(t *testing.T) {
stream := &gmocks.MockStream{}
stream.On("CloseSend").Return(nil)
stream.On("Recv").Return(&proto.Envelope{Payload: []byte{1}}, nil).Once()
stream.On("Recv").Return(nil, errors.New("stream closed")).Once()
conn := newConnection(nil, nil, stream, disabledMetrics, ConnConfig{1, 1})
conn.logger = flogging.MustGetLogger("test")
errChan := make(chan error, 2)
msgChan := make(chan *protoext.SignedGossipMessage, 1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
conn.readFromStream(errChan, msgChan)
}()
select {
case <-msgChan:
require.Fail(t, "malformed message shouldn't have been received")
case <-time.After(time.Millisecond * 100):
require.Len(t, errChan, 1)
}
conn.close()
wg.Wait()
}
func TestSendBadEnvelope(t *testing.T) {
comm1, port := newCommInstance(t, naiveSec)
defer comm1.Stop()
stream, err := establishSession(t, port)
require.NoError(t, err)
inc := comm1.Accept(acceptAll)
goodMsg := createGossipMsg()
err = stream.Send(goodMsg.Envelope)
require.NoError(t, err)
select {
case goodMsgReceived := <-inc:
require.Equal(t, goodMsg.Envelope.Payload, goodMsgReceived.GetSourceEnvelope().Payload)
case <-time.After(time.Minute):
require.Fail(t, "Didn't receive message within a timely manner")
return
}
// Next, we corrupt a message and send it until the stream is closed forcefully from the remote peer
start := time.Now()
for {
badMsg := createGossipMsg()
badMsg.Envelope.Payload = []byte{1}
err = stream.Send(badMsg.Envelope)
if err != nil {
require.Equal(t, io.EOF, err)
break
}
if time.Now().After(start.Add(time.Second * 30)) {
require.Fail(t, "Didn't close stream within a timely manner")
return
}
}
}
func establishSession(t *testing.T, port int) (proto.Gossip_GossipStreamClient, error) {
cert := GenerateCertificatesOrPanic()
secureOpts := grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{cert},
}))
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
endpoint := fmt.Sprintf("127.0.0.1:%d", port)
conn, err := grpc.DialContext(ctx, endpoint, secureOpts, grpc.WithBlock())
require.NoError(t, err, "%v", err)
if err != nil {
return nil, err
}
cl := proto.NewGossipClient(conn)
stream, err := cl.GossipStream(context.Background())
require.NoError(t, err, "%v", err)
if err != nil {
return nil, err
}
clientCertHash := certHashFromRawCert(cert.Certificate[0])
pkiID := common.PKIidType([]byte{1, 2, 3})
c := &commImpl{}
require.NoError(t, err, "%v", err)
msg, _ := c.createConnectionMsg(pkiID, clientCertHash, []byte{1, 2, 3}, func(msg []byte) ([]byte, error) {
mac := hmac.New(sha256.New, hmacKey)
mac.Write(msg)
return mac.Sum(nil), nil
}, false)
// Send your own connection message
stream.Send(msg.Envelope)
// Wait for connection message from the other side
envelope, err := stream.Recv()
if err != nil {
return nil, err
}
require.NotNil(t, envelope)
return stream, nil
}
func createGossipMsg() *protoext.SignedGossipMessage {
msg, _ := protoext.NoopSign(&proto.GossipMessage{
Tag: proto.GossipMessage_EMPTY,
Nonce: uint64(rand.Int()),
Content: &proto.GossipMessage_DataMsg{
DataMsg: &proto.DataMessage{},
},
})
return msg
}
func remotePeer(port int) *RemotePeer {
endpoint := fmt.Sprintf("127.0.0.1:%d", port)
return &RemotePeer{Endpoint: endpoint, PKIID: []byte(endpoint)}
}
func waitForMessages(t *testing.T, msgChan chan uint64, count int, errMsg string) {
c := 0
waiting := true
ticker := time.NewTicker(time.Duration(10) * time.Second)
for waiting {
select {
case <-msgChan:
c++
if c == count {
waiting = false
}
case <-ticker.C:
waiting = false
}
}
require.Equal(t, count, c, errMsg)
}
func TestConcurrentCloseSend(t *testing.T) {
var stopping int32
comm1, _ := newCommInstance(t, naiveSec)
comm2, port2 := newCommInstance(t, naiveSec)
m := comm2.Accept(acceptAll)
comm1.Send(createGossipMsg(), remotePeer(port2))
<-m
ready := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
comm1.Send(createGossipMsg(), remotePeer(port2))
close(ready)
for atomic.LoadInt32(&stopping) == int32(0) {
comm1.Send(createGossipMsg(), remotePeer(port2))
}
}()
<-ready
comm2.Stop()
atomic.StoreInt32(&stopping, int32(1))
<-done
}