From a2dd06b1bf0ae855d25a650ee55a5e80fcfd6d51 Mon Sep 17 00:00:00 2001 From: ala89 Date: Thu, 4 Jan 2024 21:53:47 +0100 Subject: [PATCH] Implement closures --- src/analysis.cpp | 4 +++- src/include/memory.h | 4 +++- src/include/types.h | 8 +++++++- src/interpreter.cpp | 6 ++++-- src/memory.cpp | 27 +++++++++++++++++++++++++-- 5 files changed, 42 insertions(+), 7 deletions(-) diff --git a/src/analysis.cpp b/src/analysis.cpp index a367361..4efe0e6 100644 --- a/src/analysis.cpp +++ b/src/analysis.cpp @@ -337,7 +337,9 @@ AnalysisResult analyze(Node &ast, Memory &memory) { memory.declare(function_name, type); - memory.update(function_name, node.children[3]); + Function fn = { {}, memory.make_closure() }; + + memory.update(function_name, fn); memory.add_scope(ScopeType::Function, &memory.get(function_name)); for (tuple variable : prototype) { diff --git a/src/include/memory.h b/src/include/memory.h index eaa9062..cae14c7 100644 --- a/src/include/memory.h +++ b/src/include/memory.h @@ -13,7 +13,7 @@ class Memory { public: bool contains(string identifier); bool contains_top(string identifier); - void add_scope(ScopeType type, MemoryVar* fn = NULL, CodePosition entry_pos = { }); + void add_scope(ScopeType type, MemoryVar* fn = nullptr, CodePosition entry_pos = { }); void remove_scope(void); MemoryVar& get(string identifier); @@ -21,6 +21,8 @@ class Memory { void declare(string identifier, Type type); void update(string identifier, EvalResult value); + Closure make_closure(void); + StackTrace get_trace(CodePosition pos); Scope& _debug_top(void); diff --git a/src/include/types.h b/src/include/types.h index 3a19f87..6aeb1fa 100644 --- a/src/include/types.h +++ b/src/include/types.h @@ -8,6 +8,7 @@ #include #include #include +#include using namespace std; /** @@ -201,7 +202,12 @@ struct ParseReturn { /** * Interpreter */ -using EvalResult = variant; +struct MemoryVar; + +using Closure = unordered_map>; +using Function = tuple; + +using EvalResult = variant; enum class ScopeType { Block, Function, For }; diff --git a/src/interpreter.cpp b/src/interpreter.cpp index f82df65..89a68f2 100644 --- a/src/interpreter.cpp +++ b/src/interpreter.cpp @@ -173,7 +173,9 @@ EvalResult eval(Node &ast, Memory &memory) { .data = prototype }); - memory.update(identifier, node.children[3]); + Function fn = { node.children[3], memory.make_closure() }; + + memory.update(identifier, fn); return {}; } break; @@ -208,7 +210,7 @@ EvalResult eval(Node &ast, Memory &memory) { } try { - eval(get(var.value), memory); + eval(get<0>(get(var.value)), memory); // Tmp: no flow control if (get<0>(prototype[0]).type != TypeType::Void) { diff --git a/src/memory.cpp b/src/memory.cpp index 3ab6e9c..9b37351 100644 --- a/src/memory.cpp +++ b/src/memory.cpp @@ -1,6 +1,7 @@ #include "include/memory.h" #include "include/errors.h" #include "include/utils.h" +#include "include/types.h" using namespace std; Memory::Memory(void) { @@ -12,6 +13,11 @@ bool Memory::contains(string identifier) { for (auto rit = scopes.rbegin(); rit != scopes.rend(); ++rit) { Scope& scope = *rit; if (scope.vars.contains(identifier)) return true; + if (scope.type == ScopeType::Function) { + Closure closure = std::get<1>(std::get(scope.fn->value)); + if (closure.contains(identifier)) return true; + break; + } } return false; @@ -39,6 +45,11 @@ MemoryVar& Memory::get(string identifier) { for (auto rit = scopes.rbegin(); rit != scopes.rend(); ++rit) { Scope& scope = *rit; if (scope.vars.contains(identifier)) return scope.vars[identifier]; + if (scope.type == ScopeType::Function) { + Closure closure = std::get<1>(std::get(scope.fn->value)); + if (closure.contains(identifier)) return closure.at(identifier); + break; + } } throw exception(); @@ -73,14 +84,26 @@ void Memory::update(string identifier, EvalResult value) { throw exception(); } +Closure Memory::make_closure(void) { + Closure closure; + + for (auto rit = scopes.rbegin(); rit != scopes.rend(); ++rit) { + Scope& scope = *rit; + for (auto& [identifier, var] : scope.vars) { + if (!closure.contains(identifier)) closure.insert_or_assign(identifier, var); + } + } + + return closure; +} + StackTrace Memory::get_trace(CodePosition pos) { StackTrace trace = {}; for (auto rit = scopes.rbegin(); rit != scopes.rend(); ++rit) { Scope& scope = *rit; if (scope.type == ScopeType::Function) { - MemoryVar fn = *(scope.fn); - StackTraceEntry entry = { fn.identifier, pos }; + StackTraceEntry entry = { scope.fn->identifier, pos }; trace.push_back(entry); pos = scope.entry_pos; }