diff --git a/transform/interface-lowering.go b/transform/interface-lowering.go index 53e248d3a6..6aa3bb1347 100644 --- a/transform/interface-lowering.go +++ b/transform/interface-lowering.go @@ -292,58 +292,40 @@ func (p *lowerInterfacesPass) run() error { methodSet := use.Operand(1).Operand(0) // global variable itf := p.interfaces[methodSet.Name()] - if len(itf.types) == 0 { - // This method call is impossible: no type implements this - // interface. In fact, the previous type assert that got this - // interface value should already have returned false. - // Replace the function pointer with undef (which will then be - // called), indicating to the optimizer this code is unreachable. - use.ReplaceAllUsesWith(llvm.Undef(p.uintptrType)) - use.EraseFromParentAsInstruction() - } else if len(itf.types) == 1 { - // There is only one implementation of the given type. - // Call that function directly. - err := p.replaceInvokeWithCall(use, itf.types[0], signature) - if err != nil { - return err - } - } else { - // There are multiple types implementing this interface, thus there - // are multiple possible functions to call. Delegate calling the - // right function to a special wrapper function. - inttoptrs := getUses(use) - if len(inttoptrs) != 1 || inttoptrs[0].IsAIntToPtrInst().IsNil() { - return errorAt(use, "internal error: expected exactly one inttoptr use of runtime.interfaceMethod") + + // Delegate calling the right function to a special wrapper function. + inttoptrs := getUses(use) + if len(inttoptrs) != 1 || inttoptrs[0].IsAIntToPtrInst().IsNil() { + return errorAt(use, "internal error: expected exactly one inttoptr use of runtime.interfaceMethod") + } + inttoptr := inttoptrs[0] + calls := getUses(inttoptr) + for _, call := range calls { + // Set up parameters for the call. First copy the regular params... + params := make([]llvm.Value, call.OperandsCount()) + paramTypes := make([]llvm.Type, len(params)) + for i := 0; i < len(params)-1; i++ { + params[i] = call.Operand(i) + paramTypes[i] = params[i].Type() } - inttoptr := inttoptrs[0] - calls := getUses(inttoptr) - for _, call := range calls { - // Set up parameters for the call. First copy the regular params... - params := make([]llvm.Value, call.OperandsCount()) - paramTypes := make([]llvm.Type, len(params)) - for i := 0; i < len(params)-1; i++ { - params[i] = call.Operand(i) - paramTypes[i] = params[i].Type() - } - // then add the typecode to the end of the list. - params[len(params)-1] = typecode - paramTypes[len(params)-1] = p.uintptrType - - // Create a function that redirects the call to the destination - // call, after selecting the right concrete type. - redirector := p.getInterfaceMethodFunc(itf, signature, call.Type(), paramTypes) - - // Replace the old lookup/inttoptr/call with the new call. - p.builder.SetInsertPointBefore(call) - retval := p.builder.CreateCall(redirector, append(params, llvm.ConstNull(llvm.PointerType(p.ctx.Int8Type(), 0))), "") - if retval.Type().TypeKind() != llvm.VoidTypeKind { - call.ReplaceAllUsesWith(retval) - } - call.EraseFromParentAsInstruction() + // then add the typecode to the end of the list. + params[len(params)-1] = typecode + paramTypes[len(params)-1] = p.uintptrType + + // Create a function that redirects the call to the destination + // call, after selecting the right concrete type. + redirector := p.getInterfaceMethodFunc(itf, signature, call.Type(), paramTypes) + + // Replace the old lookup/inttoptr/call with the new call. + p.builder.SetInsertPointBefore(call) + retval := p.builder.CreateCall(redirector, append(params, llvm.ConstNull(llvm.PointerType(p.ctx.Int8Type(), 0))), "") + if retval.Type().TypeKind() != llvm.VoidTypeKind { + call.ReplaceAllUsesWith(retval) } - inttoptr.EraseFromParentAsInstruction() - use.EraseFromParentAsInstruction() + call.EraseFromParentAsInstruction() } + inttoptr.EraseFromParentAsInstruction() + use.EraseFromParentAsInstruction() } // Replace all typeasserts on interface types with matches on their concrete @@ -634,10 +616,19 @@ func (p *lowerInterfacesPass) createInterfaceMethodFunc(itf *interfaceInfo, sign // Create entry block. entry := p.ctx.AddBasicBlock(fn, "entry") - // Create default block and make it unreachable (which it is, because all - // possible types are checked). + // Create default block and call runtime.nilPanic. + // The only other possible value remaining is nil for nil interfaces. We + // could panic with a different message here such as "nil interface" but + // that would increase code size and "nil panic" is close enough. Most + // importantly, it avoids undefined behavior when accidentally calling a + // method on a nil interface. defaultBlock := p.ctx.AddBasicBlock(fn, "default") p.builder.SetInsertPointAtEnd(defaultBlock) + nilPanic := p.mod.NamedFunction("runtime.nilPanic") + p.builder.CreateCall(nilPanic, []llvm.Value{ + llvm.Undef(llvm.PointerType(p.ctx.Int8Type(), 0)), + llvm.Undef(llvm.PointerType(p.ctx.Int8Type(), 0)), + }, "") p.builder.CreateUnreachable() // Create type switch in entry block. diff --git a/transform/testdata/interface.ll b/transform/testdata/interface.ll index ee059c959a..1e2ce61b6e 100644 --- a/transform/testdata/interface.ll +++ b/transform/testdata/interface.ll @@ -24,6 +24,7 @@ declare void @runtime.printuint8(i8) declare void @runtime.printint32(i32) declare void @runtime.printptr(i32) declare void @runtime.printnl() +declare void @runtime.nilPanic(i8*, i8*) define void @printInterfaces() { call void @printInterface(i32 ptrtoint (%runtime.typeInInterface* @"typeInInterface:reflect/types.type:basic:int" to i32), i8* inttoptr (i32 5 to i8*)) diff --git a/transform/testdata/interface.out.ll b/transform/testdata/interface.out.ll index 9ae8b48867..4bb9ec1bbc 100644 --- a/transform/testdata/interface.out.ll +++ b/transform/testdata/interface.out.ll @@ -25,6 +25,8 @@ declare void @runtime.printptr(i32) declare void @runtime.printnl() +declare void @runtime.nilPanic(i8*, i8*) + define void @printInterfaces() { call void @printInterface(i32 4, i8* inttoptr (i32 5 to i8*)) call void @printInterface(i32 16, i8* inttoptr (i8 120 to i8*)) @@ -47,8 +49,8 @@ typeswitch.notUnmatched: ; preds = %0 br i1 %typeassert.ok, label %typeswitch.Doubler, label %typeswitch.notDoubler typeswitch.Doubler: ; preds = %typeswitch.notUnmatched - %doubler.result = call i32 @"(Number).Double$invoke"(i8* %value, i8* null) - call void @runtime.printint32(i32 %doubler.result) + %1 = call i32 @"(Doubler).Double"(i8* %value, i8* null, i32 %typecode, i8* null) + call void @runtime.printint32(i32 %1) ret void typeswitch.notDoubler: ; preds = %typeswitch.notUnmatched @@ -76,6 +78,21 @@ define i32 @"(Number).Double$invoke"(i8* %receiverPtr, i8* %parentHandle) { ret i32 %ret } +define internal i32 @"(Doubler).Double"(i8* %0, i8* %1, i32 %actualType, i8* %parentHandle) unnamed_addr { +entry: + switch i32 %actualType, label %default [ + i32 68, label %"reflect/types.type:named:Number" + ] + +default: ; preds = %entry + call void @runtime.nilPanic(i8* undef, i8* undef) + unreachable + +"reflect/types.type:named:Number": ; preds = %entry + %2 = call i32 @"(Number).Double$invoke"(i8* %0, i8* %1) + ret i32 %2 +} + define internal i1 @"Doubler$typeassert"(i32 %actualType) unnamed_addr { entry: switch i32 %actualType, label %else [