diff --git a/src/analysis.cpp b/src/analysis.cpp index b2bebc9..687991b 100644 --- a/src/analysis.cpp +++ b/src/analysis.cpp @@ -264,6 +264,55 @@ AnalysisResult analyze(Node &ast, Memory &memory) { analyze(node.children[0], memory); return analyze(node.children[1], memory); } + case NodeType::FunctionPrototype: { + Token return_type_token = get(node.children[0]); + string return_type_string = get(return_type_token.data); + Type return_type = try_string_to_type(return_type_string, return_type_token.pos); + + Token token = get(node.children[1]); + string function_name = get(token.data); + + try { + memory.get_function_scope(); + throw RuntimeError(ErrorType::NestedFunction, token.pos, {}); + } catch (const InternalError& _) {} + + memory.declare(function_name, { + .type = TypeType::Function, + .data = make_fn_prototype(node) + }, return_type_token.pos); + + return {}; + } break; + case NodeType::FunctionDeclaration: { + Token return_type_token = get(node.children[0]); + string return_type_string = get(return_type_token.data); + Type return_type = try_string_to_type(return_type_string, return_type_token.pos); + + Token token = get(node.children[1]); + string function_name = get(token.data); + + try { + memory.get_function_scope(); + throw RuntimeError(ErrorType::NestedFunction, token.pos, {}); + } catch (const InternalError& _) {} + + if (memory.contains(function_name) && memory.get(function_name).initialized) + throw RuntimeError(ErrorType::FunctionRedefinition, token.pos, function_name); + + memory.declare(function_name, { + .type = TypeType::Function, + .data = make_fn_prototype(node) + }, return_type_token.pos); + + memory.update(function_name, node.children[3]); + + memory.add_scope(ScopeType::Function, token.pos, &memory.get(function_name)); + analyze(node.children[3], memory); + memory.remove_scope(); + + return {}; + } break; default: return {}; } diff --git a/src/errors.cpp b/src/errors.cpp index ba55fd2..e0a1f22 100644 --- a/src/errors.cpp +++ b/src/errors.cpp @@ -25,6 +25,7 @@ string UserError::get_message(void) const { case ErrorType::TypesNotComparable: return "Types not comparable"; case ErrorType::ExpectedIntegralType: return "Expression must have integral type"; case ErrorType::NestedFunction: return "Function definition is not allowed here"; + case ErrorType::FunctionRedefinition: return "Redefinition of '"+get(this->data)+"'"; case ErrorType::ExpectedArithmeticType: return "Expression must have arithmetic type"; case ErrorType::UnknownIdentifier: return "Unknown identifier '"+get(this->data)+"'"; case ErrorType::AlreadyDeclaredIdentifier: return "Already declared identifier '"+get(this->data)+"'"; diff --git a/src/include/errors.h b/src/include/errors.h index 4f0f54e..b30abdb 100644 --- a/src/include/errors.h +++ b/src/include/errors.h @@ -34,6 +34,7 @@ enum class ErrorType { ExpectedIntegralType, ExpectedArithmeticType, NestedFunction, + FunctionRedefinition, // Runtime UnknownIdentifier,