Compare commits

...

2 Commits

Author SHA1 Message Date
53d5cca124 implementing closures (ch25) 2024-08-29 12:58:54 +02:00
7ed6cf6dcc implementing more native functions 2024-08-29 10:53:52 +02:00
14 changed files with 450 additions and 41 deletions

View File

@ -0,0 +1,13 @@
// this program should print "outer"; without proper closure support, it shows "global".
var x = "global";
fun outer() {
var x = "outer";
fun inner() {
print x;
}
inner();
}
outer();

View File

@ -0,0 +1,22 @@
fun makeClosure() {
var local = "local";
fun closure() {
print local;
}
return closure;
}
var closure = makeClosure();
closure();
fun makeClosure2(value) {
fun closure() {
print value;
}
return closure;
}
var doughnut = makeClosure("doughnut");
var bagel = makeClosure("bagel");
doughnut();
bagel();

View File

@ -0,0 +1,11 @@
fun outer() {
var a = 1;
var b = 2;
fun middle() {
var c = 3;
var d = 4;
fun inner() {
print a + c + b + d;
}
}
}

View File

@ -0,0 +1,8 @@
fun outer() {
var x = "outside";
fun inner() {
print x;
}
inner();
}
outer();

View File

@ -0,0 +1,10 @@
fun outer() {
var x = "outside";
fun inner() {
print x;
}
return inner;
}
var closure = outer();
closure();

View File

@ -0,0 +1,22 @@
var globalSet;
var globalGet;
fun main() {
var a = "initial";
fun set() {
a = "updated";
}
fun get() {
print a;
}
globalSet = set;
globalGet = get;
}
main();
globalSet();
globalGet();

3
samples/native_power.lox Normal file
View File

@ -0,0 +1,3 @@
print power(str2num("3"),str2num("3"));
print str2num("3") + str2num("4");

View File

@ -80,6 +80,7 @@ pub const Chunk = struct {
}
pub fn dissassemble_instruction(self: Chunk, offset: usize) usize {
var current_offset = offset;
debug.print("{d:0>4} ", .{offset});
if (offset > 0 and self.lines[offset] == self.lines[offset - 1]) {
@ -116,6 +117,32 @@ pub const Chunk = struct {
@intFromEnum(OpCode.OP_JUMP_IF_FALSE) => return utils.jump_instruction("OP_JUMP_IF_FALSE", 1, self, offset),
@intFromEnum(OpCode.OP_LOOP) => return utils.jump_instruction("OP_LOOP", -1, self, offset),
@intFromEnum(OpCode.OP_CALL) => return utils.byte_instruction("OP_CALL", self, offset),
@intFromEnum(OpCode.OP_CLOSURE) => {
current_offset += 1;
const constant = self.code[current_offset];
current_offset += 1;
debug.print("{s:<16} {d:0>4} ", .{ "OP_CLOSURE", constant });
self.constants.values[constant].print();
debug.print("\n", .{});
const function = self.constants.values[constant].as_obj().as_function();
for (0..function.upvalue_count) |j| {
_ = j;
const is_local_str = switch (self.code[current_offset]) {
1 => "local",
else => "upvalue",
};
current_offset += 1;
const index = self.code[current_offset];
current_offset += 1;
debug.print("{d:0>4} | {s:<19} {s} {d}\n", .{ current_offset - 2, "", is_local_str, index });
}
return current_offset;
},
@intFromEnum(OpCode.OP_GET_UPVALUE) => return utils.byte_instruction("OP_GET_UPVALUE", self, offset),
@intFromEnum(OpCode.OP_SET_UPVALUE) => return utils.byte_instruction("OP_SET_UPVALUE", self, offset),
@intFromEnum(OpCode.OP_CLOSE_UPVALUE) => return utils.simple_instruction("OP_CLOSE_UPVALUE", offset),
else => {
debug.print("unknown opcode {d}\n", .{instruction});
return offset + 1;

View File

@ -340,24 +340,27 @@ const Parser = struct {
fn named_variable(self: *Parser, token: Token, can_assign: bool) ParsingError!void {
var get_op: OpCode = OpCode.OP_GET_LOCAL;
var set_op: OpCode = OpCode.OP_SET_LOCAL;
var has_local = true;
var constant = self.resolve_local(token) catch blk: {
has_local = false;
break :blk 0;
};
if (!has_local) {
constant = try self.identifier_constant(token);
var arg = self.resolve_local(self.compiler, token);
const upvalue_arg = self.resolve_upvalue(self.compiler, token);
if (arg != -1) {
get_op = OpCode.OP_GET_LOCAL;
set_op = OpCode.OP_SET_LOCAL;
} else if (upvalue_arg != -1) {
get_op = OpCode.OP_GET_UPVALUE;
set_op = OpCode.OP_SET_UPVALUE;
arg = upvalue_arg;
} else {
arg = try self.identifier_constant(token);
get_op = OpCode.OP_GET_GLOBAL;
set_op = OpCode.OP_SET_GLOBAL;
}
if (can_assign and self.match(TokenType.EQUAL)) {
try self.expression();
try self.emit_bytes(@intFromEnum(set_op), constant);
try self.emit_bytes(@intFromEnum(set_op), @intCast(arg));
} else {
try self.emit_bytes(@intFromEnum(get_op), constant);
try self.emit_bytes(@intFromEnum(get_op), @intCast(arg));
}
}
@ -538,7 +541,11 @@ const Parser = struct {
self.compiler.scope_depth -= 1;
while (self.compiler.local_count > 0 and self.compiler.locals[self.compiler.local_count - 1].depth.? > self.compiler.scope_depth) {
try self.emit_byte(@intFromEnum(OpCode.OP_POP));
if (self.compiler.locals[self.compiler.local_count - 1].is_captured) {
try self.emit_byte(@intFromEnum(OpCode.OP_CLOSE_UPVALUE));
} else {
try self.emit_byte(@intFromEnum(OpCode.OP_POP));
}
self.compiler.local_count -= 1;
}
}
@ -554,17 +561,18 @@ const Parser = struct {
local.name = token;
local.depth = null;
local.is_captured = false;
}
fn resolve_local(self: *Parser, name: Token) !u8 {
if (self.compiler.local_count == 0) {
return ParsingError.NotFound;
fn resolve_local(self: *Parser, compiler: *Compiler, name: Token) isize {
if (compiler.local_count == 0) {
return -1;
}
var idx: u8 = @intCast(self.compiler.local_count - 1);
var idx: u8 = @intCast(compiler.local_count - 1);
while (idx >= 0) {
const local = &self.compiler.locals[idx];
const local = &compiler.locals[idx];
if (identifiers_equals(local.name, name)) {
if (local.depth == null) {
@ -579,7 +587,47 @@ const Parser = struct {
idx -= 1;
}
return ParsingError.NotFound;
return -1;
}
fn resolve_upvalue(self: *Parser, compiler: *Compiler, name: Token) isize {
if (compiler.enclosing == null) {
return -1;
}
const local = self.resolve_local(compiler.enclosing.?, name);
if (local != -1) {
compiler.enclosing.?.locals[@intCast(local)].is_captured = true;
return @intCast(self.add_upvalue(compiler, @intCast(local), true));
}
const upvalue = self.resolve_upvalue(compiler.enclosing.?, name);
if (upvalue != -1) {
return @intCast(self.add_upvalue(compiler, @intCast(upvalue), false));
}
return -1;
}
fn add_upvalue(self: *Parser, compiler: *Compiler, index: u8, is_local: bool) usize {
const upvalue_count = compiler.function.upvalue_count;
for (0..upvalue_count) |i| {
const upvalue: *Upvalue = &compiler.upvalues[i];
if (upvalue.index == index and upvalue.is_local == is_local) {
return i;
}
}
if (upvalue_count == constants.UINT8_COUNT) {
self.error_msg("Too many closure variables in function.");
return 0;
}
compiler.upvalues[upvalue_count].is_local = is_local;
compiler.upvalues[upvalue_count].index = index;
compiler.function.upvalue_count += 1;
return compiler.function.upvalue_count - 1;
}
fn if_statement(self: *Parser) !void {
@ -759,7 +807,16 @@ const Parser = struct {
const obj_function = try self.end_parser();
const constant = try self.make_constant(Value.obj_val(&obj_function.obj));
try self.emit_bytes(@intFromEnum(OpCode.OP_CONSTANT), constant);
try self.emit_bytes(@intFromEnum(OpCode.OP_CLOSURE), constant);
for (0..obj_function.upvalue_count) |i| {
if (compiler.upvalues[i].is_local) {
try self.emit_byte(1);
} else {
try self.emit_byte(0);
}
try self.emit_byte(@intCast(compiler.upvalues[i].index));
}
}
fn call(self: *Parser, can_assign: bool) ParsingError!void {
@ -819,6 +876,7 @@ const Compiler = struct {
locals: [constants.UINT8_COUNT]Local,
local_count: usize,
upvalues: [constants.UINT8_COUNT]Upvalue,
scope_depth: usize,
fn new(allocator: std.mem.Allocator, enclosing: ?*Compiler, function_type: FunctionType) Compiler {
@ -827,6 +885,7 @@ const Compiler = struct {
var compiler = Compiler{
.locals = undefined,
.local_count = 0,
.upvalues = undefined,
.scope_depth = 0,
.function = obj_function,
.function_type = function_type,
@ -840,6 +899,7 @@ const Compiler = struct {
.length = 0,
.line = 0,
};
compiler.locals[0].is_captured = false;
compiler.local_count += 1;
@ -855,6 +915,12 @@ const Compiler = struct {
const Local = struct {
name: Token,
depth: ?usize,
is_captured: bool,
};
const Upvalue = struct {
index: usize,
is_local: bool,
};
pub fn compile(vm: *VM, contents: []const u8) !?*Obj.Function {

View File

@ -10,7 +10,7 @@ pub const UINT8_COUNT = UINT8_MAX + 1;
pub const FRAMES_MAX = 64;
pub const STACK_MAX = (FRAMES_MAX * UINT8_MAX);
pub const DEBUG_PRINT_CODE = false;
pub const DEBUG_TRACE_EXECUTION = false;
pub const DEBUG_PRINT_CODE = true;
pub const DEBUG_TRACE_EXECUTION = true;
pub const DEBUG_PRINT_INTERNAL_STRINGS = false;
pub const DEBUG_PRINT_GLOBALS = false;

64
src/native.zig Normal file
View File

@ -0,0 +1,64 @@
const std = @import("std");
const Obj = @import("./object.zig").Obj;
const Value = @import("./values.zig").Value;
const VM = @import("./vm.zig").VM;
pub fn clock(vm: *VM, arg_count: usize, args: []Value) Value {
_ = vm;
_ = arg_count;
_ = args;
const ts = std.time.milliTimestamp();
return Value.number_val(@floatFromInt(ts));
}
pub fn power(vm: *VM, arg_count: usize, args: []Value) Value {
_ = vm;
if (arg_count != 2) {
std.debug.print("power() is expecting 2 arguments.\n", .{});
return Value.nil_val();
}
if (!args[0].is_number() or !args[0].is_number()) {
std.debug.print("args must be numbers.\n", .{});
return Value.nil_val();
}
const result_f64: f64 = std.math.pow(f64, args[0].as_number(), args[1].as_number());
return Value.number_val(result_f64);
}
pub fn str2num(vm: *VM, arg_count: usize, args: []Value) Value {
_ = vm;
if (arg_count != 1 or !args[0].is_string()) {
std.debug.print("str2num() is expecting 1 string argument.\n", .{});
return Value.nil_val();
}
const result = std.fmt.parseFloat(f64, args[0].as_cstring()) catch {
std.debug.print("invalid string for number.\n", .{});
return Value.nil_val();
};
return Value.number_val(result);
}
pub fn num2str(vm: *VM, arg_count: usize, args: []Value) Value {
if (arg_count != 1 or !args[0].is_number()) {
std.debug.print("num2str() is expecting 1 number argument.\n", .{});
return Value.nil_val();
}
const str = std.fmt.allocPrint(vm.allocator, "{d}", .{args[0].as_number()}) catch {
std.debug.print("unable to convert number to string.\n", .{});
return Value.nil_val();
};
const result = Obj.String.new(vm.allocator, str);
return Value.obj_val(result);
}

View File

@ -4,6 +4,7 @@ const Allocator = std.mem.Allocator;
const Chunk = @import("./chunk.zig").Chunk;
const Value = @import("./values.zig").Value;
const VM = @import("./vm.zig").VM;
const compute_hash = @import("./utils.zig").compute_hash;
@ -11,9 +12,11 @@ pub const ObjType = enum {
String,
Function,
Native,
Closure,
Upvalue,
};
pub const NativeFn = *const fn (arg_count: usize, args: []Value) Value;
pub const NativeFn = *const fn (vm: *VM, arg_count: usize, args: []Value) Value;
pub const Obj = struct {
kind: ObjType,
@ -48,6 +51,7 @@ pub const Obj = struct {
pub const Function = struct {
obj: Obj,
arity: usize,
upvalue_count: usize,
chunk: *Chunk,
name: ?*Obj.String,
@ -60,6 +64,7 @@ pub const Obj = struct {
const function_obj = allocator.create(Function) catch unreachable;
function_obj.obj = obj;
function_obj.arity = 0;
function_obj.upvalue_count = 0;
function_obj.chunk = Chunk.new(allocator);
function_obj.name = null;
@ -94,6 +99,64 @@ pub const Obj = struct {
}
};
pub const Closure = struct {
obj: Obj,
function: *Obj.Function,
upvalues: []?*Obj.Upvalue,
upvalue_count: usize,
pub fn new(allocator: std.mem.Allocator, function: *Obj.Function) *Closure {
const obj = Obj{
.kind = ObjType.Closure,
.allocator = allocator,
};
const closure_obj = allocator.create(Closure) catch unreachable;
closure_obj.obj = obj;
closure_obj.function = function;
closure_obj.upvalue_count = function.upvalue_count;
closure_obj.upvalues = allocator.alloc(?*Obj.Upvalue, function.upvalue_count) catch unreachable;
for (0..function.upvalue_count) |i| {
closure_obj.upvalues[i] = null;
}
return closure_obj;
}
pub fn destroy(self: *Closure) void {
self.obj.allocator.free(self.upvalues);
self.obj.allocator.destroy(self);
}
};
pub const Upvalue = struct {
obj: Obj,
location: *Value,
next: ?*Obj.Upvalue,
closed: Value,
pub fn new(allocator: std.mem.Allocator, slot: *Value) *Upvalue {
const obj = Obj{
.kind = ObjType.Upvalue,
.allocator = allocator,
};
const upvalue_obj = allocator.create(Upvalue) catch unreachable;
upvalue_obj.obj = obj;
upvalue_obj.location = slot;
upvalue_obj.next = null;
upvalue_obj.closed = Value.nil_val();
return upvalue_obj;
}
pub fn destroy(self: *Upvalue) void {
self.obj.allocator.destroy(self);
}
};
pub fn is_type(self: *Obj, kind: ObjType) bool {
return self.kind == kind;
}
@ -110,6 +173,14 @@ pub const Obj = struct {
return self.is_type(ObjType.Native);
}
pub fn is_closure(self: *Obj) bool {
return self.is_type(ObjType.Closure);
}
pub fn is_upvalue(self: *Obj) bool {
return self.is_type(ObjType.Upvalue);
}
pub fn print(self: *Obj) void {
switch (self.kind) {
ObjType.String => {
@ -125,9 +196,15 @@ pub const Obj = struct {
}
},
ObjType.Native => {
// const obj = self.as_native();
debug.print("<native fn>", .{});
},
ObjType.Closure => {
const obj = self.as_closure();
obj.function.obj.print();
},
ObjType.Upvalue => {
debug.print("upvalue", .{});
},
}
}
@ -145,6 +222,14 @@ pub const Obj = struct {
const obj: *Native = @fieldParentPtr("obj", self);
obj.destroy();
},
ObjType.Closure => {
const obj: *Closure = @fieldParentPtr("obj", self);
obj.destroy();
},
ObjType.Upvalue => {
const obj: *Upvalue = @fieldParentPtr("obj", self);
obj.destroy();
},
}
}
@ -162,4 +247,9 @@ pub const Obj = struct {
std.debug.assert(self.kind == ObjType.Native);
return @fieldParentPtr("obj", self);
}
pub fn as_closure(self: *Obj) *Closure {
std.debug.assert(self.kind == ObjType.Closure);
return @fieldParentPtr("obj", self);
}
};

View File

@ -4,11 +4,13 @@ pub const OpCode = enum(u8) {
OP_TRUE,
OP_FALSE,
OP_POP,
OP_GET_GLOBAL,
OP_DEFINE_GLOBAL,
OP_GET_GLOBAL,
OP_SET_GLOBAL,
OP_GET_LOCAL,
OP_SET_LOCAL,
OP_GET_UPVALUE,
OP_SET_UPVALUE,
OP_EQUAL,
OP_GREATER,
OP_LESS,
@ -23,5 +25,7 @@ pub const OpCode = enum(u8) {
OP_JUMP_IF_FALSE,
OP_LOOP,
OP_CALL,
OP_CLOSURE,
OP_CLOSE_UPVALUE,
OP_RETURN,
};

View File

@ -12,6 +12,8 @@ const ObjType = @import("./object.zig").ObjType;
const NativeFn = @import("./object.zig").NativeFn;
const Table = @import("./table.zig").Table;
const natives = @import("./native.zig");
const compile = @import("./compile.zig").compile;
const compute_hash = @import("./utils.zig").compute_hash;
@ -24,7 +26,7 @@ pub const InterpretResult = enum {
};
pub const CallFrame = struct {
function: *Obj.Function,
closure: *Obj.Closure,
ip: usize,
// pointer to stack index provided to this frame
slots_idx: usize,
@ -41,6 +43,7 @@ pub const VM = struct {
globals: Table,
frames: [constants.FRAMES_MAX]CallFrame,
frame_count: usize,
open_upvalues: ?*Obj.Upvalue,
pub fn new(allocator: Allocator) VM {
return VM{
@ -52,11 +55,14 @@ pub const VM = struct {
.globals = Table.new(allocator),
.frames = undefined,
.frame_count = 0,
.open_upvalues = null,
};
}
pub fn init_vm(self: *VM) void {
self.define_native("clock", clock_native);
self.define_native("clock", natives.clock);
self.define_native("power", natives.power);
self.define_native("str2num", natives.str2num);
}
pub fn destroy(self: *VM) void {
@ -71,7 +77,7 @@ pub const VM = struct {
}
inline fn current_chunk(self: *VM) *Chunk {
return self.frames[self.frame_count - 1].function.chunk;
return self.frames[self.frame_count - 1].closure.function.chunk;
}
inline fn current_frame(self: *VM) *CallFrame {
@ -86,7 +92,10 @@ pub const VM = struct {
defer function.?.destroy();
_ = try self.push(Value.obj_val(&function.?.obj));
_ = self.call(function.?, 0);
const closure: *Obj.Closure = Obj.Closure.new(self.allocator, function.?);
_ = self.pop();
_ = try self.push(Value.obj_val(&closure.obj));
_ = self.call(closure, 0);
return try self.run();
}
@ -145,6 +154,7 @@ pub const VM = struct {
},
@intFromEnum(OpCode.OP_RETURN) => {
const result = self.pop();
self.close_upvalues(&self.stack[self.current_frame().slots_idx]);
self.frame_count -= 1;
if (self.frame_count == 0) {
_ = self.pop();
@ -217,6 +227,33 @@ pub const VM = struct {
return InterpretResult.RUNTIME_ERROR;
}
},
@intFromEnum(OpCode.OP_CLOSURE) => {
const function = self.read_constant().as_obj().as_function();
const closure = Obj.Closure.new(self.allocator, function);
_ = try self.push(Value.obj_val(&closure.obj));
for (0..closure.upvalue_count) |i| {
const is_local = self.read_byte();
const index = self.read_byte();
if (is_local == 1) {
const value_idx = self.current_frame().slots_idx + index;
closure.upvalues[i] = self.capture_upvalue(&self.stack[value_idx]);
} else {
closure.upvalues[i] = self.current_frame().closure.upvalues[index];
}
}
},
@intFromEnum(OpCode.OP_GET_UPVALUE) => {
const slot = self.read_byte();
try self.push(self.current_frame().closure.upvalues[slot].?.location.*);
},
@intFromEnum(OpCode.OP_SET_UPVALUE) => {
const slot = self.read_byte();
self.current_frame().closure.upvalues[slot].?.location = @constCast(&self.peek(0));
},
@intFromEnum(OpCode.OP_CLOSE_UPVALUE) => {
self.close_upvalues(&self.stack[self.stack_top - 1]);
_ = self.pop();
},
else => {
debug.print("Invalid instruction: {d}\n", .{instruction});
return InterpretResult.RUNTIME_ERROR;
@ -310,15 +347,15 @@ pub const VM = struct {
while (true) {
const frame = self.frames[frame_idx];
const function = frame.function;
const closure = frame.closure;
const instruction = frame.ip;
debug.print("[line {d}] in ", .{function.chunk.lines[instruction]});
debug.print("[line {d}] in ", .{closure.function.chunk.lines[instruction]});
if (function.name == null) {
if (closure.function.name == null) {
debug.print("script\n", .{});
} else {
debug.print("{s}()\n", .{function.name.?.chars});
debug.print("{s}()\n", .{closure.function.name.?.chars});
}
if (frame_idx == 0) {
@ -381,18 +418,22 @@ pub const VM = struct {
if (callee.is_obj()) {
switch (callee.as_obj().kind) {
ObjType.Function => {
return self.call(callee.as_obj().as_function(), arg_count);
return self.call(callee.as_obj().as_closure(), arg_count);
},
ObjType.Native => {
const native_obj: *Obj.Native = callee.as_obj().as_native();
const value = native_obj.native(
self,
arg_count,
self.stack[self.current_frame().slots_idx - arg_count .. self.current_frame().slots_idx],
self.stack[self.stack_top - arg_count .. self.stack_top],
);
self.stack_top -= arg_count + 1;
_ = try self.push(value);
return true;
},
ObjType.Closure => {
return self.call(callee.as_obj().as_closure(), arg_count);
},
else => {},
}
}
@ -400,8 +441,8 @@ pub const VM = struct {
return false;
}
pub fn call(self: *VM, function: *Obj.Function, arg_count: usize) bool {
if (arg_count != function.arity) {
pub fn call(self: *VM, closure: *Obj.Closure, arg_count: usize) bool {
if (arg_count != closure.function.arity) {
self.runtime_error("Invalid argument count.");
// runtimeError("Expected %d arguments but got %d.", function->arity, argCount);
return false;
@ -415,7 +456,7 @@ pub const VM = struct {
const frame = &self.frames[self.frame_count];
self.frame_count += 1;
frame.function = function;
frame.closure = closure;
frame.ip = 0;
frame.slots_idx = self.stack_top - arg_count - 1;
@ -432,10 +473,38 @@ pub const VM = struct {
_ = self.pop();
}
pub fn clock_native(arg_count: usize, args: []Value) Value {
const ts = std.time.milliTimestamp();
_ = arg_count;
_ = args;
return Value.number_val(@floatFromInt(ts));
fn capture_upvalue(self: *VM, local: *Value) *Obj.Upvalue {
var prev_upvalue: ?*Obj.Upvalue = null;
var upvalue: ?*Obj.Upvalue = self.open_upvalues;
while (upvalue != null and @intFromPtr(upvalue.?.location) > @intFromPtr(local)) {
prev_upvalue = upvalue;
upvalue = upvalue.?.next;
}
if (upvalue != null and upvalue.?.location == local) {
return upvalue.?;
}
const created_upvalue = Obj.Upvalue.new(self.allocator, local);
created_upvalue.next = upvalue;
if (prev_upvalue == null) {
self.open_upvalues = created_upvalue;
} else {
prev_upvalue.?.next = created_upvalue;
}
return created_upvalue;
}
fn close_upvalues(self: *VM, last: *Value) void {
while (self.open_upvalues != null and @intFromPtr(self.open_upvalues.?.location) >= @intFromPtr(last)) {
const upvalue = self.open_upvalues.?;
upvalue.closed = upvalue.location.*;
upvalue.location = &upvalue.closed;
self.open_upvalues = upvalue.next;
}
}
};