diff --git a/lib/init.lua b/lib/init.lua index 710b86b..1a3ffed 100644 --- a/lib/init.lua +++ b/lib/init.lua @@ -16,18 +16,14 @@ local RunService = game:GetService("RunService") Used to cajole varargs without dropping sparse values. ]] local function pack(...) - local len = select("#", ...) - - return len, { ... } + return select("#", ...), { ... } end --[[ Returns first value (success), and packs all following values. ]] local function packResult(...) - local result = (...) - - return result, pack(select(2, ...)) + return ..., pack(select(2, ...)) end --[[ @@ -185,9 +181,10 @@ end function Promise._newWithSelf(executor, ...) local args local promise = Promise.new(function(...) - args = {...} + args = { ... } end, ...) + -- we don't handle the length here since `args` will always be { resolve, reject, onCancelHook } executor(promise, unpack(args)) return promise @@ -196,7 +193,6 @@ end function Promise._new(traceback, executor, ...) return Promise._newWithSelf(function(self, ...) self._source = traceback - executor(...) end, ...) end @@ -262,7 +258,7 @@ function Promise._all(traceback, promises, amount) -- We need to check that each value is a promise here so that we can produce -- a proper error rather than a rejected promise with our error. - for i, promise in pairs(promises) do + for i, promise in ipairs(promises) do if not Promise.is(promise) then error((ERROR_NON_PROMISE_IN_LIST):format("Promise.all", tostring(i)), 3) end @@ -317,10 +313,10 @@ function Promise._all(traceback, promises, amount) -- We can assume the values inside `promises` are all promises since we -- checked above. - for i = 1, #promises do + for i, promise in ipairs(promises) do table.insert( newPromises, - promises[i]:andThen( + promise:andThen( function(...) resolveOne(i, ...) end, @@ -367,7 +363,7 @@ function Promise.allSettled(promises) -- We need to check that each value is a promise here so that we can produce -- a proper error rather than a rejected promise with our error. - for i, promise in pairs(promises) do + for i, promise in ipairs(promises) do if not Promise.is(promise) then error((ERROR_NON_PROMISE_IN_LIST):format("Promise.allSettled", tostring(i)), 2) end @@ -406,10 +402,10 @@ function Promise.allSettled(promises) -- We can assume the values inside `promises` are all promises since we -- checked above. - for i = 1, #promises do + for i, promise in ipairs(promises) do table.insert( newPromises, - promises[i]:finally( + promise:finally( function(...) resolveOne(i, ...) end @@ -426,7 +422,7 @@ end function Promise.race(promises) assert(type(promises) == "table", ERROR_NON_LIST:format("Promise.race")) - for i, promise in pairs(promises) do + for i, promise in ipairs(promises) do assert(Promise.is(promise), (ERROR_NON_PROMISE_IN_LIST):format("Promise.race", tostring(i))) end @@ -500,58 +496,96 @@ end Creates a Promise that resolves after given number of seconds. ]] do + -- uses a sorted doubly linked list (queue) to achieve O(1) remove operations and O(n) for insert + + -- the initial node in the linked list + local first local connection - local queue = {} - - local function enqueue(callback, seconds) - table.insert(queue, { - callback = callback, - startTime = tick(), - endTime = tick() + math.max(seconds, 1/60) - }) - - table.sort(queue, function(a, b) - return a.endTime < b.endTime - end) - - if not connection then - connection = RunService.Heartbeat:Connect(function() - while #queue > 0 and queue[1].endTime <= tick() do - local item = table.remove(queue, 1) - - item.callback(tick() - item.startTime) - end - - if #queue == 0 then - connection:Disconnect() - connection = nil - end - end) - end - end - - local function dequeue(callback) - for i, item in ipairs(queue) do - if item.callback == callback then - table.remove(queue, i) - break - end - end - end function Promise.delay(seconds) assert(type(seconds) == "number", "Bad argument #1 to Promise.delay, must be a number.") - -- If seconds is -INF, INF, or NaN, assume seconds is 0. + -- If seconds is -INF, INF, NaN, or less than 1 / 60, assume seconds is 1 / 60. -- This mirrors the behavior of wait() - if seconds < 0 or seconds == math.huge or seconds ~= seconds then - seconds = 0 + if not (seconds >= 1 / 60) or seconds == math.huge then + seconds = 1 / 60 end return Promise._new(debug.traceback(), function(resolve, _, onCancel) - enqueue(resolve, seconds) + local startTime = tick() + local endTime = startTime + seconds + + local node = { + resolve = resolve, + startTime = startTime, + endTime = endTime + } + + if connection == nil then -- first is nil when connection is nil + first = node + connection = RunService.Heartbeat:Connect(function() + local currentTime = tick() + + while first.endTime <= currentTime do + first.resolve(currentTime - first.startTime) + first = first.next + if first == nil then + connection:Disconnect() + connection = nil + break + end + first.previous = nil + currentTime = tick() + end + end) + else -- first is non-nil + if first.endTime < endTime then -- if `node` should be placed after `first` + -- we will insert `node` between `current` and `next` + -- (i.e. after `current` if `next` is nil) + local current = first + local next = current.next + + while next ~= nil and next.endTime < endTime do + current = next + next = current.next + end + + -- `current` must be non-nil, but `next` could be `nil` (i.e. last item in list) + current.next = node + node.previous = current + + if next ~= nil then + node.next = next + next.previous = node + end + else + -- set `node` to `first` + node.next = first + first.previous = node + first = node + end + end onCancel(function() - dequeue(resolve) + -- remove node from queue + local next = node.next + + if first == node then + if next == nil then -- if `node` is the first and last + connection:Disconnect() + connection = nil + else -- if `node` is `first` and not the last + next.previous = nil + end + first = next + else + local previous = node.previous + -- since `node` is not `first`, then we know `previous` is non-nil + previous.next = next + + if next ~= nil then + next.previous = previous + end + end end) end) end @@ -862,14 +896,23 @@ function Promise.prototype:awaitStatus() return self._status end +local function awaitHelper(status, ...) + return status == Promise.Status.Resolved, ... +end + --[[ Calls awaitStatus internally, returns (isResolved, values...) ]] function Promise.prototype:await(...) - local length, result = pack(self:awaitStatus(...)) - local status = table.remove(result, 1) + return awaitHelper(self:awaitStatus(...)) +end - return status == Promise.Status.Resolved, unpack(result, 1, length - 1) +local function expectHelper(status, ...) + if status ~= Promise.Status.Resolved then + error((...) == nil and "" or tostring((...)), 3) + end + + return ... end --[[ @@ -877,15 +920,7 @@ end Throws if the Promise rejects or gets cancelled. ]] function Promise.prototype:expect(...) - local length, result = pack(self:awaitStatus(...)) - local status = table.remove(result, 1) - - assert( - status == Promise.Status.Resolved, - tostring(result[1] == nil and "" or result[1]) - ) - - return unpack(result, 1, length - 1) + return expectHelper(self:awaitStatus(...)) end Promise.prototype.awaitValue = Promise.prototype.expect @@ -1036,4 +1071,4 @@ function Promise.prototype:_finalize() end end -return Promise \ No newline at end of file +return Promise