diff --git a/src/analysis.cpp b/src/analysis.cpp index bee7c32..1f906bb 100644 --- a/src/analysis.cpp +++ b/src/analysis.cpp @@ -277,17 +277,25 @@ AnalysisResult analyze(Node &ast, Memory &memory) { throw RuntimeError(ErrorType::NestedFunction, token.pos, {}); } catch (const InternalError& _) {} + FunctionPrototype prototype = make_fn_prototype(node); + Type type = { + .type=TypeType::Function, + .data=prototype + }; + if (memory.contains(function_name)) { // Il existe une variable non fonction à ce nom if (memory.get(function_name).type.type != TypeType::Function) { throw RuntimeError(ErrorType::IncompatibleRedefinition, token.pos, function_name); } + + Type old_type = memory.get(function_name).type; + if (!equal_types(type, old_type)) { + throw TypeError(ErrorType::IncompatibleDefinition, token.pos, function_name); + } } - memory.declare(function_name, { - .type = TypeType::Function, - .data = make_fn_prototype(node) - }, return_type_token.pos); + memory.declare(function_name, type, return_type_token.pos); return {}; } break; @@ -310,6 +318,10 @@ AnalysisResult analyze(Node &ast, Memory &memory) { throw RuntimeError(ErrorType::FunctionRedefinition, token.pos, function_name); FunctionPrototype prototype = make_fn_prototype(node); + Type type = { + .type=TypeType::Function, + .data=prototype + }; if (memory.contains(function_name)) { // Il existe une variable non fonction à ce nom @@ -317,13 +329,13 @@ AnalysisResult analyze(Node &ast, Memory &memory) { throw RuntimeError(ErrorType::IncompatibleRedefinition, token.pos, function_name); } - //TODO: Vérification de la cohérence des arguments avec ceux du prototype + Type old_type = memory.get(function_name).type; + if (!equal_types(type, old_type)) { + throw TypeError(ErrorType::IncompatibleDefinition, token.pos, function_name); + } } - memory.declare(function_name, { - .type=TypeType::Function, - .data=prototype - }, return_type_token.pos); + memory.declare(function_name, type, return_type_token.pos); memory.update(function_name, node.children[3]); @@ -339,6 +351,24 @@ AnalysisResult analyze(Node &ast, Memory &memory) { return {}; } break; + case NodeType::FunctionCall: { + Token token = get(node.children[0]); + string function_name = get(token.data); + + if (!memory.contains(function_name)) + throw RuntimeError(ErrorType::UnknownIdentifier, token.pos, function_name); + + + if (memory.contains(function_name) && !memory.get(function_name).initialized) + throw RuntimeError(ErrorType::UninitializedIdentifier, token.pos, function_name); + + MemoryVar function = memory.get(function_name); + + if (get(node.children[1]).children.size() != get(function.type.data).size()-1) + throw RuntimeError(ErrorType::UnexpectedArgumentCount, token.pos, int(get(function.type.data).size())-1); + + return get<0>(get(function.type.data).at(0)); + } break; default: return {}; } diff --git a/src/errors.cpp b/src/errors.cpp index 54bd83c..f08a03d 100644 --- a/src/errors.cpp +++ b/src/errors.cpp @@ -27,6 +27,8 @@ string UserError::get_message(void) const { case ErrorType::NestedFunction: return "Function definition is not allowed here"; case ErrorType::FunctionRedefinition: return "Redefinition of '"+get(this->data)+"'"; case ErrorType::IncompatibleRedefinition: return "Redefinition of '"+get(this->data)+"' as different kind of symbol"; + case ErrorType::IncompatibleDefinition: return "Declaration of '"+get(this->data)+"' is incompatible with previous prototype"; + case ErrorType::UnexpectedArgumentCount: return "Expected "+to_string(get(this->data))+" arguments to function call"; case ErrorType::ExpectedArithmeticType: return "Expression must have arithmetic type"; case ErrorType::UnknownIdentifier: return "Unknown identifier '"+get(this->data)+"'"; case ErrorType::AlreadyDeclaredIdentifier: return "Already declared identifier '"+get(this->data)+"'"; diff --git a/src/include/errors.h b/src/include/errors.h index 91438cc..fe673c1 100644 --- a/src/include/errors.h +++ b/src/include/errors.h @@ -35,7 +35,9 @@ enum class ErrorType { ExpectedArithmeticType, NestedFunction, FunctionRedefinition, - IncompatibleRedefinition, + IncompatibleRedefinition, // redefinition of a variable as a function + IncompatibleDefinition, // function declaration incompatible with prototype + UnexpectedArgumentCount, // Runtime UnknownIdentifier, @@ -48,7 +50,7 @@ enum class ErrorType { ControlReachesEndOfNonVoidFn }; -using ErrorData = variant; +using ErrorData = variant; class UserError : public exception { public: diff --git a/src/include/utils.h b/src/include/utils.h index 91e07e6..6e9a1f3 100644 --- a/src/include/utils.h +++ b/src/include/utils.h @@ -34,4 +34,9 @@ vector split_string(const string& input, char delimiter); */ string _debug_get_type_type_name(TypeType type); +/** + * Check if two types are equal +*/ +bool equal_types(Type type1, Type type2); + #endif \ No newline at end of file diff --git a/src/utils.cpp b/src/utils.cpp index 07dd736..a169ef2 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -76,4 +76,39 @@ string _debug_get_type_type_name(TypeType type) { case TypeType::Function: return "FUNCTION"; default: return "Unknown"; } +} + +bool equal_types(Type type1, Type type2) { + if (type1.type != type2.type) + return false; + + switch (type1.type) { + case TypeType::Int: + case TypeType::Double: + case TypeType::Void: + return true; + break; + case TypeType::Function: { + if (holds_alternative(type1.data) != holds_alternative(type2.data)) + return false; + + if (holds_alternative(type1.data)) + return true; + + FunctionPrototype args1 = get(type1.data); + FunctionPrototype args2 = get(type2.data); + + if (args1.size() != args2.size()) + return false; + + for (int i=0; i < (int)args1.size(); i++) { + if (!equal_types(get<0>(args1.at(i)), get<0>(args2.at(i)))) + return false; + } + return true; + } + default: + return false; + break; + } } \ No newline at end of file