diff --git a/benches/visual/query.bench.luau b/benches/visual/query.bench.luau index 37df156..97e27c9 100644 --- a/benches/visual/query.bench.luau +++ b/benches/visual/query.bench.luau @@ -57,7 +57,7 @@ local function flip() end local common = 0 -local N = 500 +local N = 2^16-2 local archetypes = {} local hm = 0 @@ -171,18 +171,14 @@ return { Functions = { Mirror = function() - for i = 1, 1000 do - for entityId, firstComponent in mcs:query(E1, E4) do - end + for entityId, firstComponent in mcs:query(E1, E2, E3, E4) do end end, Jecs = function() - for i = 1, 1000 do - for entityId, firstComponent in ecs:query(D1, D4) do - end - end + for entityId, firstComponent in ecs:query(D1, D2, D3, D4) do + end end, }, } diff --git a/src/init.luau b/src/init.luau index f5c6c86..6716b58 100644 --- a/src/init.luau +++ b/src/init.luau @@ -720,15 +720,18 @@ do return nil :: any end + local Arm = function(self: Query, ...) + return self + end local EmptyQuery: Query = { __iter = function(): Item return noop end, + drain = Arm, next = noop :: Item, replace = noop :: (Query, ...any) -> (), - without = function(self: Query, ...) - return self - end + with = Arm, + without = Arm, } setmetatable(EmptyQuery, EmptyQuery) @@ -741,11 +744,75 @@ do local i: number local compatible_archetypes: { Archetype } - local column_indices: { { number} } local ids: { number } - local tr local columns + local A, B, C, D, E, F, G, H + local a, b, c, d, e, f, g, h + + local function query_init(query) + lastArchetype = 1 + archetype = compatible_archetypes[lastArchetype] + + if not archetype then + return + end + + queryOutput = {} + queryLength = #ids + + entities = archetype.entities + i = #entities + columns = archetype.columns + + local records = archetype.records + if not B then + a = records[A] + elseif not C then + a = records[A] + b = records[B] + elseif not D then + a = records[A] + b = records[B] + c = records[C] + elseif not E then + a = records[A] + b = records[B] + c = records[C] + d = records[D] + elseif not F then + a = records[A] + b = records[B] + c = records[C] + d = records[D] + e = records[E] + elseif not G then + a = records[A] + b = records[B] + c = records[C] + d = records[D] + e = records[E] + f = records[F] + elseif not H then + a = records[A] + b = records[B] + c = records[C] + d = records[D] + e = records[E] + f = records[F] + g = records[G] + elseif H then + a = records[A] + b = records[B] + c = records[C] + d = records[D] + e = records[E] + f = records[F] + g = records[G] + h = records[H] + end + end + local function world_query_next(): any local entityId = entities[i] while entityId == nil do @@ -754,7 +821,52 @@ do if not archetype then return nil end - tr = column_indices[lastArchetype] + local records = archetype.records + if not B then + a = records[A] + elseif not C then + a = records[A] + b = records[B] + elseif not D then + a = records[A] + b = records[B] + c = records[C] + elseif not E then + a = records[A] + b = records[B] + c = records[C] + d = records[D] + elseif not F then + a = records[A] + b = records[B] + c = records[C] + d = records[D] + e = records[E] + elseif not G then + a = records[A] + b = records[B] + c = records[C] + d = records[D] + e = records[E] + f = records[F] + elseif not H then + a = records[A] + b = records[B] + c = records[C] + d = records[D] + e = records[E] + f = records[F] + g = records[G] + elseif H then + a = records[A] + b = records[B] + c = records[C] + d = records[D] + e = records[E] + f = records[F] + g = records[G] + h = records[H] + end columns = archetype.columns entities = archetype.entities i = #entities @@ -765,60 +877,57 @@ do i-=1 if queryLength == 1 then - return entityId, columns[tr[1]][row] + return entityId, columns[a][row] elseif queryLength == 2 then - return entityId, columns[tr[1]][row], columns[tr[2]][row] + return entityId, columns[a][row], columns[b][row] elseif queryLength == 3 then - return entityId, columns[tr[1]][row], columns[tr[2]][row], columns[tr[3]][row] + return entityId, columns[a][row], columns[b][row], columns[c][row] elseif queryLength == 4 then - return entityId, columns[tr[1]][row], columns[tr[2]][row], columns[tr[3]][row], columns[tr[4]][row] + return entityId, columns[a][row], columns[b][row], columns[c][row], columns[d][row] elseif queryLength == 5 then return entityId, - columns[tr[1]][row], - columns[tr[2]][row], - columns[tr[3]][row], - columns[tr[4]][row], - columns[tr[5]][row] + columns[a][row], + columns[b][row], + columns[c][row], + columns[d][row], + columns[e][row] elseif queryLength == 6 then return entityId, - columns[tr[1]][row], - columns[tr[2]][row], - columns[tr[3]][row], - columns[tr[4]][row], - columns[tr[5]][row], - columns[tr[6]][row] + columns[a][row], + columns[b][row], + columns[c][row], + columns[d][row], + columns[e][row], + columns[f][row] elseif queryLength == 7 then return entityId, - columns[tr[1]][row], - columns[tr[2]][row], - columns[tr[3]][row], - columns[tr[4]][row], - columns[tr[5]][row], - columns[tr[6]][row], - columns[tr[7]][row] + columns[a][row], + columns[b][row], + columns[c][row], + columns[d][row], + columns[e][row], + columns[f][row], + columns[g][row] elseif queryLength == 8 then return entityId, - columns[tr[1]][row], - columns[tr[2]][row], - columns[tr[3]][row], - columns[tr[4]][row], - columns[tr[5]][row], - columns[tr[6]][row], - columns[tr[7]][row], - columns[tr[8]][row] + columns[a][row], + columns[b][row], + columns[c][row], + columns[d][row], + columns[e][row], + columns[f][row], + columns[g][row], + columns[h][row] end - for j in ids do - queryOutput[j] = columns[tr[j]][row] - end + local field = archetype.records + for j, id in ids do + queryOutput[j] = columns[field[id]][row] + end return entityId, unpack(queryOutput, 1, queryLength) end - local function world_query_iter() - return world_query_next - end - local function world_query_without(self, ...) local withoutComponents = { ... } for i = #compatible_archetypes, 1, -1 do @@ -837,23 +946,11 @@ do local last = #compatible_archetypes if last ~= i then compatible_archetypes[i] = compatible_archetypes[last] - column_indices[i] = column_indices[last] end compatible_archetypes[last] = nil - column_indices[last] = nil end end - archetype = compatible_archetypes[lastArchetype] - if not archetype then - return EmptyQuery - end - - entities = archetype.entities - columns = archetype.columns - tr = column_indices[lastArchetype] - i = #entities - return self end @@ -863,40 +960,42 @@ do end end - local function world_query_replace(_, fn: (...any) -> (...any)) - for i, archetype in compatible_archetypes do - local tr = column_indices[i] - local columns = archetype.columns + local function world_query_replace(query, fn: (...any) -> (...any)) + query_init(query) + for i, archetype in compatible_archetypes do + local columns = archetype.columns + local tr = archetype.records for row in archetype.entities do if queryLength == 1 then - local a = columns[tr[1]] - local pa = fn(a[row]) + local va = columns[tr[a]] + local pa = fn(va[row]) - a[row] = pa + va[row] = pa elseif queryLength == 2 then - local a = columns[tr[1]] - local b = columns[tr[2]] + local va = columns[tr[a]] + local vb = columns[tr[b]] - a[row], b[row] = fn(a[row], b[row]) + va[row], vb[row] = fn(va[row], vb[row]) elseif queryLength == 3 then - local a = columns[tr[1]] - local b = columns[tr[2]] - local c = columns[tr[3]] + local va = columns[tr[a]] + local vb = columns[tr[b]] + local vc = columns[tr[c]] - a[row], b[row], c[row] = fn(a[row], b[row], c[row]) + va[row], vb[row], vc[row] = fn(va[row], vb[row], vc[row]) elseif queryLength == 4 then - local a = columns[tr[1]] - local b = columns[tr[2]] - local c = columns[tr[3]] - local d = columns[tr[4]] + local a = columns[tr[a]] + local b = columns[tr[b]] + local c = columns[tr[c]] + local d = columns[tr[d]] a[row], b[row], c[row], d[row] = fn( a[row], b[row], c[row], d[row]) else - for i = 1, queryLength do - queryOutput[i] = columns[tr[i]][row] - end + local field = archetype.records + for j, id in ids do + queryOutput[j] = columns[field[id]][row] + end world_query_replace_values(row, columns, fn(unpack(queryOutput))) end @@ -922,23 +1021,11 @@ do local last = #compatible_archetypes if last ~= i then compatible_archetypes[i] = compatible_archetypes[last] - column_indices[i] = column_indices[last] end compatible_archetypes[last] = nil - column_indices[last] = nil end end - archetype = compatible_archetypes[lastArchetype] - if not archetype then - return EmptyQuery - end - - entities = archetype.entities - columns = archetype.columns - tr = column_indices[lastArchetype] - i = #entities - return query end @@ -949,8 +1036,24 @@ do return compatible_archetypes end + local drain + + local function world_query_drain(query) + drain = true + query_init(query) + return query + end + + local function world_query_iter(query) + if not drain then + query_init(query) + end + return world_query_next + end + local it = { __iter = world_query_iter, + drain = world_query_drain, next = world_query_next, with = world_query_with, without = world_query_without, @@ -966,11 +1069,11 @@ do error("Missing components") end - local indices = {} compatible_archetypes = {} local length = 0 local components = { ... } :: any + A, B, C, D, E, F, G, H = ... local archetypes = world.archetypes local firstArchetypeMap: ArchetypeMap @@ -991,7 +1094,6 @@ do local compatibleArchetype = archetypes[id] local archetypeRecords = compatibleArchetype.records - local records = {} local skip = false for i, componentId in components do @@ -1000,8 +1102,6 @@ do skip = true break end - -- index should be index.offset - records[i] = index end if skip then @@ -1010,27 +1110,11 @@ do length += 1 compatible_archetypes[length] = compatibleArchetype - indices[length] = records end - column_indices = indices + drain = false ids = components - lastArchetype = 1 - archetype = compatible_archetypes[lastArchetype] - - if not archetype then - return EmptyQuery - end - - queryOutput = {} - queryLength = #ids - - entities = archetype.entities - i = #entities - tr = column_indices[lastArchetype] - columns = archetype.columns - return it end end diff --git a/test/tests.luau b/test/tests.luau index 2d1293a..c931b9d 100644 --- a/test/tests.luau +++ b/test/tests.luau @@ -246,7 +246,8 @@ TEST("world:query()", function() world:set(eAB, A, true) world:set(eAB, B, true) - local q = world:query(A) + -- Should drain the iterator + local q = world:query(A):drain() local i = 0 local j = 0 @@ -740,7 +741,7 @@ do local function changes_added() added = true - local q = world:query(T):without(PreviousT) + local q = world:query(T):without(PreviousT):drain() return function() local id, data = q:next() if not id then @@ -758,7 +759,7 @@ do end local function changes_changed() - local q = world:query(T, PreviousT) + local q = world:query(T, PreviousT):drain() return function() local id, new, old = q:next() @@ -785,7 +786,7 @@ do local function changes_removed() removed = true - local q = world:query(PreviousT):without(T) + local q = world:query(PreviousT):without(T):drain() return function() local id = q:next() if id then