proxy/args.go
package proxy
import (
"context"
"fmt"
"reflect"
"google.golang.org/grpc"
)
// Beware, this is where stuff gets super vague thanks to the magic of reflection
func validateArgs(expected, found reflect.Type, pattern apiMethodPattern) error {
// All this to get a proper array out of these reflection types
expectedInLen := expected.NumIn()
expectedOutLen := expected.NumOut()
foundInLen := found.NumIn()
foundOutLen := found.NumOut()
expectedIn := []reflect.Type{}
for i := 0; i < expectedInLen; i++ {
expectedIn = append(expectedIn, expected.In(i))
}
expectedOut := []reflect.Type{}
for i := 0; i < expectedOutLen; i++ {
expectedOut = append(expectedOut, expected.Out(i))
}
foundIn := []reflect.Type{}
for i := 0; i < foundInLen; i++ {
foundIn = append(foundIn, found.In(i))
}
foundOut := []reflect.Type{}
for i := 0; i < foundOutLen; i++ {
foundOut = append(foundOut, found.Out(i))
}
switch pattern {
case apiMethodPatternStructStream:
// API: req, stream_server -> error
// Client: ctx, req, opts -> stream_client, error
if expectedInLen != 3 || expectedOutLen != 1 || foundInLen != 4 || foundOutLen != 2 {
return fmt.Errorf("pattern was server-side streaming but real function did not meet that. expected api_in=3,api_out=1,real_in=4,real_out=2, got api_in=%d,api_out=%d,real_in=%d,real_out=%d", expectedInLen, expectedOutLen, foundInLen, foundOutLen)
}
// TODO: better more-specific type-checking all round, maybe? Though this should all be used through code-gen
return nil
}
if expectedInLen < 2 || foundInLen < 2 {
return fmt.Errorf("cannot exclude receiver from argument checks if receiver is the only argument: expected >= 2 input argments, found %d and %d", expectedInLen, foundInLen)
}
if !isStructPtr(expectedIn[0]) || !isStructPtr(foundIn[0]) {
return fmt.Errorf("no receiver")
}
// Don't check receivers, those don't have to be the same type
err := typesMatch(expectedIn[1:], foundIn[1:])
if err != nil {
return err
}
err = typesMatch(expectedOut, foundOut)
return err
}
func typesMatch(expected, found []reflect.Type) error {
// Account for both patterns
if len(expected) != len(found)-1 && len(expected) != len(found) {
return fmt.Errorf("argument lengths did not match: expected %d but found %d", len(expected), len(found))
}
for i := range expected {
if expected[i].Kind() != found[i].Kind() {
return fmt.Errorf("argments mismatch in position %d: %s vs %s", i, expected[i].Kind(), found[i].Kind())
}
}
return nil
}
// isStructPtr returns true if the pointer stack exists and resolves to a struct
func isStructPtr(in reflect.Type) bool {
for in.Kind() == reflect.Ptr {
in = in.Elem()
if in.Kind() == reflect.Struct {
return true
}
}
return false
}
func isContext(in reflect.Type) bool {
return in.Implements(reflect.TypeOf((*context.Context)(nil)).Elem())
}
func isError(in reflect.Type) bool {
return in.Implements(reflect.TypeOf((*error)(nil)).Elem())
}
func isOutStream(in reflect.Type) bool {
sendMethod, exists := in.MethodByName("Send")
if !exists {
return false
}
send := sendMethod.Type
return in.Implements(reflect.TypeOf((*grpc.ServerStream)(nil)).Elem()) && send.NumIn() == 1 && send.NumOut() == 1 && isStructPtr(send.In(0)) && isError(send.Out(0))
}
func isInStream(in reflect.Type) bool {
recvMethod, exists := in.MethodByName("Recv")
if !exists {
return false
}
recv := recvMethod.Type
return in.Implements(reflect.TypeOf((*grpc.ServerStream)(nil)).Elem()) && recv.NumIn() == 0 && recv.NumOut() == 2 && isStructPtr(recv.Out(0)) && isError(recv.Out(1))
}
// SendAndClose only applies to StreamStruct patterns
func hasSendAndClose(in reflect.Type) bool {
sendCloseMethod, exists := in.MethodByName("SendAndClose")
if !exists {
return false
}
send := sendCloseMethod.Type
return send.NumIn() == 1 && send.NumOut() == 1 && isStructPtr(send.In(0)) && isError(send.Out(0))
}
func getPattern(args reflect.Type) (pattern apiMethodPattern) {
defer func() {
if r := recover(); r != nil {
// Panic means something wasn't expected, which means this isn't a known pattern
pattern = apiMethodPatternUnknown
}
}()
// The defer above means we can freely access arguments without checking lengths, as long as it complies with all patterns
if isStructPtr(args.In(0)) {
// Pointer receiver checked, filter by first input argument type
switch {
case isContext(args.In(1)):
// We've got an explicit context, this can only be StructStruct, now we just need to confirm
if args.NumIn() == 3 && isStructPtr(args.In(2)) && args.NumOut() == 2 && isStructPtr(args.Out(0)) && isError(args.Out(1)) {
pattern = apiMethodPatternStructStruct
}
case isStructPtr(args.In(1)):
// This can only be StructStream
if args.NumIn() == 3 && isOutStream(args.In(2)) && args.NumOut() == 1 && isError(args.Out(0)) {
pattern = apiMethodPatternStructStream
}
case isInStream(args.In(1)):
// Either StreamStruct or StreamStream
if args.NumIn() == 2 && args.NumOut() == 1 && isError(args.Out(0)) {
switch {
case hasSendAndClose(args.In(1)):
// StreamStruct
pattern = apiMethodPatternStreamStruct
case isOutStream(args.In(1)):
// StreamStream
pattern = apiMethodPatternStreamStream
}
}
}
}
return
}