From 2df5f3f18e494d5177a17af398e833f60d8d0c3a Mon Sep 17 00:00:00 2001 From: Marcus Date: Mon, 13 May 2024 00:53:51 +0200 Subject: [PATCH] Add wildcards (#37) * Fix export * Initial commit * Uncomment cases * Rename case * Add tests for wildcards * Support wildcards in records * Add tests for relation data * Add shorthands * Change casing of exports * Change function signatures * Improve inlining of ECS_PAIR * Delete whitespace * Create root archetype * Add back tests * Fix tests --- lib/init.lua | 299 ++++++++++++++++++++++++---------------------- lib/init.spec.lua | 68 ++++++++++- test.project.json | 6 +- tests/world.lua | 77 ++++++++++-- wally.toml | 3 +- 5 files changed, 292 insertions(+), 161 deletions(-) diff --git a/lib/init.lua b/lib/init.lua index ca6a573..b7486e4 100644 --- a/lib/init.lua +++ b/lib/init.lua @@ -44,11 +44,118 @@ type ArchetypeDiff = { removed: Ty, } +local FLAGS_PAIR = 0x8 local HI_COMPONENT_ID = 256 local ON_ADD = HI_COMPONENT_ID + 1 local ON_REMOVE = HI_COMPONENT_ID + 2 local ON_SET = HI_COMPONENT_ID + 3 -local REST = HI_COMPONENT_ID + 4 +local WILDCARD = HI_COMPONENT_ID + 4 +local REST = HI_COMPONENT_ID + 5 + +local ECS_ID_FLAGS_MASK = 0x10 +local ECS_ENTITY_MASK = bit32.lshift(1, 24) +local ECS_GENERATION_MASK = bit32.lshift(1, 16) + +local function addFlags(isPair: boolean) + local typeFlags = 0x0 + + if isPair then + typeFlags = bit32.bor(typeFlags, FLAGS_PAIR) -- HIGHEST bit in the ID. + end + if false then + typeFlags = bit32.bor(typeFlags, 0x4) -- Set the second flag to true + end + if false then + typeFlags = bit32.bor(typeFlags, 0x2) -- Set the third flag to true + end + if false then + typeFlags = bit32.bor(typeFlags, 0x1) -- LAST BIT in the ID. + end + + return typeFlags +end + +local function newId(source: number, target: number) + local e = source * 2^28 + target * ECS_ID_FLAGS_MASK + return e +end + +local function ECS_IS_PAIR(e: number) + return (e % 2^4) // FLAGS_PAIR ~= 0 +end + +function separate(entity: number) + local _typeFlags = entity % 0x10 + entity //= ECS_ID_FLAGS_MASK + return entity // ECS_ENTITY_MASK, entity % ECS_GENERATION_MASK, _typeFlags +end + +-- HIGH 24 bits LOW 24 bits +local function ECS_GENERATION(e: i53) + e //= 0x10 + return e % ECS_GENERATION_MASK +end + +local function ECS_ID(e: i53) + e //= 0x10 + return e // ECS_ENTITY_MASK +end + +local function ECS_GENERATION_INC(e: i53) + local id, generation, flags = separate(e) + + return newId(id, generation + 1) + flags +end + +-- gets the high ID +local function ECS_PAIR_FIRST(entity: i53): i24 + entity //= 0x10 + local first = entity % ECS_ENTITY_MASK + return first +end + +-- gets the low ID +local ECS_PAIR_SECOND = ECS_ID + +local function ECS_PAIR(first: number, second: number) + local target = WILDCARD + local relation + + if first == WILDCARD then + relation = second + elseif second == WILDCARD then + relation = first + else + relation = second + target = ECS_PAIR_SECOND(first) + end + + return newId( + ECS_PAIR_SECOND(relation), target) + addFlags(--[[isPair]] true) +end + +local function getAlive(entityIndex: EntityIndex, id: i53) + return entityIndex.dense[id] +end + +local function ecs_get_source(entityIndex, e) + assert(ECS_IS_PAIR(e)) + return getAlive(entityIndex, ECS_PAIR_FIRST(e)) +end +local function ecs_get_target(entityIndex, e) + assert(ECS_IS_PAIR(e)) + return getAlive(entityIndex, ECS_PAIR_SECOND(e)) +end + +local function nextEntityId(entityIndex, index: i24) + local id = newId(index, 0) + entityIndex.sparse[id] = { + dense = index + } :: Record + entityIndex.dense[index] = id + + return id +end local function transitionArchetype( entityIndex: EntityIndex, @@ -132,22 +239,14 @@ local function hash(arr): string | number return table.concat(arr, "_") end -local function createArchetypeRecords(componentIndex: ComponentIndex, to: Archetype, _from: Archetype?) - local destinationIds = to.types - local records = to.records - local id = to.id +local function createArchetypeRecord(componentIndex, id, componentId, i) + local archetypesMap = componentIndex[componentId] - for i, destinationId in destinationIds do - local archetypesMap = componentIndex[destinationId] - - if not archetypesMap then - archetypesMap = {size = 0, sparse = {}} - componentIndex[destinationId] = archetypesMap - end - - archetypesMap.sparse[id] = i - records[destinationId] = i + if not archetypesMap then + archetypesMap = {size = 0, sparse = {}} + componentIndex[componentId] = archetypesMap end + archetypesMap.sparse[id] = i end local function archetypeOf(world: World, types: {i24}, prev: Archetype?): Archetype @@ -157,10 +256,26 @@ local function archetypeOf(world: World, types: {i24}, prev: Archetype?): Archet world.nextArchetypeId = id local length = #types - local columns = table.create(length) :: {any} + local columns = table.create(length) - for index in types do - columns[index] = {} + local records = {} + local componentIndex = world.componentIndex + local entityIndex = world.entityIndex + for i, componentId in types do + createArchetypeRecord(componentIndex, id, componentId, i) + records[componentId] = i + columns[i] = {} + + if ECS_IS_PAIR(componentId) then + local first = ecs_get_source(entityIndex, componentId) + local second = ecs_get_target(entityIndex, componentId) + local firstPair = ECS_PAIR(first, WILDCARD) + local secondPair = ECS_PAIR(WILDCARD, second) + createArchetypeRecord(componentIndex, id, firstPair, i) + createArchetypeRecord(componentIndex, id, secondPair, i) + records[firstPair] = i + records[secondPair] = i + end end local archetype = { @@ -168,15 +283,12 @@ local function archetypeOf(world: World, types: {i24}, prev: Archetype?): Archet edges = {}; entities = {}; id = id; - records = {}; + records = records; type = ty; types = types; } world.archetypeIndex[ty] = archetype world.archetypes[id] = archetype - if length > 0 then - createArchetypeRecords(world.componentIndex, archetype, prev) - end return archetype end @@ -186,8 +298,8 @@ World.__index = World function World.new() local self = setmetatable({ archetypeIndex = {}; - archetypes = {}; - componentIndex = {}; + archetypes = {} :: Archetypes; + componentIndex = {} :: ComponentIndex; entityIndex = { dense = {}, sparse = {} @@ -200,107 +312,10 @@ function World.new() nextEntityId = 0; ROOT_ARCHETYPE = (nil :: any) :: Archetype; }, World) + self.ROOT_ARCHETYPE = archetypeOf(self, {}) return self end -local FLAGS_PAIR = 0x8 - -local function addFlags(flags) - local typeFlags = 0x0 - if flags.isPair then - typeFlags = bit32.bor(typeFlags, FLAGS_PAIR) -- HIGHEST bit in the ID. - end - if false then - typeFlags = bit32.bor(typeFlags, 0x4) -- Set the second flag to true - end - if false then - typeFlags = bit32.bor(typeFlags, 0x2) -- Set the third flag to true - end - if false then - typeFlags = bit32.bor(typeFlags, 0x1) -- LAST BIT in the ID. - end - - return typeFlags -end - -local ECS_ID_FLAGS_MASK = 0x10 - --- ECS_ENTITY_MASK (0xFFFFFFFFull << 28) -local ECS_ENTITY_MASK = bit32.lshift(1, 24) - --- ECS_GENERATION_MASK (0xFFFFull << 24) -local ECS_GENERATION_MASK = bit32.lshift(1, 16) - -local function newId(source: number, target: number) - local e = source * 2^28 + target * ECS_ID_FLAGS_MASK - return e -end - -local function isPair(e: number) - return (e % 2^4) // FLAGS_PAIR ~= 0 -end - -function separate(entity: number) - local _typeFlags = entity % 0x10 - entity //= ECS_ID_FLAGS_MASK - return entity // ECS_ENTITY_MASK, entity % ECS_GENERATION_MASK, _typeFlags -end - --- HIGH 24 bits LOW 24 bits -local function ECS_GENERATION(e: i53) - e //= 0x10 - return e % ECS_GENERATION_MASK -end - -local function ECS_ID(e: i53) - e //= 0x10 - return e // ECS_ENTITY_MASK -end - -local function ECS_GENERATION_INC(e: i53) - local id, generation, flags = separate(e) - - return newId(id, generation + 1) + flags -end - --- gets the high ID -local function ECS_PAIR_FIRST(entity: i53): i24 - entity //= 0x10 - local first = entity % ECS_ENTITY_MASK - return first -end - --- gets the low ID -local ECS_PAIR_SECOND = ECS_ID - -local function ECS_PAIR(source: number, target: number) - local id = newId(ECS_PAIR_SECOND(target), ECS_PAIR_SECOND(source)) + addFlags({ isPair = true }) - return id -end - -local function getAlive(entityIndex: EntityIndex, id: i53) - return entityIndex.dense[id] -end - -local function ecs_get_source(entityIndex, e) - assert(isPair(e)) - return getAlive(entityIndex, ECS_PAIR_FIRST(e)) -end -local function ecs_get_target(entityIndex, e) - assert(isPair(e)) - return getAlive(entityIndex, ECS_PAIR_SECOND(e)) -end - -local function nextEntityId(entityIndex, index: i24) - local id = newId(index, 0) - entityIndex.sparse[id] = { - dense = index - } :: Record - entityIndex.dense[index] = id - - return id -end - function World.component(world: World) local componentId = world.nextComponentId + 1 if componentId > HI_COMPONENT_ID then @@ -402,15 +417,16 @@ local function findArchetypeWith(world: World, node: Archetype, componentId: i53 -- Component IDs are added incrementally, so inserting and sorting -- them each time would be expensive. Instead this insertion sort can find the insertion -- point in the types array. + + local destinationType = table.clone(node.types) local at = findInsert(types, componentId) if at == -1 then -- If it finds a duplicate, it just means it is the same archetype so it can return it -- directly instead of needing to hash types for a lookup to the archetype. return node end - - local destinationType = table.clone(node.types) table.insert(destinationType, at, componentId) + return ensureArchetype(world, destinationType, node) end @@ -425,15 +441,7 @@ local function ensureEdge(archetype: Archetype, componentId: i53) end local function archetypeTraverseAdd(world: World, componentId: i53, from: Archetype): Archetype - if not from then - -- If there was no source archetype then it should return the ROOT_ARCHETYPE - local ROOT_ARCHETYPE = world.ROOT_ARCHETYPE - if not ROOT_ARCHETYPE then - ROOT_ARCHETYPE = archetypeOf(world, {}, nil) - world.ROOT_ARCHETYPE = ROOT_ARCHETYPE :: never - end - from = ROOT_ARCHETYPE - end + from = from or world.ROOT_ARCHETYPE local edge = ensureEdge(from, componentId) local add = edge.add @@ -659,14 +667,14 @@ function World.query(world: World, ...: i53): Query function preparedQuery:__iter() return function() local archetype = compatibleArchetype[1] - local row = next(archetype.entities, lastRow) + local row: number = next(archetype.entities, lastRow) :: number while row == nil do lastArchetype, compatibleArchetype = next(compatibleArchetypes, lastArchetype) if lastArchetype == nil then return end archetype = compatibleArchetype[1] - row = next(archetype.entities, row) + row = next(archetype.entities, row) :: number end lastRow = row @@ -764,15 +772,22 @@ end return table.freeze({ World = World; - ON_ADD = ON_ADD; - ON_REMOVE = ON_REMOVE; - ON_SET = ON_SET; + + OnAdd = ON_ADD; + OnRemove = ON_REMOVE; + OnSet = ON_SET; + Wildcard = WILDCARD, + w = WILDCARD, + Rest = REST, + ECS_ID = ECS_ID, - IS_PAIR = isPair, + IS_PAIR = ECS_IS_PAIR, ECS_PAIR = ECS_PAIR, - ECS_GENERATION = ECS_GENERATION, ECS_GENERATION_INC = ECS_GENERATION_INC, - getAlive = getAlive, + ECS_GENERATION = ECS_GENERATION, ecs_get_target = ecs_get_target, - ecs_get_source = ecs_get_source + ecs_get_source = ecs_get_source, + + pair = ECS_PAIR, + getAlive = getAlive, }) diff --git a/lib/init.spec.lua b/lib/init.spec.lua index 553c9a4..8de8de9 100644 --- a/lib/init.spec.lua +++ b/lib/init.spec.lua @@ -309,12 +309,74 @@ return function() elseif id == eAB then expect(data[A]).to.be.ok() expect(data[B]).to.be.ok() - else - error("unknown entity", id) end end - expect(count).to.equal(3) + expect(count).to.equal(5) end) + + it("should allow querying for relations", function() + local world = jecs.World.new() + local Eats = world:entity() + local Apples = world:entity() + local bob = world:entity() + + world:set(bob, jecs.pair(Eats, Apples), true) + for e, bool in world:query(jecs.pair(Eats, Apples)) do + expect(e).to.equal(bob) + expect(bool).to.equal(bool) + end + end) + + it("should allow wildcards in queries", function() + local world = jecs.World.new() + local Eats = world:entity() + local Apples = world:entity() + local bob = world:entity() + + world:set(bob, jecs.pair(Eats, Apples), "bob eats apples") + for e, data in world:query(jecs.pair(Eats, jecs.w)) do + expect(e).to.equal(bob) + expect(data).to.equal("bob eats apples") + end + for e, data in world:query(jecs.pair(jecs.w, Apples)) do + expect(e).to.equal(bob) + expect(data).to.equal("bob eats apples") + end + end) + + it("should match against multiple pairs", function() + local world = jecs.World.new() + local pair = jecs.pair + local Eats = world:entity() + local Apples = world:entity() + local Oranges =world:entity() + local bob = world:entity() + local alice = world:entity() + + world:set(bob, pair(Eats, Apples), "bob eats apples") + world:set(alice, pair(Eats, Oranges), "alice eats oranges") + + local w = jecs.Wildcard + + local count = 0 + for e, data in world:query(pair(Eats, w)) do + count += 1 + if e == bob then + expect(data).to.equal("bob eats apples") + else + expect(data).to.equal("alice eats oranges") + end + end + + expect(count).to.equal(2) + count = 0 + + for e, data in world:query(pair(w, Apples)) do + count += 1 + expect(data).to.equal("bob eats apples") + end + expect(count).to.equal(1) + end) end) end \ No newline at end of file diff --git a/test.project.json b/test.project.json index b931a84..bdcbd0b 100644 --- a/test.project.json +++ b/test.project.json @@ -11,9 +11,6 @@ }, "ReplicatedStorage": { "$className": "ReplicatedStorage", - "DevPackages": { - "$path": "DevPackages" - }, "Lib": { "$path": "lib" }, @@ -25,6 +22,9 @@ }, "mirror": { "$path": "mirror" + }, + "DevPackages": { + "$path": "DevPackages" } }, "TestService": { diff --git a/tests/world.lua b/tests/world.lua index d95097d..1aff493 100644 --- a/tests/world.lua +++ b/tests/world.lua @@ -7,7 +7,6 @@ local ECS_PAIR = jecs.ECS_PAIR local getAlive = jecs.getAlive local ecs_get_source = jecs.ecs_get_source local ecs_get_target = jecs.ecs_get_target -local REST = 256 + 4 local TEST, CASE, CHECK, FINISH, SKIP = testkit.test() @@ -18,7 +17,6 @@ TEST("world", function() local world = jecs.World.new() local A = world:component() local B = world:component() - local eA = world:entity() world:set(eA, A, true) local eB = world:entity() @@ -48,7 +46,6 @@ TEST("world", function() end do CASE "should query all matching entities" - local world = jecs.World.new() local A = world:component() local B = world:component() @@ -71,7 +68,6 @@ TEST("world", function() end do CASE "should query all matching entities when irrelevant component is removed" - local world = jecs.World.new() local A = world:component() local B = world:component() @@ -99,7 +95,6 @@ TEST("world", function() end do CASE "should query all entities without B" - local world = jecs.World.new() local A = world:component() local B = world:component() @@ -171,29 +166,24 @@ TEST("world", function() world:remove(id, Poison) CHECK(world:get(id, Poison) == nil) - print(world:get(id, Health)) CHECK(world:get(id, Health) == 50) end do CASE "should increment generation" local world = jecs.World.new() local e = world:entity() - CHECK(ECS_ID(e) == 1 + REST) + CHECK(ECS_ID(e) == 1 + jecs.Rest) CHECK(getAlive(world.entityIndex, ECS_ID(e)) == e) CHECK(ECS_GENERATION(e) == 0) -- 0 e = ECS_GENERATION_INC(e) CHECK(ECS_GENERATION(e) == 1) -- 1 end - do CASE "relations" + do CASE "should get alive from index in the dense array" local world = jecs.World.new() local _e = world:entity() local e2 = world:entity() local e3 = world:entity() - CHECK(ECS_ID(e2) == 2 + REST) - CHECK(ECS_ID(e3) == 3 + REST) - CHECK(ECS_GENERATION(e2) == 0) - CHECK(ECS_GENERATION(e3) == 0) CHECK(IS_PAIR(world:entity()) == false) @@ -203,6 +193,69 @@ TEST("world", function() CHECK(ecs_get_target(world.entityIndex, pair) == e3) end + do CASE "should allow querying for relations" + local world = jecs.World.new() + local Eats = world:entity() + local Apples = world:entity() + local bob = world:entity() + + world:set(bob, ECS_PAIR(Eats, Apples), true) + for e, bool in world:query(ECS_PAIR(Eats, Apples)) do + CHECK(e == bob) + CHECK(bool) + end + end + + do CASE "should allow wildcards in queries" + local world = jecs.World.new() + local Eats = world:entity() + local Apples = world:entity() + local bob = world:entity() + + world:set(bob, ECS_PAIR(Eats, Apples), "bob eats apples") + + local w = jecs.Wildcard + for e, data in world:query(ECS_PAIR(Eats, w)) do + CHECK(e == bob) + CHECK(data == "bob eats apples") + end + for e, data in world:query(ECS_PAIR(w, Apples)) do + CHECK(e == bob) + CHECK(data == "bob eats apples") + end + end + + do CASE "should match against multiple pairs" + local world = jecs.World.new() + local Eats = world:entity() + local Apples = world:entity() + local Oranges =world:entity() + local bob = world:entity() + local alice = world:entity() + + world:set(bob, ECS_PAIR(Eats, Apples), "bob eats apples") + world:set(alice, ECS_PAIR(Eats, Oranges), "alice eats oranges") + + local w = jecs.Wildcard + local count = 0 + for e, data in world:query(ECS_PAIR(Eats, w)) do + count += 1 + if e == bob then + CHECK(data == "bob eats apples") + else + CHECK(data == "alice eats oranges") + end + end + + CHECK(count == 2) + count = 0 + + for e, data in world:query(ECS_PAIR(w, Apples)) do + count += 1 + CHECK(data == "bob eats apples") + end + CHECK(count == 1) + end end) FINISH() \ No newline at end of file diff --git a/wally.toml b/wally.toml index 7102c41..f17e660 100644 --- a/wally.toml +++ b/wally.toml @@ -6,4 +6,5 @@ realm = "shared" include = ["default.project.json", "lib/**", "lib", "wally.toml", "README.md"] exclude = ["**"] -[dev-dependencies] \ No newline at end of file +[dev-dependencies] +TestEZ = "roblox/testez@0.4.1" \ No newline at end of file