From 6123946e11d9c8665896164499d4d1b1940c18c1 Mon Sep 17 00:00:00 2001 From: Ukendio Date: Tue, 14 Jan 2025 11:09:18 +0100 Subject: [PATCH] Improve cached queries --- benches/query.luau | 2 +- .../ReplicatedStorage/std/changetracker.luau | 139 ----------- examples/luau/queries/changetracking.luau | 231 +++--------------- jecs.luau | 217 +++++++++------- test/tests.luau | 62 +++++ 5 files changed, 218 insertions(+), 433 deletions(-) delete mode 100644 demo/src/ReplicatedStorage/std/changetracker.luau diff --git a/benches/query.luau b/benches/query.luau index bff674c..ecd7fd5 100644 --- a/benches/query.luau +++ b/benches/query.luau @@ -9,7 +9,7 @@ local function TITLE(title: string) end local jecs = require("@jecs") -local mirror = require("../mirror/init") +local mirror = require("@mirror") type i53 = number diff --git a/demo/src/ReplicatedStorage/std/changetracker.luau b/demo/src/ReplicatedStorage/std/changetracker.luau deleted file mode 100644 index 078baa7..0000000 --- a/demo/src/ReplicatedStorage/std/changetracker.luau +++ /dev/null @@ -1,139 +0,0 @@ -local jecs = require(game:GetService("ReplicatedStorage").ecs) -type World = jecs.World - -type Tracker = { - track: ( - world: World, - fn: ( - changes: { - added: () -> () -> (number, T), - removed: () -> () -> number, - changed: () -> () -> (number, T, T), - } - ) -> () - ) -> (), -} - -type Entity = number & { __nominal_type_dont_use: T } - -local function diff(a, b) - local size = 0 - for k, v in a do - if b[k] ~= v then - return true - end - size += 1 - end - for k, v in b do - size -= 1 - end - - if size ~= 0 then - return true - end - - return false -end - -local function ChangeTracker(world: World, T: Entity): Tracker - local sparse = world.entityIndex.sparse - local PreviousT = jecs.pair(jecs.Rest, T) - local add = {} - local added - local removed - local is_trivial - - local function changes_added() - added = true - local q = world:query(T):without(PreviousT):drain() - return function() - local id, data = q.next() - if not id then - return nil - end - - is_trivial = typeof(data) ~= "table" - - add[id] = data - - return id, data - end - end - - local function changes_changed() - local q = world:query(T, PreviousT):drain() - - return function() - local id, new, old = q.next() - while true do - if not id then - return nil - end - - if not is_trivial then - if diff(new, old) then - break - end - elseif new ~= old then - break - end - - id, new, old = q.next() - end - - local record = sparse[id] - local archetype = record.archetype - local column = archetype.records[PreviousT].column - local data = if is_trivial then new else table.clone(new) - archetype.columns[column][record.row] = data - - return id, old, new - end - end - - local function changes_removed() - removed = true - - local q = world:query(PreviousT):without(T):drain() - return function() - local id = q.next() - if id then - world:remove(id, PreviousT) - end - return id - end - end - - local changes = { - added = changes_added, - changed = changes_changed, - removed = changes_removed, - } - - local function track(fn) - added = false - removed = false - - fn(changes) - - if not added then - for _ in changes_added() do - end - end - - if not removed then - for _ in changes_removed() do - end - end - - for e, data in add do - world:set(e, PreviousT, if is_trivial then data else table.clone(data)) - end - end - - local tracker = { track = track } - - return tracker -end - -return ChangeTracker diff --git a/examples/luau/queries/changetracking.luau b/examples/luau/queries/changetracking.luau index ed1de8f..bcaa032 100644 --- a/examples/luau/queries/changetracking.luau +++ b/examples/luau/queries/changetracking.luau @@ -1,164 +1,5 @@ local jecs = require("@jecs") -type World = jecs.WorldShim - -type Tracker = { - track: ( - world: World, - fn: ( - changes: { - added: () -> () -> (number, T), - removed: () -> () -> number, - changed: () -> () -> (number, T, T), - } - ) -> () - ) -> (), -} - -local function diff(a, b) - local size = 0 - for k, v in a do - if b[k] ~= v then - return true - end - size += 1 - end - for k, v in b do - size -= 1 - end - - if size ~= 0 then - return true - end - - return false -end - -type Entity = number & { __nominal_type_dont_use: T } - -local function ChangeTracker(world, T: Entity): Tracker - local PreviousT = jecs.pair(jecs.Rest, T) - local add = {} - local added - local removed - local is_trivial - - local function changes_added() - added = true - local q = world:query(T):without(PreviousT):drain() - return function() - local id, data = q.next() - if not id then - return nil - end - - is_trivial = typeof(data) ~= "table" - - add[id] = data - - return id, data - end - end - - local function changes_changed() - local q = world:query(T, PreviousT):drain() - - return function() - local id, new, old = q.next() - while true do - if not id then - return nil - end - - if not is_trivial then - if diff(new, old) then - break - end - elseif new ~= old then - break - end - - id, new, old = q.next() - end - - add[id] = new - - return id, old, new - end - end - - local function changes_removed() - removed = true - - local q = world:query(PreviousT):without(T):drain() - return function() - local id = q.next() - if id then - world:remove(id, PreviousT) - end - return id - end - end - - local changes = { - added = changes_added, - changed = changes_changed, - removed = changes_removed, - } - - local function track(fn) - added = false - removed = false - - fn(changes) - - if not added then - for _ in changes_added() do - end - end - - if not removed then - for _ in changes_removed() do - end - end - - for e, data in add do - world:set(e, PreviousT, if is_trivial then data else table.clone(data)) - end - end - - local tracker = { track = track } - - return tracker -end - -local Vector3 -do - Vector3 = {} - Vector3.__index = Vector3 - - function Vector3.new(x, y, z) - x = x or 0 - y = y or 0 - z = z or 0 - return setmetatable({ X = x, Y = y, Z = z }, Vector3) - end - - function Vector3.__add(left, right) - return Vector3.new(left.X + right.X, left.Y + right.Y, left.Z + right.Z) - end - - function Vector3.__mul(left, right) - if typeof(right) == "number" then - return Vector3.new(left.X * right, left.Y * right, left.Z * right) - end - return Vector3.new(left.X * right.X, left.Y * right.Y, left.Z * right.Z) - end - - Vector3.one = Vector3.new(1, 1, 1) - Vector3.zero = Vector3.new() -end - local world = jecs.World.new() local Name = world:component() @@ -171,10 +12,21 @@ local function name(e) return world:get(e, Name) end -local Position = named(world.component, "Position") +local Position = named(world.component, "Position") :: jecs.Entity +local Previous = jecs.Rest +local PreviousPosition = jecs.pair(Previous, Position) --- Create the ChangeTracker with the component type to track -local PositionTracker = ChangeTracker(world, Position) +local added = world + :query(Position) + :without(PreviousPosition) + :cached() +local changed = world + :query(Position, PreviousPosition) + :cached() +local removed = world + :query(PreviousPosition) + :without(Position) + :cached() local e1 = named(world.entity, "e1") world:set(e1, Position, Vector3.new(10, 20, 30)) @@ -182,52 +34,25 @@ world:set(e1, Position, Vector3.new(10, 20, 30)) local e2 = named(world.entity, "e2") world:set(e2, Position, Vector3.new(10, 20, 30)) -PositionTracker.track(function(changes) - -- You can iterate over different types of changes: Added, Changed, Removed +for e, p in added:iter() do + print(`Added {e}: \{{p.X}, {p.Y}, {p.Z}}`) + world:set(e, PreviousPosition, p) +end - -- added queries for every entity with a new Position component - for e, p in changes.added() do - print(`Added {e}: \{{p.X}, {p.Y}, {p.Z}}`) +world:set(e1, Position, "") + +for e, new, old in changed:iter() do + if new ~= old then + print(`{name(new)}'s Position changed from \{{old.X}, {old.Y}, {old.Z}\} to \{{new.X}, {new.Y}, {new.Z}\}`) + world:set(e, PreviousPosition, new) end - - -- changed queries for entities who's changed their data since - -- last was it tracked - for _ in changes.changed() do - print([[This won't print because it is the first time - we are tracking the Position component]]) - end - - -- removed queries for entities who's removed their Position component - -- since last it was tracked - for _ in changes.removed() do - print([[This won't print because it is the first time - we are tracking the Position component]]) - end -end) - -world:set(e1, Position, Vector3.new(1, 1, 2) * 999) - -PositionTracker.track(function(changes) - for e, p in changes.added() do - print([[This won't never print no Position component was added - since last time we tracked]]) - end - - for e, old, new in changes.changed() do - print(`{name(e)}'s Position changed from \{{old.X}, {old.Y}, {old.Z}\} to \{{new.X}, {new.Y}, {new.Z}\}`) - end - - -- If you don't call e.g. changes.removed() then it will automatically drain its iterator and stage their changes. - -- This ensures you will not have any off-by-one frame errors. -end) +end world:remove(e2, Position) -PositionTracker.track(function(changes) - for e in changes.removed() do - print(`Position was removed from {name(e)}`) - end -end) +for e in removed:iter() do + print(`Position was removed from {name(e)}`) +end -- Output: -- Added 265: {10, 20, 30} diff --git a/jecs.luau b/jecs.luau index 473b620..08db524 100644 --- a/jecs.luau +++ b/jecs.luau @@ -222,7 +222,7 @@ local function entity_index_is_alive(entity_index: EntityIndex, entity: number) return entity_index_try_get(entity_index, entity) ~= nil end -local function entity_index_new_id(entity_index: EntityIndex, data): i53 +local function entity_index_new_id(entity_index: EntityIndex): i53 local dense_array = entity_index.dense_array local alive_count = entity_index.alive_count if alive_count ~= #dense_array then @@ -1116,7 +1116,7 @@ do local delete = entity local component_index = world.componentIndex - local archetypes = world.archetypes + local archetypes: Archetypes = world.archetypes local tgt = ECS_PAIR(EcsWildcard, delete) local idr_t = component_index[tgt] local idr = component_index[delete] @@ -1232,6 +1232,15 @@ local EMPTY_QUERY = { setmetatable(EMPTY_QUERY, EMPTY_QUERY) +type QueryInner = { + compatible_archetypes: { Archetype }, + ids: { i53 }, + filter_with: { i53 }, + filter_without: { i53 }, + next: () -> (number, ...any), + world: World, +} + local function query_iter_init(query: QueryInner): () -> (number, ...any) local world_query_iter_next @@ -1309,6 +1318,9 @@ local function query_iter_init(query: QueryInner): () -> (number, ...any) entities = archetype.entities i = #entities + if i == 0 then + continue + end entityId = entities[i] columns = archetype.columns records = archetype.records @@ -1332,6 +1344,9 @@ local function query_iter_init(query: QueryInner): () -> (number, ...any) entities = archetype.entities i = #entities + if i == 0 then + continue + end entityId = entities[i] columns = archetype.columns records = archetype.records @@ -1356,6 +1371,9 @@ local function query_iter_init(query: QueryInner): () -> (number, ...any) entities = archetype.entities i = #entities + if i == 0 then + continue + end entityId = entities[i] columns = archetype.columns records = archetype.records @@ -1381,6 +1399,9 @@ local function query_iter_init(query: QueryInner): () -> (number, ...any) entities = archetype.entities i = #entities + if i == 0 then + continue + end entityId = entities[i] columns = archetype.columns records = archetype.records @@ -1408,6 +1429,9 @@ local function query_iter_init(query: QueryInner): () -> (number, ...any) entities = archetype.entities i = #entities + if i == 0 then + continue + end entityId = entities[i] columns = archetype.columns records = archetype.records @@ -1547,45 +1571,6 @@ local function query_archetypes(query) end local function query_cached(query: QueryInner) - local archetypes = query.compatible_archetypes - local world = query.world :: World - -- Only need one observer for EcsArchetypeCreate and EcsArchetypeDelete respectively - -- because the event will be emitted for all components of that Archetype. - local first = query.ids[1] - local observerable = world.observerable - local on_create_action = observerable[EcsOnArchetypeCreate] - if not on_create_action then - on_create_action = {} - observerable[EcsOnArchetypeCreate] = on_create_action - end - local query_cache_on_create = on_create_action[first] - if not query_cache_on_create then - query_cache_on_create = {} - on_create_action[first] = query_cache_on_create - end - - local on_delete_action = observerable[EcsOnArchetypeDelete] - if not on_delete_action then - on_delete_action = {} - observerable[EcsOnArchetypeDelete] = on_delete_action - end - local query_cache_on_delete = on_delete_action[first] - if not query_cache_on_delete then - query_cache_on_delete = {} - on_delete_action[first] = query_cache_on_delete - end - - local function on_create_callback(archetype) - table.insert(archetypes, archetype) - end - - local function on_delete_callback(archetype) - local i = table.find(archetypes, archetype) :: number - local n = #archetypes - archetypes[i] = archetypes[n] - archetypes[n] = nil - end - local with = query.filter_with local ids = query.ids if with then @@ -1594,12 +1579,6 @@ local function query_cached(query: QueryInner) query.filter_with = ids end - local observer_for_create = { query = query, callback = on_create_callback } - local observer_for_delete = { query = query, callback = on_delete_callback } - - table.insert(query_cache_on_create, observer_for_create) - table.insert(query_cache_on_delete, observer_for_delete) - local compatible_archetypes = query.compatible_archetypes local lastArchetype = 1 @@ -1613,6 +1592,50 @@ local function query_cached(query: QueryInner) local i: number local archetype: Archetype local records: { ArchetypeRecord } + local archetypes = query.compatible_archetypes + + local world = query.world :: World + -- Only need one observer for EcsArchetypeCreate and EcsArchetypeDelete respectively + -- because the event will be emitted for all components of that Archetype. + local observerable = world.observerable + local on_create_action = observerable[EcsOnArchetypeCreate] + if not on_create_action then + on_create_action = {} + observerable[EcsOnArchetypeCreate] = on_create_action + end + local query_cache_on_create = on_create_action[A] + if not query_cache_on_create then + query_cache_on_create = {} + on_create_action[A] = query_cache_on_create + end + + local on_delete_action = observerable[EcsOnArchetypeDelete] + if not on_delete_action then + on_delete_action = {} + observerable[EcsOnArchetypeDelete] = on_delete_action + end + local query_cache_on_delete = on_delete_action[A] + if not query_cache_on_delete then + query_cache_on_delete = {} + on_delete_action[A] = query_cache_on_delete + end + + local function on_create_callback(archetype) + table.insert(archetypes, archetype) + end + + local function on_delete_callback(archetype) + local i = table.find(archetypes, archetype) :: number + local n = #archetypes + archetypes[i] = archetypes[n] + archetypes[n] = nil + end + + local observer_for_create = { query = query, callback = on_create_callback } + local observer_for_delete = { query = query, callback = on_delete_callback } + + table.insert(query_cache_on_create, observer_for_create) + table.insert(query_cache_on_delete, observer_for_delete) local function cached_query_iter() lastArchetype = 1 @@ -1685,6 +1708,9 @@ local function query_cached(query: QueryInner) entities = archetype.entities i = #entities + if i == 0 then + continue + end entityId = entities[i] columns = archetype.columns records = archetype.records @@ -1708,6 +1734,9 @@ local function query_cached(query: QueryInner) entities = archetype.entities i = #entities + if i == 0 then + continue + end entityId = entities[i] columns = archetype.columns records = archetype.records @@ -1732,9 +1761,12 @@ local function query_cached(query: QueryInner) entities = archetype.entities i = #entities + if i == 0 then + continue + end entityId = entities[i] columns = archetype.columns - records = archetype.records + records = archetype.records a = columns[records[A].column] b = columns[records[B].column] c = columns[records[C].column] @@ -1757,6 +1789,9 @@ local function query_cached(query: QueryInner) entities = archetype.entities i = #entities + if i == 0 then + continue + end entityId = entities[i] columns = archetype.columns records = archetype.records @@ -1784,6 +1819,9 @@ local function query_cached(query: QueryInner) entities = archetype.entities i = #entities + if i == 0 then + continue + end entityId = entities[i] columns = archetype.columns records = archetype.records @@ -1939,7 +1977,7 @@ local function world_each(world: World, id): () -> () return function(): any local entity = entities[row] while not entity do - archetype_id = next(idr_cache, archetype_id) + archetype_id = next(idr_cache, archetype_id) :: number if not archetype_id then return end @@ -2141,35 +2179,37 @@ function World.new() return self end -type Id = - | (number & { __jecs_pair_value: T }) - | (number & { __T: T }) +export type Id = + | Entity + | Pair, Entity> + | Pair> + | Pair, Entity> -export type Pair

= number & { - __jecs_pair_value: ecs_id_t> +export type Pair = number & { + __P: P, + __O: O, } -type function ecs_id_t(entity) - local ty = entity:components()[2] - local __T = ty:readproperty(types.singleton("__T")) - if not __T then - return ty:readproperty(types.singleton("__jecs_pair_value")) - end - return __T -end +-- type function ecs_id_t(entity) +-- local ty = entity:components()[2] +-- local __T = ty:readproperty(types.singleton("__T")) +-- if not __T then +-- return ty:readproperty(types.singleton("__jecs_pair_value")) +-- end +-- return __T +-- end -type function ecs_pair_t(first, second) - local ty = first:components()[2] - if ty:readproperty(types.singleton("__T")):is("nil") then - return second - else - return first - end -end +-- type function ecs_pair_t(first, second) +-- if ecs_id_t(first):is("nil") then +-- return second +-- else +-- return first +-- end +-- end type Item = (self: Query) -> (Entity, T...) -export type Entity = number & { __T: T } +export type Entity = number & { __T: T } type Iter = (query: Query) -> () -> (Entity, T...) @@ -2183,15 +2223,6 @@ export type Query = typeof(setmetatable({}, { cached: (self: Query) -> Query, } -type QueryInner = { - compatible_archetypes: { Archetype }, - filter_with: { i53 }?, - filter_without: { i53 }?, - ids: { i53 }, - world: {}, -- Downcasted to be serializable by the analyzer - next: () -> Item -} - type Observer = { callback: (archetype: Archetype) -> (), query: QueryInner, @@ -2208,7 +2239,13 @@ export type World = { nextEntityId: number, nextArchetypeId: number, - observerable: { [i53]: { [i53]: { { query: Query } } } }, + observerable: { + [i53]: { + [i53]: { + { query: QueryInner, callback: (Archetype) -> () } + } + } + }, } & { --- Creates a new entity entity: (self: World) -> Entity, @@ -2251,18 +2288,18 @@ export type World = { children: (self: World, id: Id) -> () -> Entity, --- Searches the world for entities that match a given query - query: ((World, A) -> Query>) - & ((World, A, B) -> Query, ecs_id_t>) - & ((World, A, B, C) -> Query, ecs_id_t, ecs_id_t>) - & ((World, A, B, C, D) -> Query, ecs_id_t, ecs_id_t, ecs_id_t>) - & ((World, A, B, C, D, E) -> Query, ecs_id_t, ecs_id_t, ecs_id_t, ecs_id_t>) - & ((World, A, B, C, D, E, F) -> Query, ecs_id_t, ecs_id_t, ecs_id_t, ecs_id_t, ecs_id_t>) - & ((World, A, B, C, D, E, F, G) -> Query, ecs_id_t, ecs_id_t, ecs_id_t, ecs_id_t, ecs_id_t, ecs_id_t>) - & ((World, A, B, C, D, E, F, G, H) -> Query, ecs_id_t, ecs_id_t, ecs_id_t, ecs_id_t, ecs_id_t, ecs_id_t, ecs_id_t>) + query: ((World, Id) -> Query) + & ((World, Id, Id) -> Query) + & ((World, Id, Id, Id) -> Query) + & ((World, Id, Id, Id, Id) -> Query) + & ((World, Id, Id, Id, Id, Id) -> Query) + & ((World, Id, Id, Id, Id, Id, Id) -> Query) + & ((World, Id, Id, Id, Id, Id, Id, Id) -> Query) + & ((World, Id, Id, Id, Id, Id, Id, Id, Id, ...Id) -> Query) } return { - World = World :: { new: () -> World }, + World = World, OnAdd = EcsOnAdd :: Entity<(entity: Entity) -> ()>, OnRemove = EcsOnRemove :: Entity<(entity: Entity) -> ()>, diff --git a/test/tests.luau b/test/tests.luau index 8da151f..16188cb 100644 --- a/test/tests.luau +++ b/test/tests.luau @@ -374,9 +374,12 @@ TEST("world:query()", function() world:set(e, Bar, false) local i = 0 + local iter = 0 for _, e in q:iter() do + iter += 1 i=1 end + CHECK (iter == 1) CHECK(i == 1) for _, e in q:iter() do i=2 @@ -1364,6 +1367,65 @@ TEST("Hooks", function() end end) +TEST("change tracking", function() + CASE "#1" do + local world = world_new() + local Foo = world:component() + local Previous = jecs.Rest + + local q1 = world + :query(Foo) + :without(pair(Previous, Foo)) + :cached() + + local e1 = world:entity() + world:set(e1, Foo, 1) + local e2 = world:entity() + world:set(e2, Foo, 2) + + local i = 0 + for e, new in q1 do + i += 1 + world:set(e, pair(Previous, Foo), new) + end + + CHECK(i == 2) + local j = 0 + for e, new in q1 do + j += 1 + world:set(e, pair(Previous, Foo), new) + end + + CHECK(j == 0) + end + + CASE "#2" do + local world = world_new() + local component = world:component() + local tag = world:entity() + local previous = jecs.Rest + + local q1 = world:query(component):without(pair(previous, component), tag):cached() + + local testEntity = world:entity() + + world:set(testEntity, component, 10) + + local i = 0 + for entity, number in q1 do + i += 1 + world:add(testEntity, tag) + end + + CHECK(i == 1) + + for e, n in q1 do + world:set(e, pair(previous, component), n) + end + end + +end) + TEST("repro", function() do CASE "#1" local world = world_new()