From df9d1079ea8bb7a8100c41d90f87256c4b60250d Mon Sep 17 00:00:00 2001 From: Patrick MARIE Date: Mon, 26 Aug 2024 16:25:20 +0200 Subject: [PATCH] implementing hash tables (ch20) --- src/compile.zig | 3 +- src/main.zig | 3 + src/object.zig | 9 +- src/table.zig | 302 ++++++++++++++++++++++++++++++++++++++++++++++++ src/utils.zig | 11 ++ src/values.zig | 23 ++-- src/vm.zig | 50 +++++++- 7 files changed, 384 insertions(+), 17 deletions(-) create mode 100644 src/table.zig diff --git a/src/compile.zig b/src/compile.zig index 9118a42..6e8766e 100644 --- a/src/compile.zig +++ b/src/compile.zig @@ -290,7 +290,8 @@ const Parser = struct { fn string(self: *Parser) ParsingError!void { const str = self.previous.?.start[1 .. self.previous.?.length - 1]; - var string_obj = Obj.String.new(self.chunk.allocator, str); + + var string_obj = self.vm.copy_string(str); self.vm.add_reference(&string_obj.obj); diff --git a/src/main.zig b/src/main.zig index 556205e..dda02ec 100644 --- a/src/main.zig +++ b/src/main.zig @@ -7,6 +7,9 @@ const OpCode = @import("./opcode.zig").OpCode; const VM = @import("./vm.zig").VM; const InterpretResult = @import("./vm.zig").InterpretResult; +// XXX imported to run tests. +const Table = @import("./table.zig"); + pub const DEBUG_TRACE_EXECUTION = true; pub fn repl(allocator: Allocator, vm: *VM) !void { diff --git a/src/object.zig b/src/object.zig index 760e530..28fc9d1 100644 --- a/src/object.zig +++ b/src/object.zig @@ -2,6 +2,8 @@ const std = @import("std"); const debug = std.debug; const Allocator = std.mem.Allocator; +const compute_hash = @import("./utils.zig").compute_hash; + pub const ObjType = enum { String, }; @@ -13,8 +15,9 @@ pub const Obj = struct { pub const String = struct { chars: []const u8, obj: Obj, + hash: u32, - pub fn new(allocator: std.mem.Allocator, str: []const u8) *String { + pub fn new(allocator: std.mem.Allocator, chars: []const u8) *String { const obj = Obj{ .kind = ObjType.String, .allocator = allocator, @@ -22,7 +25,9 @@ pub const Obj = struct { const str_obj = allocator.create(String) catch unreachable; str_obj.obj = obj; - str_obj.chars = allocator.dupe(u8, str) catch unreachable; + + str_obj.chars = chars; + str_obj.hash = compute_hash(str_obj.chars); return str_obj; } diff --git a/src/table.zig b/src/table.zig new file mode 100644 index 0000000..e7acfbc --- /dev/null +++ b/src/table.zig @@ -0,0 +1,302 @@ +const std = @import("std"); +const debug = std.debug; +const Allocator = std.mem.Allocator; + +const Obj = @import("./object.zig").Obj; +const Value = @import("./values.zig").Value; + +const grow_capacity = @import("./utils.zig").grow_capacity; +const compute_hash = @import("./utils.zig").compute_hash; + +const TABLE_MAX_LOAD = 0.75; + +const Entry = struct { + key: ?*Obj.String, + value: Value, +}; + +pub const Table = struct { + allocator: Allocator, + count: usize, + capacity: usize, + entries: []Entry, + + pub fn new(allocator: Allocator) Table { + return Table{ + .allocator = allocator, + .count = 0, + .capacity = 0, + .entries = &.{}, + }; + } + + pub fn deinit(self: *Table) void { + if (self.capacity == 0) { + return; + } + + self.allocator.free(self.entries); + } + + pub fn set(self: *Table, key: *Obj.String, value: Value) bool { + const current_count: f32 = @floatFromInt(self.count + 1); + const current_capacity: f32 = @floatFromInt(self.capacity); + + if (current_count > current_capacity * TABLE_MAX_LOAD) { + const capacity = grow_capacity(self.capacity); + self.adjust_capacity(capacity); + } + + const entry = Table.find_entry(self.entries, key); + const is_new = entry.?.key == null; + if (is_new and entry.?.value.is_nil()) { + self.count += 1; + } + + entry.?.key = key; + entry.?.value = value; + + return is_new; + } + + pub fn find_entry(entries: []Entry, key: *Obj.String) ?*Entry { + var tombstone: ?*Entry = null; + var index = key.hash % entries.len; + + while (true) { + const entry = &entries[index]; + + if (entry.key == null) { + if (entry.value.is_nil()) { + // Empty entry. + if (tombstone != null) { + return tombstone; + } else { + return entry; + } + } else { + // We found a tombestone + if (tombstone == null) { + tombstone = entry; + } + } + } else if (entry.key == key) { + // We found the key + return entry; + } + + index = (index + 1) % entries.len; + } + } + + pub fn adjust_capacity(self: *Table, capacity: usize) void { + var entries = self.allocator.alloc(Entry, capacity) catch unreachable; + + for (0..entries.len) |idx| { + entries[idx].key = null; + entries[idx].value = Value.nil_val(); + } + + self.count = 0; + for (0..self.capacity) |idx| { + const entry = self.entries[idx]; + if (entry.key == null) { + continue; + } + + const dest_entry = Table.find_entry(entries, entry.key.?); + dest_entry.?.key = entry.key; + dest_entry.?.value = entry.value; + + self.count += 1; + } + + self.capacity = capacity; + if (entries.len > 0) { + self.allocator.free(self.entries); + } + self.entries = entries; + } + + pub fn dump(self: Table) void { + std.debug.print("== Hash table count:{} capacity:{} ==\n", .{ self.count, self.capacity }); + for (self.entries, 0..) |entry, idx| { + if (entry.key != null) { + std.debug.print("{d} ({d}) - {s}: ", .{ idx, entry.key.?.hash, entry.key.?.chars }); + entry.value.print(); + std.debug.print("\n", .{}); + } + + if (entry.key == null and entry.value.as_bool()) { + std.debug.print("{d} - tombstone\n", .{idx}); + } + } + std.debug.print("== End of hash table ==\n\n", .{}); + } + + pub fn add_all(self: *Table, from: Table) void { + for (from.entries) |entry| { + if (entry.key == null) { + continue; + } + _ = self.set(entry.key.?, entry.value); + } + } + + pub fn get(self: Table, key: *Obj.String, value: *Value) bool { + if (self.count == 0) { + return false; + } + + const entry = Table.find_entry(self.entries, key); + if (entry.?.key == null) { + return false; + } + + value.* = entry.?.value; + + return true; + } + + pub fn del(self: *Table, key: *Obj.String) bool { + if (self.count == 0) { + return false; + } + + // Find the entry + const entry = Table.find_entry(self.entries, key); + if (entry.?.key == null) { + return false; + } + + // Place a tombstone in the entry + entry.?.key = null; + entry.?.value = Value.bool_val(true); + + return true; + } + + pub fn find_string(self: *Table, chars: []const u8, hash: u32) ?*Obj.String { + if (self.count == 0) { + return null; + } + + var index = hash % self.capacity; + while (true) { + const entry = &self.entries[index]; + if (entry.key == null) { + // Stop if we find an empty non-tombstone entry. + if (entry.value.is_nil()) { + return null; + } + } else if (entry.key.?.chars.len == chars.len and entry.key.?.hash == hash and std.mem.eql(u8, chars, entry.key.?.chars)) { + return entry.key; + } + + index = (index + 1) % self.capacity; + } + } +}; + +test "initialize an hash table" { + const allocator = std.testing.allocator; + + var table = Table.new(allocator); + defer table.deinit(); + try std.testing.expectEqual(0, table.count); + try std.testing.expectEqual(0, table.capacity); +} + +test "adding values" { + const allocator = std.testing.allocator; + + const key = Obj.String.new(allocator, "hello world"); + defer key.destroy(); + + var table = Table.new(allocator); + defer table.deinit(); + + var res = table.set(key, Value.nil_val()); + try std.testing.expectEqual(true, res); + try std.testing.expectEqual(1, table.count); + try std.testing.expectEqual(8, table.capacity); + + res = table.set(key, Value.nil_val()); + try std.testing.expectEqual(false, res); + try std.testing.expectEqual(1, table.count); + try std.testing.expectEqual(8, table.capacity); +} + +test "adding tables" { + const allocator = std.testing.allocator; + + const key = Obj.String.new(allocator, "hello world"); + defer key.destroy(); + + var table = Table.new(allocator); + defer table.deinit(); + + const res = table.set(key, Value.nil_val()); + try std.testing.expectEqual(true, res); + try std.testing.expectEqual(8, table.capacity); + try std.testing.expectEqual(1, table.count); + + var table2 = Table.new(allocator); + defer table2.deinit(); + + try std.testing.expectEqual(0, table2.capacity); + try std.testing.expectEqual(0, table2.count); + + table2.add_all(table); + try std.testing.expectEqual(8, table2.capacity); + try std.testing.expectEqual(1, table2.count); +} + +test "deleting from table" { + const allocator = std.testing.allocator; + + const key = Obj.String.new(allocator, "hello world"); + defer key.destroy(); + + var table = Table.new(allocator); + defer table.deinit(); + + var res = table.set(key, Value.nil_val()); + try std.testing.expectEqual(true, res); + + // table.dump(); + + res = table.del(key); + try std.testing.expectEqual(true, res); + + // table.dump(); +} + +test "find" { + const allocator = std.testing.allocator; + + const key = Obj.String.new(allocator, "hello world"); + defer key.destroy(); + + var table = Table.new(allocator); + defer table.deinit(); + + const value = Value.number_val(42.0); + + var res = table.set(key, value); + try std.testing.expectEqual(true, res); + + var entry = table.find_string("bye world", compute_hash("bye world")); + try std.testing.expectEqual(entry, null); + + entry = table.find_string("hello world", compute_hash("hello world")); + // std.debug.print("{any}\n", .{entry}); + try std.testing.expect(entry != null); + + var value_obj = Value.nil_val(); + res = table.get(entry.?, &value_obj); + try std.testing.expect(res); + // std.debug.print("{any}\n", .{value_obj}); + + try std.testing.expectEqual(value_obj.as_number(), value.as_number()); +} diff --git a/src/utils.zig b/src/utils.zig index 119cfc5..0b3fced 100644 --- a/src/utils.zig +++ b/src/utils.zig @@ -24,3 +24,14 @@ pub fn constant_instruction(opcode_name: []const u8, chunk: Chunk, offset: usize debug.print("'\n", .{}); return offset + 2; } + +pub fn compute_hash(str: []const u8) u32 { + var res_hash: u32 = 2166136261; + + for (str) |c| { + res_hash ^= c; + res_hash *%= 16777619; + } + + return res_hash; +} diff --git a/src/values.zig b/src/values.zig index e5f40f0..9e194bb 100644 --- a/src/values.zig +++ b/src/values.zig @@ -114,14 +114,18 @@ pub const Value = struct { ValueType.Nil => true, ValueType.Bool => self.as_bool() == other.as_bool(), ValueType.Number => self.as_number() == other.as_number(), - ValueType.Obj => { - const obj_string0 = self.as_cstring(); - const obj_string1 = other.as_cstring(); - - return std.mem.eql(u8, obj_string0, obj_string1); - }, + ValueType.Obj => self.as_obj() == other.as_obj(), }; } + + pub fn print(self: Value) void { + switch (self.value_type) { + ValueType.Nil => debug.print("nil", .{}), + ValueType.Bool => debug.print("{any}", .{self.as_bool()}), + ValueType.Number => debug.print("{d}", .{self.as_number()}), + ValueType.Obj => self.as_obj().print(), + } + } }; pub const ValueArray = struct { @@ -156,10 +160,5 @@ pub const ValueArray = struct { }; pub fn print_value(value: Value) void { - switch (value.value_type) { - ValueType.Nil => debug.print("nil", .{}), - ValueType.Bool => debug.print("{any}", .{value.as_bool()}), - ValueType.Number => debug.print("{d}", .{value.as_number()}), - ValueType.Obj => value.as_obj().print(), - } + value.print(); } diff --git a/src/vm.zig b/src/vm.zig index 835455d..809fd65 100644 --- a/src/vm.zig +++ b/src/vm.zig @@ -6,8 +6,10 @@ const Chunk = @import("./chunk.zig").Chunk; const OpCode = @import("./opcode.zig").OpCode; const Value = @import("./values.zig").Value; const Obj = @import("./object.zig").Obj; +const Table = @import("./table.zig").Table; const compile = @import("./compile.zig").compile; +const compute_hash = @import("./utils.zig").compute_hash; const DEBUG_TRACE_EXECUTION = @import("./main.zig").DEBUG_TRACE_EXECUTION; @@ -22,24 +24,31 @@ pub const InterpretResult = enum { }; pub const VM = struct { + allocator: Allocator, chunk: ?*Chunk, ip: ?usize, stack: std.ArrayList(Value), // Keeping creating objects in references to destroy objects on cleaning. // In the book, a linked list between objects is used to handle this. references: std.ArrayList(*Obj), + strings: Table, pub fn new(allocator: Allocator) VM { return VM{ + .allocator = allocator, .chunk = null, .ip = null, .stack = std.ArrayList(Value).init(allocator), .references = std.ArrayList(*Obj).init(allocator), + .strings = Table.new(allocator), }; } pub fn free(self: *VM) void { self.stack.deinit(); + + self.strings.dump(); + self.strings.deinit(); self.clean_references(); self.references.deinit(); } @@ -179,9 +188,8 @@ pub const VM = struct { const a = self.pop().as_cstring(); const concat_str = try std.mem.concat(self.chunk.?.allocator, u8, &.{ a, b }); - defer self.chunk.?.allocator.free(concat_str); - var string_obj = Obj.String.new(self.chunk.?.allocator, concat_str); + var string_obj = self.take_string(concat_str); self.add_reference(&string_obj.obj); @@ -201,6 +209,12 @@ pub const VM = struct { } pub fn add_reference(self: *VM, obj: *Obj) void { + // do not add duplicate references + for (self.references.items) |item| { + if (item == obj) { + return; + } + } // XXX TODO catch unreachable to prevents self.references.append(obj) catch unreachable; } @@ -210,4 +224,36 @@ pub const VM = struct { item.destroy(); } } + + pub fn copy_string(self: *VM, source: []const u8) *Obj.String { + const hash = compute_hash(source); + const obj_string = self.strings.find_string(source, hash); + + if (obj_string != null) { + return obj_string.?; + } + + const copy: []const u8 = self.allocator.dupe(u8, source) catch unreachable; + return self.allocate_string(copy); + } + + pub fn take_string(self: *VM, source: []const u8) *Obj.String { + const hash = compute_hash(source); + const obj_string = self.strings.find_string(source, hash); + + if (obj_string != null) { + // free given string + self.allocator.free(source); + return obj_string.?; + } + + return self.allocate_string(source); + } + + pub fn allocate_string(self: *VM, source: []const u8) *Obj.String { + const obj_string = Obj.String.new(self.allocator, source); + _ = self.strings.set(obj_string, Value.nil_val()); + + return obj_string; + } };