From 3daa675f8d7416b51261ff49082aad51f9c57d22 Mon Sep 17 00:00:00 2001 From: Patrick MARIE Date: Wed, 28 Aug 2024 16:25:01 +0200 Subject: [PATCH] implementing calls and functions (ch24) --- samples/ch24_function.lox | 4 + samples/ch24_function_call.lox | 9 +++ samples/ch24_invalid_arg_count.lox | 7 ++ samples/ch24_multiple_calls.lox | 3 + src/chunk.zig | 37 ++++++--- src/compile.zig | 124 +++++++++++++++++++++++++++-- src/constant.zig | 4 +- src/main.zig | 10 +-- src/object.zig | 2 +- src/opcode.zig | 1 + src/vm.zig | 97 +++++++++++++++++----- 11 files changed, 252 insertions(+), 46 deletions(-) create mode 100644 samples/ch24_function.lox create mode 100644 samples/ch24_function_call.lox create mode 100644 samples/ch24_invalid_arg_count.lox create mode 100644 samples/ch24_multiple_calls.lox diff --git a/samples/ch24_function.lox b/samples/ch24_function.lox new file mode 100644 index 0000000..901b55d --- /dev/null +++ b/samples/ch24_function.lox @@ -0,0 +1,4 @@ +fun areWeHavingItYet() { + print "Yes we are!"; +} +print areWeHavingItYet; diff --git a/samples/ch24_function_call.lox b/samples/ch24_function_call.lox new file mode 100644 index 0000000..50fc40d --- /dev/null +++ b/samples/ch24_function_call.lox @@ -0,0 +1,9 @@ +fun helloWorld(var_str, var_int) { + var b = 1; + print var_str + " blah"; + + return b + var_int; +} + +var c = helloWorld("a", 42); +print c; diff --git a/samples/ch24_invalid_arg_count.lox b/samples/ch24_invalid_arg_count.lox new file mode 100644 index 0000000..bda98c8 --- /dev/null +++ b/samples/ch24_invalid_arg_count.lox @@ -0,0 +1,7 @@ +fun a() { b(); } +fun b() { c(); } +fun c() { + c("too", "many"); +} + +a(); diff --git a/samples/ch24_multiple_calls.lox b/samples/ch24_multiple_calls.lox new file mode 100644 index 0000000..2838005 --- /dev/null +++ b/samples/ch24_multiple_calls.lox @@ -0,0 +1,3 @@ +fun a(a, b) { var c = a + b; return c; } +print a(2, 4); +print a(4, 6); diff --git a/src/chunk.zig b/src/chunk.zig index fcdc648..3eda2b8 100644 --- a/src/chunk.zig +++ b/src/chunk.zig @@ -10,22 +10,25 @@ const grow_capacity = @import("./utils.zig").grow_capacity; const utils = @import("./utils.zig"); pub const Chunk = struct { + allocator: Allocator, + count: usize, capacity: usize, code: []u8, lines: []usize, constants: ValueArray, - allocator: Allocator, - pub fn new(allocator: Allocator) Chunk { - return Chunk{ - .count = 0, - .capacity = 0, - .code = &.{}, - .lines = &.{}, - .constants = ValueArray.new(allocator), - .allocator = allocator, - }; + pub fn new(allocator: Allocator) *Chunk { + var chunk: *Chunk = allocator.create(Chunk) catch unreachable; + + chunk.allocator = allocator; + chunk.count = 0; + chunk.capacity = 0; + chunk.code = &.{}; + chunk.lines = &.{}; + chunk.constants = ValueArray.new(allocator); + + return chunk; } pub fn destroy(self: *Chunk) void { @@ -35,6 +38,8 @@ pub const Chunk = struct { self.allocator.free(self.code); self.allocator.free(self.lines); } + + self.allocator.destroy(self); } pub fn write(self: *Chunk, byte: u8, line: usize) !void { @@ -50,8 +55,17 @@ pub const Chunk = struct { self.count += 1; } - pub fn dump(self: Chunk) void { + pub fn dump(self: *Chunk) void { + debug.print("== chunk dump of {*} ==\n", .{self}); debug.print("{any}\n", .{self}); + + for (0..self.constants.count) |idx| { + debug.print("constant {d}: {*} ", .{ idx, &self.constants.values[idx] }); + self.constants.values[idx].print(); + debug.print("\n", .{}); + } + + debug.print("== end of chunk dump \n\n", .{}); } pub fn dissassemble(self: Chunk, name: []const u8) void { @@ -101,6 +115,7 @@ pub const Chunk = struct { @intFromEnum(OpCode.OP_JUMP) => return utils.jump_instruction("OP_JUMP", 1, self, offset), @intFromEnum(OpCode.OP_JUMP_IF_FALSE) => return utils.jump_instruction("OP_JUMP_IF_FALSE", 1, self, offset), @intFromEnum(OpCode.OP_LOOP) => return utils.jump_instruction("OP_LOOP", -1, self, offset), + @intFromEnum(OpCode.OP_CALL) => return utils.byte_instruction("OP_CALL", self, offset), else => { debug.print("unknown opcode {d}\n", .{instruction}); return offset + 1; diff --git a/src/compile.zig b/src/compile.zig index 3f838c5..ae0c0b1 100644 --- a/src/compile.zig +++ b/src/compile.zig @@ -61,7 +61,7 @@ const Parser = struct { } inline fn current_chunk(self: *Parser) *Chunk { - return &self.compiler.function.chunk; + return self.compiler.function.chunk; } fn advance(self: *Parser) void { @@ -132,17 +132,24 @@ const Parser = struct { } fn emit_return(self: *Parser) ParsingError!void { + try self.emit_byte(@intFromEnum(OpCode.OP_NIL)); try self.emit_byte(@intFromEnum(OpCode.OP_RETURN)); } fn end_parser(self: *Parser) !*Obj.Function { + try self.emit_return(); + if (!self.had_error and constants.DEBUG_PRINT_CODE) { self.current_chunk().dissassemble("code"); } - try self.emit_return(); + const function_obj = self.compiler.function; - return self.compiler.function; + if (self.compiler.enclosing != null) { + self.compiler = self.compiler.enclosing.?; + } + + return function_obj; } fn number(self: *Parser, can_assign: bool) ParsingError!void { @@ -235,7 +242,7 @@ const Parser = struct { fn get_rule(operator_type: TokenType) ParserRule { return switch (operator_type) { - TokenType.LEFT_PAREN => ParserRule{ .prefix = grouping, .infix = null, .precedence = Precedence.None }, + TokenType.LEFT_PAREN => ParserRule{ .prefix = grouping, .infix = call, .precedence = Precedence.Call }, TokenType.RIGHT_PAREN => ParserRule{ .prefix = null, .infix = null, .precedence = Precedence.None }, TokenType.LEFT_BRACE => ParserRule{ .prefix = null, .infix = null, .precedence = Precedence.None }, TokenType.RIGHT_BRACE => ParserRule{ .prefix = null, .infix = null, .precedence = Precedence.None }, @@ -353,7 +360,9 @@ const Parser = struct { } fn declaration(self: *Parser) ParsingError!void { - if (self.match(TokenType.VAR)) { + if (self.match(TokenType.FUN)) { + try self.fun_declaration(); + } else if (self.match(TokenType.VAR)) { try self.var_declaration(); } else { try self.statement(); @@ -371,6 +380,8 @@ const Parser = struct { try self.for_statement(); } else if (self.match(TokenType.IF)) { try self.if_statement(); + } else if (self.match(TokenType.RETURN)) { + try self.return_statement(); } else if (self.match(TokenType.WHILE)) { try self.while_statement(); } else if (self.match(TokenType.LEFT_BRACE)) { @@ -470,6 +481,9 @@ const Parser = struct { } fn mark_initialized(self: *Parser) void { + if (self.compiler.scope_depth == 0) { + return; + } self.compiler.locals[self.compiler.local_count - 1].depth = self.compiler.scope_depth; } @@ -701,6 +715,93 @@ const Parser = struct { try self.end_scope(); } + + fn fun_declaration(self: *Parser) ParsingError!void { + const global: u8 = try self.parse_variable("Expect function name."); + self.mark_initialized(); + try self.function(FunctionType.Function); + try self.define_variable(global); + } + + fn function(self: *Parser, function_type: FunctionType) ParsingError!void { + var compiler = Compiler.new(self.vm.allocator, self.compiler, function_type); + + self.compiler = &compiler; + if (function_type != FunctionType.Script) { + self.compiler.function.name = self.vm.copy_string(self.previous.?.start[0..self.previous.?.length]); + } + + self.begin_scope(); + + self.consume(TokenType.LEFT_PAREN, "Expect '(' after function name."); + if (!self.check(TokenType.RIGHT_PAREN)) { + while (true) { + self.compiler.function.arity += 1; + if (self.compiler.function.arity > 255) { + self.error_at_current("Can't have more than 255 parameters."); + } + + const constant = try self.parse_variable("Expect parameter name."); + try self.define_variable(constant); + if (!self.match(TokenType.COMMA)) { + break; + } + } + } + + self.consume(TokenType.RIGHT_PAREN, "Expect ')' after parameters."); + self.consume(TokenType.LEFT_BRACE, "Expect '{' before function body."); + + try self.block(); + + const obj_function = try self.end_parser(); + + const constant = try self.make_constant(Value.obj_val(&obj_function.obj)); + try self.emit_bytes(@intFromEnum(OpCode.OP_CONSTANT), constant); + } + + fn call(self: *Parser, can_assign: bool) ParsingError!void { + _ = can_assign; + + const arg_count = try self.argument_list(); + try self.emit_bytes(@intFromEnum(OpCode.OP_CALL), @intCast(arg_count)); + } + + fn argument_list(self: *Parser) ParsingError!usize { + var arg_count: usize = 0; + + if (!self.check(TokenType.RIGHT_PAREN)) { + while (true) { + try self.expression(); + if (arg_count == 16) { + self.error_msg("Can't have more than 16 arguments."); + } + arg_count += 1; + + if (!self.match(TokenType.COMMA)) { + break; + } + } + } + + self.consume(TokenType.RIGHT_PAREN, "Expect ')' after arguments."); + + return arg_count; + } + + fn return_statement(self: *Parser) ParsingError!void { + if (self.compiler.function_type == FunctionType.Script) { + self.error_msg("Can't return from top-level code."); + } + + if (self.match(TokenType.SEMICOLON)) { + try self.emit_return(); + } else { + try self.expression(); + self.consume(TokenType.SEMICOLON, "Expect ';' after return value."); + try self.emit_byte(@intFromEnum(OpCode.OP_RETURN)); + } + } }; const FunctionType = enum { @@ -709,6 +810,8 @@ const FunctionType = enum { }; const Compiler = struct { + enclosing: ?*Compiler, + function: *Obj.Function, function_type: FunctionType, @@ -716,7 +819,7 @@ const Compiler = struct { local_count: usize, scope_depth: usize, - fn new(allocator: std.mem.Allocator, function_type: FunctionType) Compiler { + fn new(allocator: std.mem.Allocator, enclosing: ?*Compiler, function_type: FunctionType) Compiler { const obj_function = Obj.Function.new(allocator); var compiler = Compiler{ @@ -725,6 +828,7 @@ const Compiler = struct { .scope_depth = 0, .function = obj_function, .function_type = function_type, + .enclosing = enclosing, }; compiler.locals[0].depth = 0; @@ -739,6 +843,11 @@ const Compiler = struct { return compiler; } + + fn destroy(self: *Compiler) void { + // do not destroy function here! it is used after compiler life. + _ = self; + } }; const Local = struct { @@ -747,8 +856,7 @@ const Local = struct { }; pub fn compile(vm: *VM, contents: []const u8) !?*Obj.Function { - var compiler = Compiler.new(vm.allocator, FunctionType.Script); - + var compiler = Compiler.new(vm.allocator, null, FunctionType.Script); var scanner = Scanner.init(contents); var parser = Parser.new(vm, &compiler, &scanner); diff --git a/src/constant.zig b/src/constant.zig index 68a7740..fc6c23d 100644 --- a/src/constant.zig +++ b/src/constant.zig @@ -10,7 +10,7 @@ pub const UINT8_COUNT = UINT8_MAX + 1; pub const FRAMES_MAX = 64; pub const STACK_MAX = (FRAMES_MAX * UINT8_MAX); -pub const DEBUG_PRINT_CODE = true; -pub const DEBUG_TRACE_EXECUTION = true; +pub const DEBUG_PRINT_CODE = false; +pub const DEBUG_TRACE_EXECUTION = false; pub const DEBUG_PRINT_INTERNAL_STRINGS = false; pub const DEBUG_PRINT_GLOBALS = false; diff --git a/src/main.zig b/src/main.zig index b64a07b..6b28e2f 100644 --- a/src/main.zig +++ b/src/main.zig @@ -12,7 +12,7 @@ const InterpretResult = @import("./vm.zig").InterpretResult; // XXX imported to run tests. const Table = @import("./table.zig"); -pub fn repl(allocator: Allocator, vm: *VM) !void { +pub fn repl(vm: *VM) !void { var line: [1024]u8 = undefined; const stdin = std.io.getStdIn().reader(); @@ -34,7 +34,7 @@ pub fn repl(allocator: Allocator, vm: *VM) !void { break; } - _ = try vm.interpret(allocator, &line); + _ = try vm.interpret(&line); } } @@ -45,7 +45,7 @@ pub fn run_file(allocator: Allocator, vm: *VM, filepath: []const u8) !void { const file_content = try file.readToEndAlloc(allocator, 1024 * 1024); defer allocator.free(file_content); - const result = try vm.interpret(allocator, file_content); + const result = try vm.interpret(file_content); switch (result) { InterpretResult.COMPILE_ERROR => std.process.exit(65), @@ -55,7 +55,7 @@ pub fn run_file(allocator: Allocator, vm: *VM, filepath: []const u8) !void { } pub fn main() !void { - var gpa = std.heap.GeneralPurposeAllocator(.{ .safety = true }){}; + var gpa = std.heap.GeneralPurposeAllocator(.{ .safety = false }){}; defer _ = debug.assert(gpa.deinit() == .ok); const allocator = gpa.allocator(); @@ -66,7 +66,7 @@ pub fn main() !void { defer vm.destroy(); if (args.len == 1) { - try repl(allocator, &vm); + try repl(&vm); } else if (args.len == 2) { try run_file(allocator, &vm, args[1]); } else { diff --git a/src/object.zig b/src/object.zig index 21766fb..69de867 100644 --- a/src/object.zig +++ b/src/object.zig @@ -44,7 +44,7 @@ pub const Obj = struct { pub const Function = struct { obj: Obj, arity: usize, - chunk: Chunk, + chunk: *Chunk, name: ?*Obj.String, pub fn new(allocator: std.mem.Allocator) *Function { diff --git a/src/opcode.zig b/src/opcode.zig index 7c0cdde..0109591 100644 --- a/src/opcode.zig +++ b/src/opcode.zig @@ -22,5 +22,6 @@ pub const OpCode = enum(u8) { OP_JUMP, OP_JUMP_IF_FALSE, OP_LOOP, + OP_CALL, OP_RETURN, }; diff --git a/src/vm.zig b/src/vm.zig index babbf03..abf50f5 100644 --- a/src/vm.zig +++ b/src/vm.zig @@ -8,6 +8,7 @@ const Chunk = @import("./chunk.zig").Chunk; const OpCode = @import("./opcode.zig").OpCode; const Value = @import("./values.zig").Value; const Obj = @import("./object.zig").Obj; +const ObjType = @import("./object.zig").ObjType; const Table = @import("./table.zig").Table; const compile = @import("./compile.zig").compile; @@ -65,30 +66,22 @@ pub const VM = struct { } inline fn current_chunk(self: *VM) *Chunk { - return &self.frames[self.frame_count - 1].function.chunk; + return self.frames[self.frame_count - 1].function.chunk; } inline fn current_frame(self: *VM) *CallFrame { return &self.frames[self.frame_count - 1]; } - pub fn interpret(self: *VM, allocator: Allocator, content: []const u8) !InterpretResult { - var chunk = Chunk.new(allocator); - defer chunk.destroy(); - - const function = try compile(self, content); + pub fn interpret(self: *VM, content: []const u8) !InterpretResult { + var function = try compile(self, content); if (function == null) { return InterpretResult.COMPILE_ERROR; } defer function.?.destroy(); _ = try self.push(Value.obj_val(&function.?.obj)); - - const frame = &self.frames[self.frame_count]; - self.frame_count += 1; - frame.function = function.?; - frame.ip = 0; - frame.slots_idx = self.stack_top; + _ = self.call(function.?, 0); return try self.run(); } @@ -146,7 +139,15 @@ pub const VM = struct { debug.print("\n", .{}); }, @intFromEnum(OpCode.OP_RETURN) => { - return InterpretResult.OK; + const result = self.pop(); + self.frame_count -= 1; + if (self.frame_count == 0) { + _ = self.pop(); + return InterpretResult.OK; + } + + self.stack_top = self.frames[self.frame_count].slots_idx; + try self.push(result); }, @intFromEnum(OpCode.OP_EQUAL) => { try self.push(Value.bool_val(self.pop().equals(self.pop()))); @@ -185,11 +186,11 @@ pub const VM = struct { }, @intFromEnum(OpCode.OP_GET_LOCAL) => { const slot = self.read_byte(); - try self.push(self.stack[self.current_frame().slots_idx + slot - 1]); + try self.push(self.stack[self.current_frame().slots_idx + slot]); }, @intFromEnum(OpCode.OP_SET_LOCAL) => { const slot = self.read_byte(); - self.stack[self.current_frame().slots_idx + slot - 1] = self.peek(0); + self.stack[self.current_frame().slots_idx + slot] = self.peek(0); }, @intFromEnum(OpCode.OP_JUMP) => { const offset = self.read_short(); @@ -205,6 +206,12 @@ pub const VM = struct { const offset = self.read_short(); self.current_frame().ip -= offset; }, + @intFromEnum(OpCode.OP_CALL) => { + const arg_count = self.read_byte(); + if (!self.call_value(self.peek(arg_count), arg_count)) { + return InterpretResult.RUNTIME_ERROR; + } + }, else => { debug.print("Invalid instruction: {d}\n", .{instruction}); return InterpretResult.RUNTIME_ERROR; @@ -292,11 +299,28 @@ pub const VM = struct { } pub fn runtime_error(self: *VM, err_msg: []const u8) void { - const instruction = self.current_frame().ip; - const line = self.current_chunk().lines[instruction]; - debug.print("err: {s}\n", .{err_msg}); - debug.print("[line {d}] in script\n", .{line}); + + var frame_idx = self.frame_count - 1; + + while (true) { + const frame = self.frames[frame_idx]; + const function = frame.function; + const instruction = frame.ip; + + debug.print("[line {d}] in ", .{function.chunk.lines[instruction]}); + + if (function.name == null) { + debug.print("script\n", .{}); + } else { + debug.print("{s}()\n", .{function.name.?.chars}); + } + + if (frame_idx == 0) { + break; + } + frame_idx -= 1; + } } pub fn add_reference(self: *VM, obj: *Obj) void { @@ -347,4 +371,39 @@ pub const VM = struct { return obj_string; } + + pub fn call_value(self: *VM, callee: Value, arg_count: usize) bool { + if (callee.is_obj()) { + switch (callee.as_obj().kind) { + ObjType.Function => { + return self.call(callee.as_obj().as_function(), arg_count); + }, + else => {}, + } + } + self.runtime_error("Can only call functions and classes."); + return false; + } + + pub fn call(self: *VM, function: *Obj.Function, arg_count: usize) bool { + if (arg_count != function.arity) { + self.runtime_error("Invalid argument count."); + // runtimeError("Expected %d arguments but got %d.", function->arity, argCount); + return false; + } + + if (self.frame_count == constants.FRAMES_MAX) { + self.runtime_error("Stack overflow."); + return false; + } + + const frame = &self.frames[self.frame_count]; + self.frame_count += 1; + + frame.function = function; + frame.ip = 0; + frame.slots_idx = self.stack_top - arg_count - 1; + + return true; + } };