188 lines
5.2 KiB
Go
188 lines
5.2 KiB
Go
/*
|
|
Copyright IBM Corp. All Rights Reserved.
|
|
|
|
SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
package grpclogging
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
"go.uber.org/zap/zapcore"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/peer"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
// Leveler returns a zap level to use when logging from a grpc interceptor.
|
|
type Leveler interface {
|
|
Level(ctx context.Context, fullMethod string) zapcore.Level
|
|
}
|
|
|
|
// PayloadLeveler gets the level to use when logging grpc message payloads.
|
|
type PayloadLeveler interface {
|
|
PayloadLevel(ctx context.Context, fullMethod string) zapcore.Level
|
|
}
|
|
|
|
//go:generate counterfeiter -o fakes/leveler.go --fake-name Leveler . LevelerFunc
|
|
|
|
type LevelerFunc func(ctx context.Context, fullMethod string) zapcore.Level
|
|
|
|
func (l LevelerFunc) Level(ctx context.Context, fullMethod string) zapcore.Level {
|
|
return l(ctx, fullMethod)
|
|
}
|
|
|
|
func (l LevelerFunc) PayloadLevel(ctx context.Context, fullMethod string) zapcore.Level {
|
|
return l(ctx, fullMethod)
|
|
}
|
|
|
|
// DefaultPayloadLevel is default level to use when logging payloads
|
|
const DefaultPayloadLevel = zapcore.Level(zapcore.DebugLevel - 1)
|
|
|
|
type options struct {
|
|
Leveler
|
|
PayloadLeveler
|
|
}
|
|
|
|
type Option func(o *options)
|
|
|
|
func WithLeveler(l Leveler) Option {
|
|
return func(o *options) { o.Leveler = l }
|
|
}
|
|
|
|
func WithPayloadLeveler(l PayloadLeveler) Option {
|
|
return func(o *options) { o.PayloadLeveler = l }
|
|
}
|
|
|
|
func applyOptions(opts ...Option) *options {
|
|
o := &options{
|
|
Leveler: LevelerFunc(func(context.Context, string) zapcore.Level { return zapcore.InfoLevel }),
|
|
PayloadLeveler: LevelerFunc(func(context.Context, string) zapcore.Level { return DefaultPayloadLevel }),
|
|
}
|
|
for _, opt := range opts {
|
|
opt(o)
|
|
}
|
|
return o
|
|
}
|
|
|
|
// Levelers will be required and should be provided with the full method info
|
|
|
|
func UnaryServerInterceptor(logger *zap.Logger, opts ...Option) grpc.UnaryServerInterceptor {
|
|
o := applyOptions(opts...)
|
|
|
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
|
logger := logger
|
|
startTime := time.Now()
|
|
|
|
fields := getFields(ctx, info.FullMethod)
|
|
logger = logger.With(fields...)
|
|
ctx = WithFields(ctx, fields)
|
|
|
|
payloadLogger := logger.Named("payload")
|
|
payloadLevel := o.PayloadLevel(ctx, info.FullMethod)
|
|
if ce := payloadLogger.Check(payloadLevel, "received unary request"); ce != nil {
|
|
ce.Write(ProtoMessage("message", req))
|
|
}
|
|
|
|
resp, err := handler(ctx, req)
|
|
|
|
if ce := payloadLogger.Check(payloadLevel, "sending unary response"); ce != nil && err == nil {
|
|
ce.Write(ProtoMessage("message", resp))
|
|
}
|
|
|
|
if ce := logger.Check(o.Level(ctx, info.FullMethod), "unary call completed"); ce != nil {
|
|
st, _ := status.FromError(err)
|
|
ce.Write(
|
|
Error(err),
|
|
zap.Stringer("grpc.code", st.Code()),
|
|
zap.Duration("grpc.call_duration", time.Since(startTime)),
|
|
)
|
|
}
|
|
|
|
return resp, err
|
|
}
|
|
}
|
|
|
|
func StreamServerInterceptor(logger *zap.Logger, opts ...Option) grpc.StreamServerInterceptor {
|
|
o := applyOptions(opts...)
|
|
|
|
return func(service interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
|
logger := logger
|
|
ctx := stream.Context()
|
|
startTime := time.Now()
|
|
|
|
fields := getFields(ctx, info.FullMethod)
|
|
logger = logger.With(fields...)
|
|
ctx = WithFields(ctx, fields)
|
|
|
|
wrappedStream := &serverStream{
|
|
ServerStream: stream,
|
|
context: ctx,
|
|
payloadLogger: logger.Named("payload"),
|
|
payloadLevel: o.PayloadLevel(ctx, info.FullMethod),
|
|
}
|
|
|
|
err := handler(service, wrappedStream)
|
|
if ce := logger.Check(o.Level(ctx, info.FullMethod), "streaming call completed"); ce != nil {
|
|
st, _ := status.FromError(err)
|
|
ce.Write(
|
|
Error(err),
|
|
zap.Stringer("grpc.code", st.Code()),
|
|
zap.Duration("grpc.call_duration", time.Since(startTime)),
|
|
)
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
func getFields(ctx context.Context, method string) []zapcore.Field {
|
|
var fields []zap.Field
|
|
if parts := strings.Split(method, "/"); len(parts) == 3 {
|
|
fields = append(fields, zap.String("grpc.service", parts[1]), zap.String("grpc.method", parts[2]))
|
|
}
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
fields = append(fields, zap.Time("grpc.request_deadline", deadline))
|
|
}
|
|
if p, ok := peer.FromContext(ctx); ok {
|
|
fields = append(fields, zap.String("grpc.peer_address", p.Addr.String()))
|
|
if ti, ok := p.AuthInfo.(credentials.TLSInfo); ok {
|
|
if len(ti.State.PeerCertificates) > 0 {
|
|
cert := ti.State.PeerCertificates[0]
|
|
fields = append(fields, zap.String("grpc.peer_subject", cert.Subject.String()))
|
|
}
|
|
}
|
|
}
|
|
return fields
|
|
}
|
|
|
|
type serverStream struct {
|
|
grpc.ServerStream
|
|
context context.Context
|
|
payloadLogger *zap.Logger
|
|
payloadLevel zapcore.Level
|
|
}
|
|
|
|
func (ss *serverStream) Context() context.Context {
|
|
return ss.context
|
|
}
|
|
|
|
func (ss *serverStream) SendMsg(msg interface{}) error {
|
|
if ce := ss.payloadLogger.Check(ss.payloadLevel, "sending stream message"); ce != nil {
|
|
ce.Write(ProtoMessage("message", msg))
|
|
}
|
|
return ss.ServerStream.SendMsg(msg)
|
|
}
|
|
|
|
func (ss *serverStream) RecvMsg(msg interface{}) error {
|
|
err := ss.ServerStream.RecvMsg(msg)
|
|
if ce := ss.payloadLogger.Check(ss.payloadLevel, "received stream message"); ce != nil {
|
|
ce.Write(ProtoMessage("message", msg))
|
|
}
|
|
return err
|
|
}
|