Add wildcards (#37)

* Fix export

* Initial commit

* Uncomment cases

* Rename case

* Add tests for wildcards

* Support wildcards in records

* Add tests for relation data

* Add shorthands

* Change casing of exports

* Change function signatures

* Improve inlining of ECS_PAIR

* Delete whitespace

* Create root archetype

* Add back tests

* Fix tests
This commit is contained in:
Marcus 2024-05-13 00:53:51 +02:00 committed by GitHub
parent 582b09be66
commit 2df5f3f18e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 292 additions and 161 deletions

View file

@ -44,11 +44,118 @@ type ArchetypeDiff = {
removed: Ty, removed: Ty,
} }
local FLAGS_PAIR = 0x8
local HI_COMPONENT_ID = 256 local HI_COMPONENT_ID = 256
local ON_ADD = HI_COMPONENT_ID + 1 local ON_ADD = HI_COMPONENT_ID + 1
local ON_REMOVE = HI_COMPONENT_ID + 2 local ON_REMOVE = HI_COMPONENT_ID + 2
local ON_SET = HI_COMPONENT_ID + 3 local ON_SET = HI_COMPONENT_ID + 3
local REST = HI_COMPONENT_ID + 4 local WILDCARD = HI_COMPONENT_ID + 4
local REST = HI_COMPONENT_ID + 5
local ECS_ID_FLAGS_MASK = 0x10
local ECS_ENTITY_MASK = bit32.lshift(1, 24)
local ECS_GENERATION_MASK = bit32.lshift(1, 16)
local function addFlags(isPair: boolean)
local typeFlags = 0x0
if isPair then
typeFlags = bit32.bor(typeFlags, FLAGS_PAIR) -- HIGHEST bit in the ID.
end
if false then
typeFlags = bit32.bor(typeFlags, 0x4) -- Set the second flag to true
end
if false then
typeFlags = bit32.bor(typeFlags, 0x2) -- Set the third flag to true
end
if false then
typeFlags = bit32.bor(typeFlags, 0x1) -- LAST BIT in the ID.
end
return typeFlags
end
local function newId(source: number, target: number)
local e = source * 2^28 + target * ECS_ID_FLAGS_MASK
return e
end
local function ECS_IS_PAIR(e: number)
return (e % 2^4) // FLAGS_PAIR ~= 0
end
function separate(entity: number)
local _typeFlags = entity % 0x10
entity //= ECS_ID_FLAGS_MASK
return entity // ECS_ENTITY_MASK, entity % ECS_GENERATION_MASK, _typeFlags
end
-- HIGH 24 bits LOW 24 bits
local function ECS_GENERATION(e: i53)
e //= 0x10
return e % ECS_GENERATION_MASK
end
local function ECS_ID(e: i53)
e //= 0x10
return e // ECS_ENTITY_MASK
end
local function ECS_GENERATION_INC(e: i53)
local id, generation, flags = separate(e)
return newId(id, generation + 1) + flags
end
-- gets the high ID
local function ECS_PAIR_FIRST(entity: i53): i24
entity //= 0x10
local first = entity % ECS_ENTITY_MASK
return first
end
-- gets the low ID
local ECS_PAIR_SECOND = ECS_ID
local function ECS_PAIR(first: number, second: number)
local target = WILDCARD
local relation
if first == WILDCARD then
relation = second
elseif second == WILDCARD then
relation = first
else
relation = second
target = ECS_PAIR_SECOND(first)
end
return newId(
ECS_PAIR_SECOND(relation), target) + addFlags(--[[isPair]] true)
end
local function getAlive(entityIndex: EntityIndex, id: i53)
return entityIndex.dense[id]
end
local function ecs_get_source(entityIndex, e)
assert(ECS_IS_PAIR(e))
return getAlive(entityIndex, ECS_PAIR_FIRST(e))
end
local function ecs_get_target(entityIndex, e)
assert(ECS_IS_PAIR(e))
return getAlive(entityIndex, ECS_PAIR_SECOND(e))
end
local function nextEntityId(entityIndex, index: i24)
local id = newId(index, 0)
entityIndex.sparse[id] = {
dense = index
} :: Record
entityIndex.dense[index] = id
return id
end
local function transitionArchetype( local function transitionArchetype(
entityIndex: EntityIndex, entityIndex: EntityIndex,
@ -132,22 +239,14 @@ local function hash(arr): string | number
return table.concat(arr, "_") return table.concat(arr, "_")
end end
local function createArchetypeRecords(componentIndex: ComponentIndex, to: Archetype, _from: Archetype?) local function createArchetypeRecord(componentIndex, id, componentId, i)
local destinationIds = to.types local archetypesMap = componentIndex[componentId]
local records = to.records
local id = to.id
for i, destinationId in destinationIds do if not archetypesMap then
local archetypesMap = componentIndex[destinationId] archetypesMap = {size = 0, sparse = {}}
componentIndex[componentId] = archetypesMap
if not archetypesMap then
archetypesMap = {size = 0, sparse = {}}
componentIndex[destinationId] = archetypesMap
end
archetypesMap.sparse[id] = i
records[destinationId] = i
end end
archetypesMap.sparse[id] = i
end end
local function archetypeOf(world: World, types: {i24}, prev: Archetype?): Archetype local function archetypeOf(world: World, types: {i24}, prev: Archetype?): Archetype
@ -157,10 +256,26 @@ local function archetypeOf(world: World, types: {i24}, prev: Archetype?): Archet
world.nextArchetypeId = id world.nextArchetypeId = id
local length = #types local length = #types
local columns = table.create(length) :: {any} local columns = table.create(length)
for index in types do local records = {}
columns[index] = {} local componentIndex = world.componentIndex
local entityIndex = world.entityIndex
for i, componentId in types do
createArchetypeRecord(componentIndex, id, componentId, i)
records[componentId] = i
columns[i] = {}
if ECS_IS_PAIR(componentId) then
local first = ecs_get_source(entityIndex, componentId)
local second = ecs_get_target(entityIndex, componentId)
local firstPair = ECS_PAIR(first, WILDCARD)
local secondPair = ECS_PAIR(WILDCARD, second)
createArchetypeRecord(componentIndex, id, firstPair, i)
createArchetypeRecord(componentIndex, id, secondPair, i)
records[firstPair] = i
records[secondPair] = i
end
end end
local archetype = { local archetype = {
@ -168,15 +283,12 @@ local function archetypeOf(world: World, types: {i24}, prev: Archetype?): Archet
edges = {}; edges = {};
entities = {}; entities = {};
id = id; id = id;
records = {}; records = records;
type = ty; type = ty;
types = types; types = types;
} }
world.archetypeIndex[ty] = archetype world.archetypeIndex[ty] = archetype
world.archetypes[id] = archetype world.archetypes[id] = archetype
if length > 0 then
createArchetypeRecords(world.componentIndex, archetype, prev)
end
return archetype return archetype
end end
@ -186,8 +298,8 @@ World.__index = World
function World.new() function World.new()
local self = setmetatable({ local self = setmetatable({
archetypeIndex = {}; archetypeIndex = {};
archetypes = {}; archetypes = {} :: Archetypes;
componentIndex = {}; componentIndex = {} :: ComponentIndex;
entityIndex = { entityIndex = {
dense = {}, dense = {},
sparse = {} sparse = {}
@ -200,107 +312,10 @@ function World.new()
nextEntityId = 0; nextEntityId = 0;
ROOT_ARCHETYPE = (nil :: any) :: Archetype; ROOT_ARCHETYPE = (nil :: any) :: Archetype;
}, World) }, World)
self.ROOT_ARCHETYPE = archetypeOf(self, {})
return self return self
end end
local FLAGS_PAIR = 0x8
local function addFlags(flags)
local typeFlags = 0x0
if flags.isPair then
typeFlags = bit32.bor(typeFlags, FLAGS_PAIR) -- HIGHEST bit in the ID.
end
if false then
typeFlags = bit32.bor(typeFlags, 0x4) -- Set the second flag to true
end
if false then
typeFlags = bit32.bor(typeFlags, 0x2) -- Set the third flag to true
end
if false then
typeFlags = bit32.bor(typeFlags, 0x1) -- LAST BIT in the ID.
end
return typeFlags
end
local ECS_ID_FLAGS_MASK = 0x10
-- ECS_ENTITY_MASK (0xFFFFFFFFull << 28)
local ECS_ENTITY_MASK = bit32.lshift(1, 24)
-- ECS_GENERATION_MASK (0xFFFFull << 24)
local ECS_GENERATION_MASK = bit32.lshift(1, 16)
local function newId(source: number, target: number)
local e = source * 2^28 + target * ECS_ID_FLAGS_MASK
return e
end
local function isPair(e: number)
return (e % 2^4) // FLAGS_PAIR ~= 0
end
function separate(entity: number)
local _typeFlags = entity % 0x10
entity //= ECS_ID_FLAGS_MASK
return entity // ECS_ENTITY_MASK, entity % ECS_GENERATION_MASK, _typeFlags
end
-- HIGH 24 bits LOW 24 bits
local function ECS_GENERATION(e: i53)
e //= 0x10
return e % ECS_GENERATION_MASK
end
local function ECS_ID(e: i53)
e //= 0x10
return e // ECS_ENTITY_MASK
end
local function ECS_GENERATION_INC(e: i53)
local id, generation, flags = separate(e)
return newId(id, generation + 1) + flags
end
-- gets the high ID
local function ECS_PAIR_FIRST(entity: i53): i24
entity //= 0x10
local first = entity % ECS_ENTITY_MASK
return first
end
-- gets the low ID
local ECS_PAIR_SECOND = ECS_ID
local function ECS_PAIR(source: number, target: number)
local id = newId(ECS_PAIR_SECOND(target), ECS_PAIR_SECOND(source)) + addFlags({ isPair = true })
return id
end
local function getAlive(entityIndex: EntityIndex, id: i53)
return entityIndex.dense[id]
end
local function ecs_get_source(entityIndex, e)
assert(isPair(e))
return getAlive(entityIndex, ECS_PAIR_FIRST(e))
end
local function ecs_get_target(entityIndex, e)
assert(isPair(e))
return getAlive(entityIndex, ECS_PAIR_SECOND(e))
end
local function nextEntityId(entityIndex, index: i24)
local id = newId(index, 0)
entityIndex.sparse[id] = {
dense = index
} :: Record
entityIndex.dense[index] = id
return id
end
function World.component(world: World) function World.component(world: World)
local componentId = world.nextComponentId + 1 local componentId = world.nextComponentId + 1
if componentId > HI_COMPONENT_ID then if componentId > HI_COMPONENT_ID then
@ -402,15 +417,16 @@ local function findArchetypeWith(world: World, node: Archetype, componentId: i53
-- Component IDs are added incrementally, so inserting and sorting -- Component IDs are added incrementally, so inserting and sorting
-- them each time would be expensive. Instead this insertion sort can find the insertion -- them each time would be expensive. Instead this insertion sort can find the insertion
-- point in the types array. -- point in the types array.
local destinationType = table.clone(node.types)
local at = findInsert(types, componentId) local at = findInsert(types, componentId)
if at == -1 then if at == -1 then
-- If it finds a duplicate, it just means it is the same archetype so it can return it -- If it finds a duplicate, it just means it is the same archetype so it can return it
-- directly instead of needing to hash types for a lookup to the archetype. -- directly instead of needing to hash types for a lookup to the archetype.
return node return node
end end
local destinationType = table.clone(node.types)
table.insert(destinationType, at, componentId) table.insert(destinationType, at, componentId)
return ensureArchetype(world, destinationType, node) return ensureArchetype(world, destinationType, node)
end end
@ -425,15 +441,7 @@ local function ensureEdge(archetype: Archetype, componentId: i53)
end end
local function archetypeTraverseAdd(world: World, componentId: i53, from: Archetype): Archetype local function archetypeTraverseAdd(world: World, componentId: i53, from: Archetype): Archetype
if not from then from = from or world.ROOT_ARCHETYPE
-- If there was no source archetype then it should return the ROOT_ARCHETYPE
local ROOT_ARCHETYPE = world.ROOT_ARCHETYPE
if not ROOT_ARCHETYPE then
ROOT_ARCHETYPE = archetypeOf(world, {}, nil)
world.ROOT_ARCHETYPE = ROOT_ARCHETYPE :: never
end
from = ROOT_ARCHETYPE
end
local edge = ensureEdge(from, componentId) local edge = ensureEdge(from, componentId)
local add = edge.add local add = edge.add
@ -659,14 +667,14 @@ function World.query(world: World, ...: i53): Query
function preparedQuery:__iter() function preparedQuery:__iter()
return function() return function()
local archetype = compatibleArchetype[1] local archetype = compatibleArchetype[1]
local row = next(archetype.entities, lastRow) local row: number = next(archetype.entities, lastRow) :: number
while row == nil do while row == nil do
lastArchetype, compatibleArchetype = next(compatibleArchetypes, lastArchetype) lastArchetype, compatibleArchetype = next(compatibleArchetypes, lastArchetype)
if lastArchetype == nil then if lastArchetype == nil then
return return
end end
archetype = compatibleArchetype[1] archetype = compatibleArchetype[1]
row = next(archetype.entities, row) row = next(archetype.entities, row) :: number
end end
lastRow = row lastRow = row
@ -764,15 +772,22 @@ end
return table.freeze({ return table.freeze({
World = World; World = World;
ON_ADD = ON_ADD;
ON_REMOVE = ON_REMOVE; OnAdd = ON_ADD;
ON_SET = ON_SET; OnRemove = ON_REMOVE;
OnSet = ON_SET;
Wildcard = WILDCARD,
w = WILDCARD,
Rest = REST,
ECS_ID = ECS_ID, ECS_ID = ECS_ID,
IS_PAIR = isPair, IS_PAIR = ECS_IS_PAIR,
ECS_PAIR = ECS_PAIR, ECS_PAIR = ECS_PAIR,
ECS_GENERATION = ECS_GENERATION,
ECS_GENERATION_INC = ECS_GENERATION_INC, ECS_GENERATION_INC = ECS_GENERATION_INC,
getAlive = getAlive, ECS_GENERATION = ECS_GENERATION,
ecs_get_target = ecs_get_target, ecs_get_target = ecs_get_target,
ecs_get_source = ecs_get_source ecs_get_source = ecs_get_source,
pair = ECS_PAIR,
getAlive = getAlive,
}) })

View file

@ -309,12 +309,74 @@ return function()
elseif id == eAB then elseif id == eAB then
expect(data[A]).to.be.ok() expect(data[A]).to.be.ok()
expect(data[B]).to.be.ok() expect(data[B]).to.be.ok()
else
error("unknown entity", id)
end end
end end
expect(count).to.equal(3) expect(count).to.equal(5)
end) end)
it("should allow querying for relations", function()
local world = jecs.World.new()
local Eats = world:entity()
local Apples = world:entity()
local bob = world:entity()
world:set(bob, jecs.pair(Eats, Apples), true)
for e, bool in world:query(jecs.pair(Eats, Apples)) do
expect(e).to.equal(bob)
expect(bool).to.equal(bool)
end
end)
it("should allow wildcards in queries", function()
local world = jecs.World.new()
local Eats = world:entity()
local Apples = world:entity()
local bob = world:entity()
world:set(bob, jecs.pair(Eats, Apples), "bob eats apples")
for e, data in world:query(jecs.pair(Eats, jecs.w)) do
expect(e).to.equal(bob)
expect(data).to.equal("bob eats apples")
end
for e, data in world:query(jecs.pair(jecs.w, Apples)) do
expect(e).to.equal(bob)
expect(data).to.equal("bob eats apples")
end
end)
it("should match against multiple pairs", function()
local world = jecs.World.new()
local pair = jecs.pair
local Eats = world:entity()
local Apples = world:entity()
local Oranges =world:entity()
local bob = world:entity()
local alice = world:entity()
world:set(bob, pair(Eats, Apples), "bob eats apples")
world:set(alice, pair(Eats, Oranges), "alice eats oranges")
local w = jecs.Wildcard
local count = 0
for e, data in world:query(pair(Eats, w)) do
count += 1
if e == bob then
expect(data).to.equal("bob eats apples")
else
expect(data).to.equal("alice eats oranges")
end
end
expect(count).to.equal(2)
count = 0
for e, data in world:query(pair(w, Apples)) do
count += 1
expect(data).to.equal("bob eats apples")
end
expect(count).to.equal(1)
end)
end) end)
end end

View file

@ -11,9 +11,6 @@
}, },
"ReplicatedStorage": { "ReplicatedStorage": {
"$className": "ReplicatedStorage", "$className": "ReplicatedStorage",
"DevPackages": {
"$path": "DevPackages"
},
"Lib": { "Lib": {
"$path": "lib" "$path": "lib"
}, },
@ -25,6 +22,9 @@
}, },
"mirror": { "mirror": {
"$path": "mirror" "$path": "mirror"
},
"DevPackages": {
"$path": "DevPackages"
} }
}, },
"TestService": { "TestService": {

View file

@ -7,7 +7,6 @@ local ECS_PAIR = jecs.ECS_PAIR
local getAlive = jecs.getAlive local getAlive = jecs.getAlive
local ecs_get_source = jecs.ecs_get_source local ecs_get_source = jecs.ecs_get_source
local ecs_get_target = jecs.ecs_get_target local ecs_get_target = jecs.ecs_get_target
local REST = 256 + 4
local TEST, CASE, CHECK, FINISH, SKIP = testkit.test() local TEST, CASE, CHECK, FINISH, SKIP = testkit.test()
@ -18,7 +17,6 @@ TEST("world", function()
local world = jecs.World.new() local world = jecs.World.new()
local A = world:component() local A = world:component()
local B = world:component() local B = world:component()
local eA = world:entity() local eA = world:entity()
world:set(eA, A, true) world:set(eA, A, true)
local eB = world:entity() local eB = world:entity()
@ -48,7 +46,6 @@ TEST("world", function()
end end
do CASE "should query all matching entities" do CASE "should query all matching entities"
local world = jecs.World.new() local world = jecs.World.new()
local A = world:component() local A = world:component()
local B = world:component() local B = world:component()
@ -71,7 +68,6 @@ TEST("world", function()
end end
do CASE "should query all matching entities when irrelevant component is removed" do CASE "should query all matching entities when irrelevant component is removed"
local world = jecs.World.new() local world = jecs.World.new()
local A = world:component() local A = world:component()
local B = world:component() local B = world:component()
@ -99,7 +95,6 @@ TEST("world", function()
end end
do CASE "should query all entities without B" do CASE "should query all entities without B"
local world = jecs.World.new() local world = jecs.World.new()
local A = world:component() local A = world:component()
local B = world:component() local B = world:component()
@ -171,29 +166,24 @@ TEST("world", function()
world:remove(id, Poison) world:remove(id, Poison)
CHECK(world:get(id, Poison) == nil) CHECK(world:get(id, Poison) == nil)
print(world:get(id, Health))
CHECK(world:get(id, Health) == 50) CHECK(world:get(id, Health) == 50)
end end
do CASE "should increment generation" do CASE "should increment generation"
local world = jecs.World.new() local world = jecs.World.new()
local e = world:entity() local e = world:entity()
CHECK(ECS_ID(e) == 1 + REST) CHECK(ECS_ID(e) == 1 + jecs.Rest)
CHECK(getAlive(world.entityIndex, ECS_ID(e)) == e) CHECK(getAlive(world.entityIndex, ECS_ID(e)) == e)
CHECK(ECS_GENERATION(e) == 0) -- 0 CHECK(ECS_GENERATION(e) == 0) -- 0
e = ECS_GENERATION_INC(e) e = ECS_GENERATION_INC(e)
CHECK(ECS_GENERATION(e) == 1) -- 1 CHECK(ECS_GENERATION(e) == 1) -- 1
end end
do CASE "relations" do CASE "should get alive from index in the dense array"
local world = jecs.World.new() local world = jecs.World.new()
local _e = world:entity() local _e = world:entity()
local e2 = world:entity() local e2 = world:entity()
local e3 = world:entity() local e3 = world:entity()
CHECK(ECS_ID(e2) == 2 + REST)
CHECK(ECS_ID(e3) == 3 + REST)
CHECK(ECS_GENERATION(e2) == 0)
CHECK(ECS_GENERATION(e3) == 0)
CHECK(IS_PAIR(world:entity()) == false) CHECK(IS_PAIR(world:entity()) == false)
@ -203,6 +193,69 @@ TEST("world", function()
CHECK(ecs_get_target(world.entityIndex, pair) == e3) CHECK(ecs_get_target(world.entityIndex, pair) == e3)
end end
do CASE "should allow querying for relations"
local world = jecs.World.new()
local Eats = world:entity()
local Apples = world:entity()
local bob = world:entity()
world:set(bob, ECS_PAIR(Eats, Apples), true)
for e, bool in world:query(ECS_PAIR(Eats, Apples)) do
CHECK(e == bob)
CHECK(bool)
end
end
do CASE "should allow wildcards in queries"
local world = jecs.World.new()
local Eats = world:entity()
local Apples = world:entity()
local bob = world:entity()
world:set(bob, ECS_PAIR(Eats, Apples), "bob eats apples")
local w = jecs.Wildcard
for e, data in world:query(ECS_PAIR(Eats, w)) do
CHECK(e == bob)
CHECK(data == "bob eats apples")
end
for e, data in world:query(ECS_PAIR(w, Apples)) do
CHECK(e == bob)
CHECK(data == "bob eats apples")
end
end
do CASE "should match against multiple pairs"
local world = jecs.World.new()
local Eats = world:entity()
local Apples = world:entity()
local Oranges =world:entity()
local bob = world:entity()
local alice = world:entity()
world:set(bob, ECS_PAIR(Eats, Apples), "bob eats apples")
world:set(alice, ECS_PAIR(Eats, Oranges), "alice eats oranges")
local w = jecs.Wildcard
local count = 0
for e, data in world:query(ECS_PAIR(Eats, w)) do
count += 1
if e == bob then
CHECK(data == "bob eats apples")
else
CHECK(data == "alice eats oranges")
end
end
CHECK(count == 2)
count = 0
for e, data in world:query(ECS_PAIR(w, Apples)) do
count += 1
CHECK(data == "bob eats apples")
end
CHECK(count == 1)
end
end) end)
FINISH() FINISH()

View file

@ -6,4 +6,5 @@ realm = "shared"
include = ["default.project.json", "lib/**", "lib", "wally.toml", "README.md"] include = ["default.project.json", "lib/**", "lib", "wally.toml", "README.md"]
exclude = ["**"] exclude = ["**"]
[dev-dependencies] [dev-dependencies]
TestEZ = "roblox/testez@0.4.1"