diff --git a/samples/ch25_closures1.lox b/samples/ch25_closures1.lox new file mode 100644 index 0000000..faa4cb0 --- /dev/null +++ b/samples/ch25_closures1.lox @@ -0,0 +1,13 @@ +// this program should print "outer"; without proper closure support, it shows "global". + +var x = "global"; + +fun outer() { + var x = "outer"; + fun inner() { + print x; + } + inner(); +} + +outer(); \ No newline at end of file diff --git a/samples/ch25_closures2.lox b/samples/ch25_closures2.lox new file mode 100644 index 0000000..1e646fe --- /dev/null +++ b/samples/ch25_closures2.lox @@ -0,0 +1,22 @@ +fun makeClosure() { + var local = "local"; + fun closure() { + print local; + } + return closure; +} + +var closure = makeClosure(); +closure(); + +fun makeClosure2(value) { + fun closure() { + print value; + } + return closure; +} + +var doughnut = makeClosure("doughnut"); +var bagel = makeClosure("bagel"); +doughnut(); +bagel(); \ No newline at end of file diff --git a/samples/ch25_closures3.lox b/samples/ch25_closures3.lox new file mode 100644 index 0000000..5e842f7 --- /dev/null +++ b/samples/ch25_closures3.lox @@ -0,0 +1,11 @@ +fun outer() { + var a = 1; + var b = 2; + fun middle() { + var c = 3; + var d = 4; + fun inner() { + print a + c + b + d; + } + } +} diff --git a/samples/ch25_closures4.lox b/samples/ch25_closures4.lox new file mode 100644 index 0000000..dfe743a --- /dev/null +++ b/samples/ch25_closures4.lox @@ -0,0 +1,8 @@ +fun outer() { + var x = "outside"; + fun inner() { + print x; + } + inner(); +} +outer(); \ No newline at end of file diff --git a/samples/ch25_closures5.lox b/samples/ch25_closures5.lox new file mode 100644 index 0000000..a2e20be --- /dev/null +++ b/samples/ch25_closures5.lox @@ -0,0 +1,10 @@ +fun outer() { + var x = "outside"; + fun inner() { + print x; + } + return inner; +} + +var closure = outer(); +closure(); \ No newline at end of file diff --git a/samples/ch25_closures6.lox b/samples/ch25_closures6.lox new file mode 100644 index 0000000..4e23821 --- /dev/null +++ b/samples/ch25_closures6.lox @@ -0,0 +1,22 @@ +var globalSet; +var globalGet; + +fun main() { + var a = "initial"; + + fun set() { + a = "updated"; + } + + fun get() { + print a; + } + + globalSet = set; + globalGet = get; +} + +main(); + +globalSet(); +globalGet(); \ No newline at end of file diff --git a/src/chunk.zig b/src/chunk.zig index 3eda2b8..2ed26b7 100644 --- a/src/chunk.zig +++ b/src/chunk.zig @@ -80,6 +80,7 @@ pub const Chunk = struct { } pub fn dissassemble_instruction(self: Chunk, offset: usize) usize { + var current_offset = offset; debug.print("{d:0>4} ", .{offset}); if (offset > 0 and self.lines[offset] == self.lines[offset - 1]) { @@ -116,6 +117,32 @@ pub const Chunk = struct { @intFromEnum(OpCode.OP_JUMP_IF_FALSE) => return utils.jump_instruction("OP_JUMP_IF_FALSE", 1, self, offset), @intFromEnum(OpCode.OP_LOOP) => return utils.jump_instruction("OP_LOOP", -1, self, offset), @intFromEnum(OpCode.OP_CALL) => return utils.byte_instruction("OP_CALL", self, offset), + @intFromEnum(OpCode.OP_CLOSURE) => { + current_offset += 1; + const constant = self.code[current_offset]; + current_offset += 1; + debug.print("{s:<16} {d:0>4} ", .{ "OP_CLOSURE", constant }); + self.constants.values[constant].print(); + debug.print("\n", .{}); + + const function = self.constants.values[constant].as_obj().as_function(); + for (0..function.upvalue_count) |j| { + _ = j; + const is_local_str = switch (self.code[current_offset]) { + 1 => "local", + else => "upvalue", + }; + current_offset += 1; + const index = self.code[current_offset]; + current_offset += 1; + + debug.print("{d:0>4} | {s:<19} {s} {d}\n", .{ current_offset - 2, "", is_local_str, index }); + } + return current_offset; + }, + @intFromEnum(OpCode.OP_GET_UPVALUE) => return utils.byte_instruction("OP_GET_UPVALUE", self, offset), + @intFromEnum(OpCode.OP_SET_UPVALUE) => return utils.byte_instruction("OP_SET_UPVALUE", self, offset), + @intFromEnum(OpCode.OP_CLOSE_UPVALUE) => return utils.simple_instruction("OP_CLOSE_UPVALUE", offset), else => { debug.print("unknown opcode {d}\n", .{instruction}); return offset + 1; diff --git a/src/compile.zig b/src/compile.zig index 98e604f..8c2d4a2 100644 --- a/src/compile.zig +++ b/src/compile.zig @@ -340,24 +340,27 @@ const Parser = struct { fn named_variable(self: *Parser, token: Token, can_assign: bool) ParsingError!void { var get_op: OpCode = OpCode.OP_GET_LOCAL; var set_op: OpCode = OpCode.OP_SET_LOCAL; - var has_local = true; - var constant = self.resolve_local(token) catch blk: { - has_local = false; - break :blk 0; - }; - - if (!has_local) { - constant = try self.identifier_constant(token); + var arg = self.resolve_local(self.compiler, token); + const upvalue_arg = self.resolve_upvalue(self.compiler, token); + if (arg != -1) { + get_op = OpCode.OP_GET_LOCAL; + set_op = OpCode.OP_SET_LOCAL; + } else if (upvalue_arg != -1) { + get_op = OpCode.OP_GET_UPVALUE; + set_op = OpCode.OP_SET_UPVALUE; + arg = upvalue_arg; + } else { + arg = try self.identifier_constant(token); get_op = OpCode.OP_GET_GLOBAL; set_op = OpCode.OP_SET_GLOBAL; } if (can_assign and self.match(TokenType.EQUAL)) { try self.expression(); - try self.emit_bytes(@intFromEnum(set_op), constant); + try self.emit_bytes(@intFromEnum(set_op), @intCast(arg)); } else { - try self.emit_bytes(@intFromEnum(get_op), constant); + try self.emit_bytes(@intFromEnum(get_op), @intCast(arg)); } } @@ -538,7 +541,11 @@ const Parser = struct { self.compiler.scope_depth -= 1; while (self.compiler.local_count > 0 and self.compiler.locals[self.compiler.local_count - 1].depth.? > self.compiler.scope_depth) { - try self.emit_byte(@intFromEnum(OpCode.OP_POP)); + if (self.compiler.locals[self.compiler.local_count - 1].is_captured) { + try self.emit_byte(@intFromEnum(OpCode.OP_CLOSE_UPVALUE)); + } else { + try self.emit_byte(@intFromEnum(OpCode.OP_POP)); + } self.compiler.local_count -= 1; } } @@ -554,17 +561,18 @@ const Parser = struct { local.name = token; local.depth = null; + local.is_captured = false; } - fn resolve_local(self: *Parser, name: Token) !u8 { - if (self.compiler.local_count == 0) { - return ParsingError.NotFound; + fn resolve_local(self: *Parser, compiler: *Compiler, name: Token) isize { + if (compiler.local_count == 0) { + return -1; } - var idx: u8 = @intCast(self.compiler.local_count - 1); + var idx: u8 = @intCast(compiler.local_count - 1); while (idx >= 0) { - const local = &self.compiler.locals[idx]; + const local = &compiler.locals[idx]; if (identifiers_equals(local.name, name)) { if (local.depth == null) { @@ -579,7 +587,47 @@ const Parser = struct { idx -= 1; } - return ParsingError.NotFound; + return -1; + } + + fn resolve_upvalue(self: *Parser, compiler: *Compiler, name: Token) isize { + if (compiler.enclosing == null) { + return -1; + } + + const local = self.resolve_local(compiler.enclosing.?, name); + if (local != -1) { + compiler.enclosing.?.locals[@intCast(local)].is_captured = true; + return @intCast(self.add_upvalue(compiler, @intCast(local), true)); + } + + const upvalue = self.resolve_upvalue(compiler.enclosing.?, name); + if (upvalue != -1) { + return @intCast(self.add_upvalue(compiler, @intCast(upvalue), false)); + } + + return -1; + } + + fn add_upvalue(self: *Parser, compiler: *Compiler, index: u8, is_local: bool) usize { + const upvalue_count = compiler.function.upvalue_count; + + for (0..upvalue_count) |i| { + const upvalue: *Upvalue = &compiler.upvalues[i]; + if (upvalue.index == index and upvalue.is_local == is_local) { + return i; + } + } + + if (upvalue_count == constants.UINT8_COUNT) { + self.error_msg("Too many closure variables in function."); + return 0; + } + + compiler.upvalues[upvalue_count].is_local = is_local; + compiler.upvalues[upvalue_count].index = index; + compiler.function.upvalue_count += 1; + return compiler.function.upvalue_count - 1; } fn if_statement(self: *Parser) !void { @@ -759,7 +807,16 @@ const Parser = struct { const obj_function = try self.end_parser(); const constant = try self.make_constant(Value.obj_val(&obj_function.obj)); - try self.emit_bytes(@intFromEnum(OpCode.OP_CONSTANT), constant); + try self.emit_bytes(@intFromEnum(OpCode.OP_CLOSURE), constant); + + for (0..obj_function.upvalue_count) |i| { + if (compiler.upvalues[i].is_local) { + try self.emit_byte(1); + } else { + try self.emit_byte(0); + } + try self.emit_byte(@intCast(compiler.upvalues[i].index)); + } } fn call(self: *Parser, can_assign: bool) ParsingError!void { @@ -819,6 +876,7 @@ const Compiler = struct { locals: [constants.UINT8_COUNT]Local, local_count: usize, + upvalues: [constants.UINT8_COUNT]Upvalue, scope_depth: usize, fn new(allocator: std.mem.Allocator, enclosing: ?*Compiler, function_type: FunctionType) Compiler { @@ -827,6 +885,7 @@ const Compiler = struct { var compiler = Compiler{ .locals = undefined, .local_count = 0, + .upvalues = undefined, .scope_depth = 0, .function = obj_function, .function_type = function_type, @@ -840,6 +899,7 @@ const Compiler = struct { .length = 0, .line = 0, }; + compiler.locals[0].is_captured = false; compiler.local_count += 1; @@ -855,6 +915,12 @@ const Compiler = struct { const Local = struct { name: Token, depth: ?usize, + is_captured: bool, +}; + +const Upvalue = struct { + index: usize, + is_local: bool, }; pub fn compile(vm: *VM, contents: []const u8) !?*Obj.Function { diff --git a/src/constant.zig b/src/constant.zig index fc6c23d..68a7740 100644 --- a/src/constant.zig +++ b/src/constant.zig @@ -10,7 +10,7 @@ pub const UINT8_COUNT = UINT8_MAX + 1; pub const FRAMES_MAX = 64; pub const STACK_MAX = (FRAMES_MAX * UINT8_MAX); -pub const DEBUG_PRINT_CODE = false; -pub const DEBUG_TRACE_EXECUTION = false; +pub const DEBUG_PRINT_CODE = true; +pub const DEBUG_TRACE_EXECUTION = true; pub const DEBUG_PRINT_INTERNAL_STRINGS = false; pub const DEBUG_PRINT_GLOBALS = false; diff --git a/src/object.zig b/src/object.zig index d5697ed..b540e42 100644 --- a/src/object.zig +++ b/src/object.zig @@ -12,6 +12,8 @@ pub const ObjType = enum { String, Function, Native, + Closure, + Upvalue, }; pub const NativeFn = *const fn (vm: *VM, arg_count: usize, args: []Value) Value; @@ -49,6 +51,7 @@ pub const Obj = struct { pub const Function = struct { obj: Obj, arity: usize, + upvalue_count: usize, chunk: *Chunk, name: ?*Obj.String, @@ -61,6 +64,7 @@ pub const Obj = struct { const function_obj = allocator.create(Function) catch unreachable; function_obj.obj = obj; function_obj.arity = 0; + function_obj.upvalue_count = 0; function_obj.chunk = Chunk.new(allocator); function_obj.name = null; @@ -95,6 +99,64 @@ pub const Obj = struct { } }; + pub const Closure = struct { + obj: Obj, + function: *Obj.Function, + upvalues: []?*Obj.Upvalue, + upvalue_count: usize, + + pub fn new(allocator: std.mem.Allocator, function: *Obj.Function) *Closure { + const obj = Obj{ + .kind = ObjType.Closure, + .allocator = allocator, + }; + + const closure_obj = allocator.create(Closure) catch unreachable; + closure_obj.obj = obj; + closure_obj.function = function; + closure_obj.upvalue_count = function.upvalue_count; + + closure_obj.upvalues = allocator.alloc(?*Obj.Upvalue, function.upvalue_count) catch unreachable; + + for (0..function.upvalue_count) |i| { + closure_obj.upvalues[i] = null; + } + + return closure_obj; + } + + pub fn destroy(self: *Closure) void { + self.obj.allocator.free(self.upvalues); + self.obj.allocator.destroy(self); + } + }; + + pub const Upvalue = struct { + obj: Obj, + location: *Value, + next: ?*Obj.Upvalue, + closed: Value, + + pub fn new(allocator: std.mem.Allocator, slot: *Value) *Upvalue { + const obj = Obj{ + .kind = ObjType.Upvalue, + .allocator = allocator, + }; + + const upvalue_obj = allocator.create(Upvalue) catch unreachable; + upvalue_obj.obj = obj; + upvalue_obj.location = slot; + upvalue_obj.next = null; + upvalue_obj.closed = Value.nil_val(); + + return upvalue_obj; + } + + pub fn destroy(self: *Upvalue) void { + self.obj.allocator.destroy(self); + } + }; + pub fn is_type(self: *Obj, kind: ObjType) bool { return self.kind == kind; } @@ -111,6 +173,14 @@ pub const Obj = struct { return self.is_type(ObjType.Native); } + pub fn is_closure(self: *Obj) bool { + return self.is_type(ObjType.Closure); + } + + pub fn is_upvalue(self: *Obj) bool { + return self.is_type(ObjType.Upvalue); + } + pub fn print(self: *Obj) void { switch (self.kind) { ObjType.String => { @@ -126,9 +196,15 @@ pub const Obj = struct { } }, ObjType.Native => { - // const obj = self.as_native(); debug.print("", .{}); }, + ObjType.Closure => { + const obj = self.as_closure(); + obj.function.obj.print(); + }, + ObjType.Upvalue => { + debug.print("upvalue", .{}); + }, } } @@ -146,6 +222,14 @@ pub const Obj = struct { const obj: *Native = @fieldParentPtr("obj", self); obj.destroy(); }, + ObjType.Closure => { + const obj: *Closure = @fieldParentPtr("obj", self); + obj.destroy(); + }, + ObjType.Upvalue => { + const obj: *Upvalue = @fieldParentPtr("obj", self); + obj.destroy(); + }, } } @@ -163,4 +247,9 @@ pub const Obj = struct { std.debug.assert(self.kind == ObjType.Native); return @fieldParentPtr("obj", self); } + + pub fn as_closure(self: *Obj) *Closure { + std.debug.assert(self.kind == ObjType.Closure); + return @fieldParentPtr("obj", self); + } }; diff --git a/src/opcode.zig b/src/opcode.zig index 0109591..07bbb5d 100644 --- a/src/opcode.zig +++ b/src/opcode.zig @@ -4,11 +4,13 @@ pub const OpCode = enum(u8) { OP_TRUE, OP_FALSE, OP_POP, - OP_GET_GLOBAL, OP_DEFINE_GLOBAL, + OP_GET_GLOBAL, OP_SET_GLOBAL, OP_GET_LOCAL, OP_SET_LOCAL, + OP_GET_UPVALUE, + OP_SET_UPVALUE, OP_EQUAL, OP_GREATER, OP_LESS, @@ -23,5 +25,7 @@ pub const OpCode = enum(u8) { OP_JUMP_IF_FALSE, OP_LOOP, OP_CALL, + OP_CLOSURE, + OP_CLOSE_UPVALUE, OP_RETURN, }; diff --git a/src/vm.zig b/src/vm.zig index c89b6d9..95eb3b6 100644 --- a/src/vm.zig +++ b/src/vm.zig @@ -26,7 +26,7 @@ pub const InterpretResult = enum { }; pub const CallFrame = struct { - function: *Obj.Function, + closure: *Obj.Closure, ip: usize, // pointer to stack index provided to this frame slots_idx: usize, @@ -43,6 +43,7 @@ pub const VM = struct { globals: Table, frames: [constants.FRAMES_MAX]CallFrame, frame_count: usize, + open_upvalues: ?*Obj.Upvalue, pub fn new(allocator: Allocator) VM { return VM{ @@ -54,6 +55,7 @@ pub const VM = struct { .globals = Table.new(allocator), .frames = undefined, .frame_count = 0, + .open_upvalues = null, }; } @@ -75,7 +77,7 @@ pub const VM = struct { } inline fn current_chunk(self: *VM) *Chunk { - return self.frames[self.frame_count - 1].function.chunk; + return self.frames[self.frame_count - 1].closure.function.chunk; } inline fn current_frame(self: *VM) *CallFrame { @@ -90,7 +92,10 @@ pub const VM = struct { defer function.?.destroy(); _ = try self.push(Value.obj_val(&function.?.obj)); - _ = self.call(function.?, 0); + const closure: *Obj.Closure = Obj.Closure.new(self.allocator, function.?); + _ = self.pop(); + _ = try self.push(Value.obj_val(&closure.obj)); + _ = self.call(closure, 0); return try self.run(); } @@ -149,6 +154,7 @@ pub const VM = struct { }, @intFromEnum(OpCode.OP_RETURN) => { const result = self.pop(); + self.close_upvalues(&self.stack[self.current_frame().slots_idx]); self.frame_count -= 1; if (self.frame_count == 0) { _ = self.pop(); @@ -221,6 +227,33 @@ pub const VM = struct { return InterpretResult.RUNTIME_ERROR; } }, + @intFromEnum(OpCode.OP_CLOSURE) => { + const function = self.read_constant().as_obj().as_function(); + const closure = Obj.Closure.new(self.allocator, function); + _ = try self.push(Value.obj_val(&closure.obj)); + for (0..closure.upvalue_count) |i| { + const is_local = self.read_byte(); + const index = self.read_byte(); + if (is_local == 1) { + const value_idx = self.current_frame().slots_idx + index; + closure.upvalues[i] = self.capture_upvalue(&self.stack[value_idx]); + } else { + closure.upvalues[i] = self.current_frame().closure.upvalues[index]; + } + } + }, + @intFromEnum(OpCode.OP_GET_UPVALUE) => { + const slot = self.read_byte(); + try self.push(self.current_frame().closure.upvalues[slot].?.location.*); + }, + @intFromEnum(OpCode.OP_SET_UPVALUE) => { + const slot = self.read_byte(); + self.current_frame().closure.upvalues[slot].?.location = @constCast(&self.peek(0)); + }, + @intFromEnum(OpCode.OP_CLOSE_UPVALUE) => { + self.close_upvalues(&self.stack[self.stack_top - 1]); + _ = self.pop(); + }, else => { debug.print("Invalid instruction: {d}\n", .{instruction}); return InterpretResult.RUNTIME_ERROR; @@ -314,15 +347,15 @@ pub const VM = struct { while (true) { const frame = self.frames[frame_idx]; - const function = frame.function; + const closure = frame.closure; const instruction = frame.ip; - debug.print("[line {d}] in ", .{function.chunk.lines[instruction]}); + debug.print("[line {d}] in ", .{closure.function.chunk.lines[instruction]}); - if (function.name == null) { + if (closure.function.name == null) { debug.print("script\n", .{}); } else { - debug.print("{s}()\n", .{function.name.?.chars}); + debug.print("{s}()\n", .{closure.function.name.?.chars}); } if (frame_idx == 0) { @@ -385,7 +418,7 @@ pub const VM = struct { if (callee.is_obj()) { switch (callee.as_obj().kind) { ObjType.Function => { - return self.call(callee.as_obj().as_function(), arg_count); + return self.call(callee.as_obj().as_closure(), arg_count); }, ObjType.Native => { const native_obj: *Obj.Native = callee.as_obj().as_native(); @@ -398,6 +431,9 @@ pub const VM = struct { _ = try self.push(value); return true; }, + ObjType.Closure => { + return self.call(callee.as_obj().as_closure(), arg_count); + }, else => {}, } } @@ -405,8 +441,8 @@ pub const VM = struct { return false; } - pub fn call(self: *VM, function: *Obj.Function, arg_count: usize) bool { - if (arg_count != function.arity) { + pub fn call(self: *VM, closure: *Obj.Closure, arg_count: usize) bool { + if (arg_count != closure.function.arity) { self.runtime_error("Invalid argument count."); // runtimeError("Expected %d arguments but got %d.", function->arity, argCount); return false; @@ -420,7 +456,7 @@ pub const VM = struct { const frame = &self.frames[self.frame_count]; self.frame_count += 1; - frame.function = function; + frame.closure = closure; frame.ip = 0; frame.slots_idx = self.stack_top - arg_count - 1; @@ -436,4 +472,39 @@ pub const VM = struct { _ = self.pop(); _ = self.pop(); } + + fn capture_upvalue(self: *VM, local: *Value) *Obj.Upvalue { + var prev_upvalue: ?*Obj.Upvalue = null; + var upvalue: ?*Obj.Upvalue = self.open_upvalues; + + while (upvalue != null and @intFromPtr(upvalue.?.location) > @intFromPtr(local)) { + prev_upvalue = upvalue; + upvalue = upvalue.?.next; + } + + if (upvalue != null and upvalue.?.location == local) { + return upvalue.?; + } + + const created_upvalue = Obj.Upvalue.new(self.allocator, local); + created_upvalue.next = upvalue; + + if (prev_upvalue == null) { + self.open_upvalues = created_upvalue; + } else { + prev_upvalue.?.next = created_upvalue; + } + + return created_upvalue; + } + + fn close_upvalues(self: *VM, last: *Value) void { + while (self.open_upvalues != null and @intFromPtr(self.open_upvalues.?.location) >= @intFromPtr(last)) { + const upvalue = self.open_upvalues.?; + + upvalue.closed = upvalue.location.*; + upvalue.location = &upvalue.closed; + self.open_upvalues = upvalue.next; + } + } };