Skip to content

Commit

Permalink
Exclusively use types.Provider with native types (#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones committed Jul 18, 2023
1 parent bad352c commit 215c1af
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 125 deletions.
116 changes: 6 additions & 110 deletions ext/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
structpb "google.golang.org/protobuf/types/known/structpb"
)

Expand Down Expand Up @@ -141,68 +139,20 @@ func (tp *nativeTypeProvider) FindIdent(typeName string) (ref.Val, bool) {
return tp.baseProvider.FindIdent(typeName)
}

// FindType looks up CEL type-checker type definition by qualified identifier, and if not found
// FindStructType looks up the CEL type definition by qualified identifier, and if not found
// proxies to the composed types.Provider.
func (tp *nativeTypeProvider) FindType(typeName string) (*exprpb.Type, bool) {
func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool) {
if _, found := tp.nativeTypes[typeName]; found {
return decls.NewTypeType(decls.NewObjectType(typeName)), true
return types.NewTypeTypeWithParam(types.NewObjectType(typeName)), true
}
if celType, found := tp.baseProvider.FindStructType(typeName); found {
et, err := types.TypeToExprType(celType)
if err != nil {
return nil, false
}
return et, true
return celType, true
}
return nil, false
}

func (tp *nativeTypeProvider) FieldStructType(typeName string) (*types.Type, bool) {
return tp.baseProvider.FindStructType(typeName)
}

// FindFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed ref.TypeProvider
func (tp *nativeTypeProvider) FindFieldType(typeName, fieldName string) (*ref.FieldType, bool) {
t, found := tp.nativeTypes[typeName]
if !found {
cft, found := tp.baseProvider.FindStructFieldType(typeName, fieldName)
if !found {
return nil, false
}
et, err := types.TypeToExprType(cft.Type)
if err != nil {
return nil, false
}
return &ref.FieldType{
Type: et,
IsSet: cft.IsSet,
GetFrom: cft.GetFrom,
}, true
}
refField, isDefined := t.hasField(fieldName)
if !found || !isDefined {
return nil, false
}
exprType, ok := convertToExprType(refField.Type)
if !ok {
return nil, false
}
return &ref.FieldType{
Type: exprType,
IsSet: func(obj any) bool {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := refVal.FieldByName(fieldName)
return !refField.IsZero()
},
GetFrom: func(obj any) (any, error) {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := refVal.FieldByName(fieldName)
return getFieldValue(tp, refField), nil
},
}, true
}

// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed types.Provider
func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
t, found := tp.nativeTypes[typeName]
if !found {
Expand Down Expand Up @@ -347,60 +297,6 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) {
return nil, false
}

// convertToExprType converts the Golang reflect.Type to a protobuf exprpb.Type.
func convertToExprType(refType reflect.Type) (*exprpb.Type, bool) {
switch refType.Kind() {
case reflect.Bool:
return decls.Bool, true
case reflect.Float32, reflect.Float64:
return decls.Double, true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if refType == durationType {
return decls.Duration, true
}
return decls.Int, true
case reflect.String:
return decls.String, true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return decls.Uint, true
case reflect.Array, reflect.Slice:
refElem := refType.Elem()
if refElem == reflect.TypeOf(byte(0)) {
return decls.Bytes, true
}
elemType, ok := convertToExprType(refElem)
if !ok {
return nil, false
}
return decls.NewListType(elemType), true
case reflect.Map:
keyType, ok := convertToExprType(refType.Key())
if !ok {
return nil, false
}
// Ensure the key type is a int, bool, uint, string
elemType, ok := convertToExprType(refType.Elem())
if !ok {
return nil, false
}
return decls.NewMapType(keyType, elemType), true
case reflect.Struct:
if refType == timestampType {
return decls.Timestamp, true
}
return decls.NewObjectType(
fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
), true
case reflect.Pointer:
if refType.Implements(pbMsgInterfaceType) {
pbMsg := reflect.New(refType.Elem()).Interface().(protoreflect.ProtoMessage)
return decls.NewObjectType(string(pbMsg.ProtoReflect().Descriptor().FullName())), true
}
return convertToExprType(refType.Elem())
}
return nil, false
}

func newNativeObject(adapter types.Adapter, val any, refValue reflect.Value) ref.Val {
valType, err := newNativeType(refValue.Type())
if err != nil {
Expand Down
17 changes: 2 additions & 15 deletions ext/native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ func TestNativeTypesErrors(t *testing.T) {

func TestNativeTypesConvertToNative(t *testing.T) {
env := testNativeEnv(t, NativeTypes(reflect.TypeOf(TestNestedType{})))
adapter := env.TypeAdapter()
adapter := env.CELTypeAdapter()
conversions := []struct {
in any
out any
Expand Down Expand Up @@ -455,22 +455,9 @@ func TestNativeTypesConvertToNative(t *testing.T) {
}
}

func TestNativeTypesConvertToExprTypeErrors(t *testing.T) {
unsupportedTypes := []reflect.Type{
reflect.TypeOf(make(map[string]chan string)),
reflect.TypeOf(make([]chan int, 0)),
reflect.TypeOf(make(map[chan int]bool, 0)),
}
for _, ut := range unsupportedTypes {
if _, converted := convertToExprType(ut); converted {
t.Errorf("convertToExprType(%v) succeeded when it should have failed", ut)
}
}
}

func TestConvertToTypeErrors(t *testing.T) {
env := testNativeEnv(t, NativeTypes(reflect.TypeOf(TestNestedType{})))
adapter := env.TypeAdapter()
adapter := env.CELTypeAdapter()
conversions := []struct {
in any
out any
Expand Down

0 comments on commit 215c1af

Please sign in to comment.