Support callable tables where functions are allowed

Closes #64
This commit is contained in:
eryn L. K 2021-12-27 21:43:37 -05:00
parent e7033308ec
commit 4e04458816
2 changed files with 169 additions and 141 deletions

View file

@ -7,6 +7,21 @@ local ERROR_NON_LIST = "Please pass a list of promises to %s"
local ERROR_NON_FUNCTION = "Please pass a handler function to %s!"
local MODE_KEY_METATABLE = { __mode = "k" }
local function isCallable(value)
if type(value) == "function" then
return true
end
if type(value) == "table" then
local metatable = getmetatable(value)
if metatable and type(rawget(metatable, "__call")) == "function" then
return true
end
end
return false
end
--[[
Creates an enum dictionary with some metamethods to prevent common mistakes.
]]
@ -131,7 +146,7 @@ local function packResult(success, ...)
end
local function makeErrorHandler(traceback)
assert(traceback ~= nil)
assert(traceback ~= nil, "traceback is nil")
return function(err)
-- If the error object is already a table, forward it directly.
@ -592,7 +607,7 @@ end
]=]
function Promise.fold(list, reducer, initialValue)
assert(type(list) == "table", "Bad argument #1 to Promise.fold: must be a table")
assert(type(reducer) == "function", "Bad argument #2 to Promise.fold: must be a function")
assert(isCallable(reducer), "Bad argument #2 to Promise.fold: must be a function")
local accumulator = Promise.resolve(initialValue)
return Promise.each(list, function(resolvedElement, i)
@ -842,7 +857,7 @@ end
]=]
function Promise.each(list, predicate)
assert(type(list) == "table", string.format(ERROR_NON_LIST, "Promise.each"))
assert(type(predicate) == "function", string.format(ERROR_NON_FUNCTION, "Promise.each"))
assert(isCallable(predicate), string.format(ERROR_NON_FUNCTION, "Promise.each"))
return Promise._new(debug.traceback(nil, 2), function(resolve, reject, onCancel)
local results = {}
@ -951,11 +966,11 @@ function Promise.is(object)
return true
elseif objectMetatable == nil then
-- No metatable, but we should still chain onto tables with andThen methods
return type(object.andThen) == "function"
return isCallable(object.andThen)
elseif
type(objectMetatable) == "table"
and type(rawget(objectMetatable, "__index")) == "table"
and type(rawget(rawget(objectMetatable, "__index"), "andThen")) == "function"
and isCallable(rawget(rawget(objectMetatable, "__index"), "andThen"))
then
-- Maybe this came from a different or older Promise library.
return true
@ -1235,14 +1250,8 @@ end
@return Promise<...any>
]=]
function Promise.prototype:andThen(successHandler, failureHandler)
assert(
successHandler == nil or type(successHandler) == "function",
string.format(ERROR_NON_FUNCTION, "Promise:andThen")
)
assert(
failureHandler == nil or type(failureHandler) == "function",
string.format(ERROR_NON_FUNCTION, "Promise:andThen")
)
assert(successHandler == nil or isCallable(successHandler), string.format(ERROR_NON_FUNCTION, "Promise:andThen"))
assert(failureHandler == nil or isCallable(failureHandler), string.format(ERROR_NON_FUNCTION, "Promise:andThen"))
return self:_andThen(debug.traceback(nil, 2), successHandler, failureHandler)
end
@ -1261,10 +1270,7 @@ end
@return Promise<...any>
]=]
function Promise.prototype:catch(failureHandler)
assert(
failureHandler == nil or type(failureHandler) == "function",
string.format(ERROR_NON_FUNCTION, "Promise:catch")
)
assert(failureHandler == nil or isCallable(failureHandler), string.format(ERROR_NON_FUNCTION, "Promise:catch"))
return self:_andThen(debug.traceback(nil, 2), nil, failureHandler)
end
@ -1285,7 +1291,7 @@ end
@return Promise<...any>
]=]
function Promise.prototype:tap(tapHandler)
assert(type(tapHandler) == "function", string.format(ERROR_NON_FUNCTION, "Promise:tap"))
assert(isCallable(tapHandler), string.format(ERROR_NON_FUNCTION, "Promise:tap"))
return self:_andThen(debug.traceback(nil, 2), function(...)
local callbackReturn = tapHandler(...)
@ -1320,7 +1326,7 @@ end
@return Promise
]=]
function Promise.prototype:andThenCall(callback, ...)
assert(type(callback) == "function", string.format(ERROR_NON_FUNCTION, "Promise:andThenCall"))
assert(isCallable(callback), string.format(ERROR_NON_FUNCTION, "Promise:andThenCall"))
local length, values = pack(...)
return self:_andThen(debug.traceback(nil, 2), function()
return callback(unpack(values, 1, length))
@ -1474,10 +1480,7 @@ end
@return Promise<...any>
]=]
function Promise.prototype:finally(finallyHandler)
assert(
finallyHandler == nil or type(finallyHandler) == "function",
string.format(ERROR_NON_FUNCTION, "Promise:finally")
)
assert(finallyHandler == nil or isCallable(finallyHandler), string.format(ERROR_NON_FUNCTION, "Promise:finally"))
return self:_finally(debug.traceback(nil, 2), finallyHandler)
end
@ -1491,7 +1494,7 @@ end
@return Promise
]=]
function Promise.prototype:finallyCall(callback, ...)
assert(type(callback) == "function", string.format(ERROR_NON_FUNCTION, "Promise:finallyCall"))
assert(isCallable(callback), string.format(ERROR_NON_FUNCTION, "Promise:finallyCall"))
local length, values = pack(...)
return self:_finally(debug.traceback(nil, 2), function()
return callback(unpack(values, 1, length))
@ -1540,7 +1543,7 @@ end
@return Promise<...any>
]=]
function Promise.prototype:done(doneHandler)
assert(doneHandler == nil or type(doneHandler) == "function", string.format(ERROR_NON_FUNCTION, "Promise:done"))
assert(doneHandler == nil or isCallable(doneHandler), string.format(ERROR_NON_FUNCTION, "Promise:done"))
return self:_finally(debug.traceback(nil, 2), doneHandler, true)
end
@ -1554,7 +1557,7 @@ end
@return Promise
]=]
function Promise.prototype:doneCall(callback, ...)
assert(type(callback) == "function", string.format(ERROR_NON_FUNCTION, "Promise:doneCall"))
assert(isCallable(callback), string.format(ERROR_NON_FUNCTION, "Promise:doneCall"))
local length, values = pack(...)
return self:_finally(debug.traceback(nil, 2), function()
return callback(unpack(values, 1, length))
@ -1903,7 +1906,7 @@ end
@param ...? P
]=]
function Promise.retry(callback, times, ...)
assert(type(callback) == "function", "Parameter #1 to Promise.retry must be a function")
assert(isCallable(callback), "Parameter #1 to Promise.retry must be a function")
assert(type(times) == "number", "Parameter #2 to Promise.retry must be a number")
local args, length = { ... }, select("#", ...)

View file

@ -5,7 +5,8 @@ return function()
local timeEvent = Instance.new("BindableEvent")
Promise._timeEvent = timeEvent.Event
local advanceTime do
local advanceTime
do
local injectedPromiseTime = 0
Promise._getTime = function()
@ -13,7 +14,7 @@ return function()
end
function advanceTime(delta)
delta = delta or (1/60)
delta = delta or (1 / 60)
injectedPromiseTime = injectedPromiseTime + delta
timeEvent:Fire(delta)
@ -145,7 +146,9 @@ return function()
expect(trace:find("nestedCall")).to.be.ok()
expect(trace:find("runExecutor")).to.be.ok()
expect(trace:find("runPlanNode")).to.be.ok()
expect(trace:find("...Rejected because it was chained to the following Promise, which encountered an error:")).to.be.ok()
expect(
trace:find("...Rejected because it was chained to the following Promise, which encountered an error:")
).to.be.ok()
end)
it("should report errors from Promises with _error (< v2)", function()
@ -158,9 +161,29 @@ return function()
local trace = tostring(newPromise._values[1])
expect(trace:find("Sample error")).to.be.ok()
expect(trace:find("...Rejected because it was chained to the following Promise, which encountered an error:")).to.be.ok()
expect(
trace:find("...Rejected because it was chained to the following Promise, which encountered an error:")
).to.be.ok()
expect(trace:find("%[No stack trace available")).to.be.ok()
end)
it("should allow callable tables", function()
local promise = Promise.new(setmetatable({}, {
__call = function(_, resolve)
resolve(1)
end,
}))
local called = false
promise:andThen(setmetatable({}, {
__call = function(_, var)
expect(var).to.equal(1)
called = true
end,
}))
expect(called).to.equal(true)
end)
end)
describe("Promise.defer", function()
@ -206,7 +229,7 @@ return function()
local promise = Promise.delay(2)
Promise.delay(1):andThen(function()
promise:cancel()
promise:cancel()
end)
expect(promise:getStatus()).to.equal(Promise.Status.Started)
@ -308,15 +331,12 @@ return function()
local promise = Promise.resolve(5)
local chained = promise:andThen(
function(...)
argsLength, args = pack(...)
callCount = callCount + 1
end,
function()
badCallCount = badCallCount + 1
end
)
local chained = promise:andThen(function(...)
argsLength, args = pack(...)
callCount = callCount + 1
end, function()
badCallCount = badCallCount + 1
end)
expect(badCallCount).to.equal(0)
@ -342,15 +362,12 @@ return function()
local promise = Promise.reject(5)
local chained = promise:andThen(
function(...)
badCallCount = badCallCount + 1
end,
function(...)
argsLength, args = pack(...)
callCount = callCount + 1
end
)
local chained = promise:andThen(function(...)
badCallCount = badCallCount + 1
end, function(...)
argsLength, args = pack(...)
callCount = callCount + 1
end)
expect(badCallCount).to.equal(0)
@ -397,16 +414,13 @@ return function()
startResolution = resolve
end)
local chained = promise:andThen(
function(...)
args = {...}
argsLength = select("#", ...)
callCount = callCount + 1
end,
function()
badCallCount = badCallCount + 1
end
)
local chained = promise:andThen(function(...)
args = { ... }
argsLength = select("#", ...)
callCount = callCount + 1
end, function()
badCallCount = badCallCount + 1
end)
expect(callCount).to.equal(0)
expect(badCallCount).to.equal(0)
@ -440,16 +454,13 @@ return function()
startResolution = reject
end)
local chained = promise:andThen(
function()
badCallCount = badCallCount + 1
end,
function(...)
args = {...}
argsLength = select("#", ...)
callCount = callCount + 1
end
)
local chained = promise:andThen(function()
badCallCount = badCallCount + 1
end, function(...)
args = { ... }
argsLength = select("#", ...)
callCount = callCount + 1
end)
expect(callCount).to.equal(0)
expect(badCallCount).to.equal(0)
@ -476,9 +487,7 @@ return function()
local x, y, z
Promise.new(function(resolve, reject)
reject(1, 2, 3)
end)
:andThen(function() end)
:catch(function(a, b, c)
end):andThen(function() end):catch(function(a, b, c)
x, y, z = a, b, c
end)
@ -492,11 +501,13 @@ return function()
it("should mark promises as cancelled and not resolve or reject them", function()
local callCount = 0
local finallyCallCount = 0
local promise = Promise.new(function() end):andThen(function()
callCount = callCount + 1
end):finally(function()
finallyCallCount = finallyCallCount + 1
end)
local promise = Promise.new(function() end)
:andThen(function()
callCount = callCount + 1
end)
:finally(function()
finallyCallCount = finallyCallCount + 1
end)
promise:cancel()
promise:cancel() -- Twice to check call counts
@ -556,7 +567,9 @@ return function()
it("should track consumers", function()
local pending = Promise.new(function() end)
local p0 = Promise.resolve()
local p1 = p0:finally(function() return pending end)
local p1 = p0:finally(function()
return pending
end)
local p2 = Promise.new(function(resolve)
resolve(p1)
end)
@ -596,9 +609,7 @@ return function()
end):finally(finally)
-- Chained promise
Promise.resolve():andThen(function()
end):finally(finally):finally(finally)
Promise.resolve():andThen(function() end):finally(finally):finally(finally)
-- Rejected promise
Promise.reject():finally(finally)
@ -620,11 +631,13 @@ return function()
it("should forward return values", function()
local value
Promise.resolve():finally(function()
return 1
end):andThen(function(v)
value = v
end)
Promise.resolve()
:finally(function()
return 1
end)
:andThen(function(v)
value = v
end)
expect(value).to.equal(1)
end)
@ -649,7 +662,7 @@ return function()
it("should error if given non-promise values", function()
expect(function()
Promise.all({{}, {}, {}})
Promise.all({ {}, {}, {} })
end).to.throw()
end)
@ -662,7 +675,7 @@ return function()
for i = 1, testValuesLength do
promises[i] = Promise.new(function(resolve)
resolveFunctions[i] = {resolve, testValues[i]}
resolveFunctions[i] = { resolve, testValues[i] }
end)
end
@ -698,7 +711,7 @@ return function()
resolveB = resolve
end)
local combinedPromise = Promise.all({a, b})
local combinedPromise = Promise.all({ a, b })
expect(combinedPromise:getStatus()).to.equal(Promise.Status.Started)
@ -727,7 +740,7 @@ return function()
resolveB = resolve
end)
local combinedPromise = Promise.all({a, b})
local combinedPromise = Promise.all({ a, b })
expect(combinedPromise:getStatus()).to.equal(Promise.Status.Started)
@ -755,7 +768,7 @@ return function()
rejectB = reject
end)
local combinedPromise = Promise.all({a, b})
local combinedPromise = Promise.all({ a, b })
expect(combinedPromise:getStatus()).to.equal(Promise.Status.Started)
@ -788,7 +801,7 @@ return function()
expect(Promise.all({
Promise.resolve(),
Promise.reject(),
p
p,
}):getStatus()).to.equal(Promise.Status.Rejected)
expect(p:getStatus()).to.equal(Promise.Status.Cancelled)
end)
@ -800,7 +813,7 @@ return function()
local promises = {
Promise.new(function() end),
Promise.new(function() end),
p
p,
}
Promise.all(promises):cancel()
@ -824,7 +837,7 @@ return function()
end)
it("should accept promises in the list", function()
local sum = Promise.fold({Promise.resolve(1), 2, 3}, function(sum, element)
local sum = Promise.fold({ Promise.resolve(1), 2, 3 }, function(sum, element)
return sum + element
end, 0)
expect(Promise.is(sum)).to.equal(true)
@ -833,7 +846,7 @@ return function()
end)
it("should always return a promise even if the list or reducer don't use them", function()
local sum = Promise.fold({1, 2, 3}, function(sum, element, index)
local sum = Promise.fold({ 1, 2, 3 }, function(sum, element, index)
if index == 2 then
return Promise.delay(1):andThenReturn(sum + element)
else
@ -849,7 +862,7 @@ return function()
it("should return the first rejected promise", function()
local errorMessage = "foo"
local sum = Promise.fold({1, 2, 3}, function(sum, element, index)
local sum = Promise.fold({ 1, 2, 3 }, function(sum, element, index)
if index == 2 then
return Promise.reject(errorMessage)
else
@ -864,14 +877,14 @@ return function()
it("should return the first canceled promise", function()
local secondPromise
local sum = Promise.fold({1, 2, 3}, function(sum, element, index)
local sum = Promise.fold({ 1, 2, 3 }, function(sum, element, index)
if index == 1 then
return sum + element
elseif index == 2 then
secondPromise = Promise.delay(1):andThenReturn(sum + element)
return secondPromise
else
error('this should not run if the promise is cancelled')
error("this should not run if the promise is cancelled")
end
end, 0)
expect(Promise.is(sum)).to.equal(true)
@ -885,7 +898,7 @@ return function()
it("should resolve with the first settled value", function()
local promise = Promise.race({
Promise.resolve(1),
Promise.resolve(2)
Promise.resolve(2),
}):andThen(function(value)
expect(value).to.equal(1)
end)
@ -901,7 +914,7 @@ return function()
Promise.new(function() end),
Promise.new(function(resolve)
resolve(2)
end)
end),
}
local promise = Promise.race(promises)
@ -916,7 +929,7 @@ return function()
expect(Promise.race({
Promise.reject(),
Promise.resolve(),
p
p,
}):getStatus()).to.equal(Promise.Status.Rejected)
expect(p:getStatus()).to.equal(Promise.Status.Cancelled)
end)
@ -937,7 +950,7 @@ return function()
local promises = {
Promise.new(function() end),
Promise.new(function() end),
p
p,
}
Promise.race(promises):cancel()
@ -965,9 +978,9 @@ return function()
it("should catch errors after a yield", function()
local bindable = Instance.new("BindableEvent")
local test = Promise.promisify(function ()
local test = Promise.promisify(function()
bindable.Event:Wait()
error('errortext')
error("errortext")
end)
local promise = test()
@ -1026,7 +1039,7 @@ return function()
it("should catch synchronous errors", function()
local errorText
Promise.try(function()
error('errortext')
error("errortext")
end):catch(function(e)
errorText = tostring(e)
end)
@ -1048,7 +1061,7 @@ return function()
local bindable = Instance.new("BindableEvent")
local promise = Promise.try(function()
bindable.Event:Wait()
error('errortext')
error("errortext")
end)
expect(promise:getStatus()).to.equal(Promise.Status.Started)
@ -1127,11 +1140,13 @@ return function()
expect(value).to.equal(true)
local never, always
Promise.reject():done(function()
never = true
end):finally(function()
always = true
end)
Promise.reject()
:done(function()
never = true
end)
:finally(function()
always = true
end)
expect(never).to.never.be.ok()
expect(always).to.be.ok()
@ -1143,7 +1158,7 @@ return function()
local p = Promise.some({
Promise.resolve(1),
Promise.reject(),
Promise.resolve(2)
Promise.resolve(2),
}, 2)
expect(p:getStatus()).to.equal(Promise.Status.Resolved)
expect(p._values[1][1]).to.equal(1)
@ -1153,13 +1168,15 @@ return function()
it("should error if the goal can't be reached", function()
expect(Promise.some({
Promise.resolve(),
Promise.reject()
Promise.reject(),
}, 2):getStatus()).to.equal(Promise.Status.Rejected)
local reject
local p = Promise.some({
Promise.resolve(),
Promise.new(function(_, r) reject = r end)
Promise.new(function(_, r)
reject = r
end),
}, 2)
expect(p:getStatus()).to.equal(Promise.Status.Started)
@ -1171,12 +1188,14 @@ return function()
it("should cancel pending Promises once the goal is reached", function()
local resolve
local pending1 = Promise.new(function() end)
local pending2 = Promise.new(function(r) resolve = r end)
local pending2 = Promise.new(function(r)
resolve = r
end)
local some = Promise.some({
pending1,
pending2,
Promise.resolve()
Promise.resolve(),
}, 2)
expect(some:getStatus()).to.equal(Promise.Status.Started)
@ -1198,7 +1217,7 @@ return function()
it("should return an empty array if amount is 0", function()
local p = Promise.some({
Promise.resolve(2)
Promise.resolve(2),
}, 0)
expect(p:getStatus()).to.equal(Promise.Status.Resolved)
@ -1226,7 +1245,7 @@ return function()
local promises = {
Promise.new(function() end),
Promise.new(function() end),
p
p,
}
Promise.some(promises, 3):cancel()
@ -1241,7 +1260,7 @@ return function()
local p = Promise.any({
Promise.reject(),
Promise.reject(),
Promise.resolve(1)
Promise.resolve(1),
})
expect(p:getStatus()).to.equal(Promise.Status.Resolved)
@ -1265,7 +1284,9 @@ return function()
Promise.resolve(),
Promise.reject(),
Promise.resolve(),
Promise.new(function(_, r) reject = r end)
Promise.new(function(_, r)
reject = r
end),
})
expect(p:getStatus()).to.equal(Promise.Status.Started)
@ -1284,7 +1305,7 @@ return function()
local promises = {
Promise.new(function() end),
Promise.new(function() end),
p
p,
}
Promise.allSettled(promises):cancel()
@ -1349,9 +1370,12 @@ return function()
describe("Promise.each", function()
it("should iterate", function()
local ok, result = Promise.each({
"foo", "bar", "baz", "qux"
"foo",
"bar",
"baz",
"qux",
}, function(...)
return {...}
return { ... }
end):_unwrap()
expect(ok).to.equal(true)
@ -1370,7 +1394,9 @@ return function()
local callCounts = {}
local promise = Promise.each({
"foo", "bar", "baz"
"foo",
"bar",
"baz",
}, function(value, index)
callCounts[index] = (callCounts[index] or 0) + 1
@ -1415,7 +1441,7 @@ return function()
end)
it("should reject with the value if the predicate promise rejects", function()
local promise = Promise.each({1, 2, 3}, function()
local promise = Promise.each({ 1, 2, 3 }, function()
return Promise.reject("foobar")
end)
@ -1430,7 +1456,7 @@ return function()
end)
local promise = Promise.each({
innerPromise
innerPromise,
}, function(value)
return value * 2
end)
@ -1445,7 +1471,7 @@ return function()
it("should reject with the value if a Promise from the list rejects", function()
local called = false
local promise = Promise.each({1, 2, Promise.reject("foobar")}, function(value)
local promise = Promise.each({ 1, 2, Promise.reject("foobar") }, function(value)
called = true
return "never"
end)
@ -1460,7 +1486,7 @@ return function()
cancelled:cancel()
local called = false
local promise = Promise.each({1, 2, cancelled}, function()
local promise = Promise.each({ 1, 2, cancelled }, function()
called = true
end)
@ -1473,13 +1499,13 @@ return function()
local callCounts = {}
local promise = Promise.each({
"foo", "bar", "baz"
"foo",
"bar",
"baz",
}, function(value, index)
callCounts[index] = (callCounts[index] or 0) + 1
return Promise.new(function()
end)
return Promise.new(function() end)
end)
expect(promise:getStatus()).to.equal(Promise.Status.Started)
@ -1497,10 +1523,11 @@ return function()
local innerPromise
local promise = Promise.each({
"foo", "bar", "baz"
"foo",
"bar",
"baz",
}, function(value, index)
innerPromise = Promise.new(function()
end)
innerPromise = Promise.new(function() end)
return innerPromise
end)
@ -1512,7 +1539,7 @@ return function()
it("should cancel Promises in the list if Promise.each is cancelled", function()
local innerPromise = Promise.new(function() end)
local promise = Promise.each({innerPromise}, function() end)
local promise = Promise.each({ innerPromise }, function() end)
promise:cancel()
@ -1595,7 +1622,7 @@ return function()
local obj = {
andThen = function()
return 1
end
end,
}
expect(Promise.is(obj)).to.equal(true)
@ -1606,9 +1633,7 @@ return function()
OldPromise.prototype = {}
OldPromise.__index = OldPromise.prototype
function OldPromise.prototype:andThen()
end
function OldPromise.prototype:andThen() end
local oldPromise = setmetatable({}, OldPromise)