diff --git a/src/spec-lang/tc/typecheck.ts b/src/spec-lang/tc/typecheck.ts index 3ae022f3..ea50f905 100644 --- a/src/spec-lang/tc/typecheck.ts +++ b/src/spec-lang/tc/typecheck.ts @@ -1683,9 +1683,18 @@ function getFunDefType(fun: FunctionDefinition): FunctionType { function matchArguments( arg: SNode[], argTs: TypeNode[], - callable: FunctionDefinition | FunctionType + callable: FunctionDefinition | VariableDeclaration | FunctionType ) { - const funT = callable instanceof FunctionDefinition ? getFunDefType(callable) : callable; + let funT: FunctionType; + + if (callable instanceof FunctionDefinition) { + funT = getFunDefType(callable); + } else if (callable instanceof VariableDeclaration) { + funT = callable.getterFunType(); + } else { + funT = callable; + } + const isVariadic = funT.parameters.length > 0 && last(funT.parameters) instanceof VariableTypes; // For non-variadic functions the number of arguments must match the number of formal parameters @@ -1960,15 +1969,20 @@ export function tcFunctionCall(expr: SFunctionCall, ctx: STypingCtx, typeEnv: Ty const argTs = args.map((arg) => tc(arg, ctx, typeEnv)); - const matchingFunDefs = calleeT.definitions.filter( - (fun) => fun instanceof FunctionDefinition && matchArguments(args, argTs, fun) - ); - - const matchingVarDefs = calleeT.definitions.filter( - (fun) => fun instanceof VariableDeclaration && expr.args.length === 0 - ); - - const matchingDefs = matchingFunDefs.concat(matchingVarDefs); + // Filter the from the original set of (potentially overloaded) functions with the + // same name just the functions that match the actual argTs in the call + const matchingDefs: Array<[FunctionDefinition | VariableDeclaration, FunctionType]> = + calleeT.definitions + .map( + (def) => + [ + def, + def instanceof FunctionDefinition + ? getFunDefType(def) + : def.getterFunType() + ] as [FunctionDefinition | VariableDeclaration, FunctionType] + ) + .filter(([, funT]) => matchArguments(args, argTs, funT)); if (matchingDefs.length === 0) { throw new SUnresolvedFun( @@ -1983,13 +1997,13 @@ export function tcFunctionCall(expr: SFunctionCall, ctx: STypingCtx, typeEnv: Ty expr ); } else if (matchingDefs.length > 1) { - // This is an internal error - shouldn't be encoutered by normal user operations. + // This is an internal error - shouldn't be encountered by normal user operations. throw new Error( `Multiple functions / public getters match callsite ${expr.pp()}: ${calleeT.pp()}` ); } - const def = matchingDefs[0]; + const [def, funT] = matchingDefs[0]; // Narrow down the set of matching definitions in the callee's type. calleeT.definitions = [def]; @@ -1997,27 +2011,17 @@ export function tcFunctionCall(expr: SFunctionCall, ctx: STypingCtx, typeEnv: Ty callee.defSite = def; } - if (def instanceof FunctionDefinition) { - // param.vType is defined, as you can't put a `var x,` in a function definition. - const retTs = def.vReturnParameters.vParameters.map((param) => astVarToTypeNode(param)); - - if (retTs.length === 1) { - return retTs[0]; - } - - if (retTs.length > 1) { - return new TupleType(retTs); - } + const retTs = funT.returns; - throw new SFunNoReturn(`Function ${def.name} doesn't return a type`, expr); - } else { - if (def.vType instanceof UserDefinedTypeName) { - throw new Error(`NYI public getters for ${def.vType.print()}`); - } + if (retTs.length === 1) { + return retTs[0]; + } - // def.vType is defined, as you can't put a `var x,` in a contract state var definition. - return astVarToTypeNode(def); + if (retTs.length > 1) { + return new TupleType(retTs); } + + throw new SFunNoReturn(`Function ${def.name} doesn't return a type`, expr); } // Builtin function diff --git a/test/unit/tc.spec.ts b/test/unit/tc.spec.ts index a005731d..bbbf0e93 100644 --- a/test/unit/tc.spec.ts +++ b/test/unit/tc.spec.ts @@ -179,6 +179,16 @@ describe("TypeChecker Expression Unit Tests", () => { function idPair(uint x, uint y) public returns (uint, uint) { return (x,y); } + + mapping (address => Boo)[] public pubV; + + struct Loo { + int8 a; + string b; + int[] arr; + } + + Loo[] public pubV2; }`, [ ["uint", ["Foo"], new TypeNameType(new IntType(256, false))], @@ -580,6 +590,19 @@ describe("TypeChecker Expression Unit Tests", () => { "type(IFace).name", ["Foo"], new PointerType(new StringType(), DataLocation.Memory) + ], + [ + "this.pubV(1, address(0))", + ["Foo"], + (units) => new UserDefinedType("Boo", findTypeDef("Boo", units)) + ], + [ + "this.pubV2(1)", + ["Foo"], + new TupleType([ + new IntType(8, true), + new PointerType(new StringType(), DataLocation.Memory) + ]) ] ] ], @@ -744,6 +767,8 @@ contract UserDefinedValueTypes { } function noReturn(uint x) public {} + + mapping (address => Boo)[] public pubV; }`, [ ["int23", ["Foo"]], @@ -809,7 +834,11 @@ contract UserDefinedValueTypes { ["type(int24).min", ["Foo"]], ["type(uint).max", ["Foo"]], ["type(FooEnum).max", ["Foo"]], - ["type(FooEnum).min", ["Foo"]] + ["type(FooEnum).min", ["Foo"]], + ["this.pubV()", ["Foo"]], + ["this.pubV(false)", ["Foo"]], + ["this.pubV(1)", ["Foo"]], + ["this.pubV(1, 2)", ["Foo"]] ] ], [