diff --git a/lib/init.lua b/lib/init.lua index 8058fdf..a0ee2d4 100644 --- a/lib/init.lua +++ b/lib/init.lua @@ -7,6 +7,21 @@ local ERROR_NON_LIST = "Please pass a list of promises to %s" local ERROR_NON_FUNCTION = "Please pass a handler function to %s!" local MODE_KEY_METATABLE = { __mode = "k" } +local function isCallable(value) + if type(value) == "function" then + return true + end + + if type(value) == "table" then + local metatable = getmetatable(value) + if metatable and type(rawget(metatable, "__call")) == "function" then + return true + end + end + + return false +end + --[[ Creates an enum dictionary with some metamethods to prevent common mistakes. ]] @@ -131,7 +146,7 @@ local function packResult(success, ...) end local function makeErrorHandler(traceback) - assert(traceback ~= nil) + assert(traceback ~= nil, "traceback is nil") return function(err) -- If the error object is already a table, forward it directly. @@ -592,7 +607,7 @@ end ]=] function Promise.fold(list, reducer, initialValue) assert(type(list) == "table", "Bad argument #1 to Promise.fold: must be a table") - assert(type(reducer) == "function", "Bad argument #2 to Promise.fold: must be a function") + assert(isCallable(reducer), "Bad argument #2 to Promise.fold: must be a function") local accumulator = Promise.resolve(initialValue) return Promise.each(list, function(resolvedElement, i) @@ -842,7 +857,7 @@ end ]=] function Promise.each(list, predicate) assert(type(list) == "table", string.format(ERROR_NON_LIST, "Promise.each")) - assert(type(predicate) == "function", string.format(ERROR_NON_FUNCTION, "Promise.each")) + assert(isCallable(predicate), string.format(ERROR_NON_FUNCTION, "Promise.each")) return Promise._new(debug.traceback(nil, 2), function(resolve, reject, onCancel) local results = {} @@ -951,11 +966,11 @@ function Promise.is(object) return true elseif objectMetatable == nil then -- No metatable, but we should still chain onto tables with andThen methods - return type(object.andThen) == "function" + return isCallable(object.andThen) elseif type(objectMetatable) == "table" and type(rawget(objectMetatable, "__index")) == "table" - and type(rawget(rawget(objectMetatable, "__index"), "andThen")) == "function" + and isCallable(rawget(rawget(objectMetatable, "__index"), "andThen")) then -- Maybe this came from a different or older Promise library. return true @@ -1235,14 +1250,8 @@ end @return Promise<...any> ]=] function Promise.prototype:andThen(successHandler, failureHandler) - assert( - successHandler == nil or type(successHandler) == "function", - string.format(ERROR_NON_FUNCTION, "Promise:andThen") - ) - assert( - failureHandler == nil or type(failureHandler) == "function", - string.format(ERROR_NON_FUNCTION, "Promise:andThen") - ) + assert(successHandler == nil or isCallable(successHandler), string.format(ERROR_NON_FUNCTION, "Promise:andThen")) + assert(failureHandler == nil or isCallable(failureHandler), string.format(ERROR_NON_FUNCTION, "Promise:andThen")) return self:_andThen(debug.traceback(nil, 2), successHandler, failureHandler) end @@ -1261,10 +1270,7 @@ end @return Promise<...any> ]=] function Promise.prototype:catch(failureHandler) - assert( - failureHandler == nil or type(failureHandler) == "function", - string.format(ERROR_NON_FUNCTION, "Promise:catch") - ) + assert(failureHandler == nil or isCallable(failureHandler), string.format(ERROR_NON_FUNCTION, "Promise:catch")) return self:_andThen(debug.traceback(nil, 2), nil, failureHandler) end @@ -1285,7 +1291,7 @@ end @return Promise<...any> ]=] function Promise.prototype:tap(tapHandler) - assert(type(tapHandler) == "function", string.format(ERROR_NON_FUNCTION, "Promise:tap")) + assert(isCallable(tapHandler), string.format(ERROR_NON_FUNCTION, "Promise:tap")) return self:_andThen(debug.traceback(nil, 2), function(...) local callbackReturn = tapHandler(...) @@ -1320,7 +1326,7 @@ end @return Promise ]=] function Promise.prototype:andThenCall(callback, ...) - assert(type(callback) == "function", string.format(ERROR_NON_FUNCTION, "Promise:andThenCall")) + assert(isCallable(callback), string.format(ERROR_NON_FUNCTION, "Promise:andThenCall")) local length, values = pack(...) return self:_andThen(debug.traceback(nil, 2), function() return callback(unpack(values, 1, length)) @@ -1474,10 +1480,7 @@ end @return Promise<...any> ]=] function Promise.prototype:finally(finallyHandler) - assert( - finallyHandler == nil or type(finallyHandler) == "function", - string.format(ERROR_NON_FUNCTION, "Promise:finally") - ) + assert(finallyHandler == nil or isCallable(finallyHandler), string.format(ERROR_NON_FUNCTION, "Promise:finally")) return self:_finally(debug.traceback(nil, 2), finallyHandler) end @@ -1491,7 +1494,7 @@ end @return Promise ]=] function Promise.prototype:finallyCall(callback, ...) - assert(type(callback) == "function", string.format(ERROR_NON_FUNCTION, "Promise:finallyCall")) + assert(isCallable(callback), string.format(ERROR_NON_FUNCTION, "Promise:finallyCall")) local length, values = pack(...) return self:_finally(debug.traceback(nil, 2), function() return callback(unpack(values, 1, length)) @@ -1540,7 +1543,7 @@ end @return Promise<...any> ]=] function Promise.prototype:done(doneHandler) - assert(doneHandler == nil or type(doneHandler) == "function", string.format(ERROR_NON_FUNCTION, "Promise:done")) + assert(doneHandler == nil or isCallable(doneHandler), string.format(ERROR_NON_FUNCTION, "Promise:done")) return self:_finally(debug.traceback(nil, 2), doneHandler, true) end @@ -1554,7 +1557,7 @@ end @return Promise ]=] function Promise.prototype:doneCall(callback, ...) - assert(type(callback) == "function", string.format(ERROR_NON_FUNCTION, "Promise:doneCall")) + assert(isCallable(callback), string.format(ERROR_NON_FUNCTION, "Promise:doneCall")) local length, values = pack(...) return self:_finally(debug.traceback(nil, 2), function() return callback(unpack(values, 1, length)) @@ -1903,7 +1906,7 @@ end @param ...? P ]=] function Promise.retry(callback, times, ...) - assert(type(callback) == "function", "Parameter #1 to Promise.retry must be a function") + assert(isCallable(callback), "Parameter #1 to Promise.retry must be a function") assert(type(times) == "number", "Parameter #2 to Promise.retry must be a number") local args, length = { ... }, select("#", ...) diff --git a/lib/init.spec.lua b/lib/init.spec.lua index 1bfc996..3168cf8 100644 --- a/lib/init.spec.lua +++ b/lib/init.spec.lua @@ -5,7 +5,8 @@ return function() local timeEvent = Instance.new("BindableEvent") Promise._timeEvent = timeEvent.Event - local advanceTime do + local advanceTime + do local injectedPromiseTime = 0 Promise._getTime = function() @@ -13,7 +14,7 @@ return function() end function advanceTime(delta) - delta = delta or (1/60) + delta = delta or (1 / 60) injectedPromiseTime = injectedPromiseTime + delta timeEvent:Fire(delta) @@ -145,7 +146,9 @@ return function() expect(trace:find("nestedCall")).to.be.ok() expect(trace:find("runExecutor")).to.be.ok() expect(trace:find("runPlanNode")).to.be.ok() - expect(trace:find("...Rejected because it was chained to the following Promise, which encountered an error:")).to.be.ok() + expect( + trace:find("...Rejected because it was chained to the following Promise, which encountered an error:") + ).to.be.ok() end) it("should report errors from Promises with _error (< v2)", function() @@ -158,9 +161,29 @@ return function() local trace = tostring(newPromise._values[1]) expect(trace:find("Sample error")).to.be.ok() - expect(trace:find("...Rejected because it was chained to the following Promise, which encountered an error:")).to.be.ok() + expect( + trace:find("...Rejected because it was chained to the following Promise, which encountered an error:") + ).to.be.ok() expect(trace:find("%[No stack trace available")).to.be.ok() end) + + it("should allow callable tables", function() + local promise = Promise.new(setmetatable({}, { + __call = function(_, resolve) + resolve(1) + end, + })) + + local called = false + promise:andThen(setmetatable({}, { + __call = function(_, var) + expect(var).to.equal(1) + called = true + end, + })) + + expect(called).to.equal(true) + end) end) describe("Promise.defer", function() @@ -206,7 +229,7 @@ return function() local promise = Promise.delay(2) Promise.delay(1):andThen(function() - promise:cancel() + promise:cancel() end) expect(promise:getStatus()).to.equal(Promise.Status.Started) @@ -308,15 +331,12 @@ return function() local promise = Promise.resolve(5) - local chained = promise:andThen( - function(...) - argsLength, args = pack(...) - callCount = callCount + 1 - end, - function() - badCallCount = badCallCount + 1 - end - ) + local chained = promise:andThen(function(...) + argsLength, args = pack(...) + callCount = callCount + 1 + end, function() + badCallCount = badCallCount + 1 + end) expect(badCallCount).to.equal(0) @@ -342,15 +362,12 @@ return function() local promise = Promise.reject(5) - local chained = promise:andThen( - function(...) - badCallCount = badCallCount + 1 - end, - function(...) - argsLength, args = pack(...) - callCount = callCount + 1 - end - ) + local chained = promise:andThen(function(...) + badCallCount = badCallCount + 1 + end, function(...) + argsLength, args = pack(...) + callCount = callCount + 1 + end) expect(badCallCount).to.equal(0) @@ -397,16 +414,13 @@ return function() startResolution = resolve end) - local chained = promise:andThen( - function(...) - args = {...} - argsLength = select("#", ...) - callCount = callCount + 1 - end, - function() - badCallCount = badCallCount + 1 - end - ) + local chained = promise:andThen(function(...) + args = { ... } + argsLength = select("#", ...) + callCount = callCount + 1 + end, function() + badCallCount = badCallCount + 1 + end) expect(callCount).to.equal(0) expect(badCallCount).to.equal(0) @@ -440,16 +454,13 @@ return function() startResolution = reject end) - local chained = promise:andThen( - function() - badCallCount = badCallCount + 1 - end, - function(...) - args = {...} - argsLength = select("#", ...) - callCount = callCount + 1 - end - ) + local chained = promise:andThen(function() + badCallCount = badCallCount + 1 + end, function(...) + args = { ... } + argsLength = select("#", ...) + callCount = callCount + 1 + end) expect(callCount).to.equal(0) expect(badCallCount).to.equal(0) @@ -476,9 +487,7 @@ return function() local x, y, z Promise.new(function(resolve, reject) reject(1, 2, 3) - end) - :andThen(function() end) - :catch(function(a, b, c) + end):andThen(function() end):catch(function(a, b, c) x, y, z = a, b, c end) @@ -492,11 +501,13 @@ return function() it("should mark promises as cancelled and not resolve or reject them", function() local callCount = 0 local finallyCallCount = 0 - local promise = Promise.new(function() end):andThen(function() - callCount = callCount + 1 - end):finally(function() - finallyCallCount = finallyCallCount + 1 - end) + local promise = Promise.new(function() end) + :andThen(function() + callCount = callCount + 1 + end) + :finally(function() + finallyCallCount = finallyCallCount + 1 + end) promise:cancel() promise:cancel() -- Twice to check call counts @@ -556,7 +567,9 @@ return function() it("should track consumers", function() local pending = Promise.new(function() end) local p0 = Promise.resolve() - local p1 = p0:finally(function() return pending end) + local p1 = p0:finally(function() + return pending + end) local p2 = Promise.new(function(resolve) resolve(p1) end) @@ -596,9 +609,7 @@ return function() end):finally(finally) -- Chained promise - Promise.resolve():andThen(function() - - end):finally(finally):finally(finally) + Promise.resolve():andThen(function() end):finally(finally):finally(finally) -- Rejected promise Promise.reject():finally(finally) @@ -620,11 +631,13 @@ return function() it("should forward return values", function() local value - Promise.resolve():finally(function() - return 1 - end):andThen(function(v) - value = v - end) + Promise.resolve() + :finally(function() + return 1 + end) + :andThen(function(v) + value = v + end) expect(value).to.equal(1) end) @@ -649,7 +662,7 @@ return function() it("should error if given non-promise values", function() expect(function() - Promise.all({{}, {}, {}}) + Promise.all({ {}, {}, {} }) end).to.throw() end) @@ -662,7 +675,7 @@ return function() for i = 1, testValuesLength do promises[i] = Promise.new(function(resolve) - resolveFunctions[i] = {resolve, testValues[i]} + resolveFunctions[i] = { resolve, testValues[i] } end) end @@ -698,7 +711,7 @@ return function() resolveB = resolve end) - local combinedPromise = Promise.all({a, b}) + local combinedPromise = Promise.all({ a, b }) expect(combinedPromise:getStatus()).to.equal(Promise.Status.Started) @@ -727,7 +740,7 @@ return function() resolveB = resolve end) - local combinedPromise = Promise.all({a, b}) + local combinedPromise = Promise.all({ a, b }) expect(combinedPromise:getStatus()).to.equal(Promise.Status.Started) @@ -755,7 +768,7 @@ return function() rejectB = reject end) - local combinedPromise = Promise.all({a, b}) + local combinedPromise = Promise.all({ a, b }) expect(combinedPromise:getStatus()).to.equal(Promise.Status.Started) @@ -788,7 +801,7 @@ return function() expect(Promise.all({ Promise.resolve(), Promise.reject(), - p + p, }):getStatus()).to.equal(Promise.Status.Rejected) expect(p:getStatus()).to.equal(Promise.Status.Cancelled) end) @@ -800,7 +813,7 @@ return function() local promises = { Promise.new(function() end), Promise.new(function() end), - p + p, } Promise.all(promises):cancel() @@ -824,7 +837,7 @@ return function() end) it("should accept promises in the list", function() - local sum = Promise.fold({Promise.resolve(1), 2, 3}, function(sum, element) + local sum = Promise.fold({ Promise.resolve(1), 2, 3 }, function(sum, element) return sum + element end, 0) expect(Promise.is(sum)).to.equal(true) @@ -833,7 +846,7 @@ return function() end) it("should always return a promise even if the list or reducer don't use them", function() - local sum = Promise.fold({1, 2, 3}, function(sum, element, index) + local sum = Promise.fold({ 1, 2, 3 }, function(sum, element, index) if index == 2 then return Promise.delay(1):andThenReturn(sum + element) else @@ -849,7 +862,7 @@ return function() it("should return the first rejected promise", function() local errorMessage = "foo" - local sum = Promise.fold({1, 2, 3}, function(sum, element, index) + local sum = Promise.fold({ 1, 2, 3 }, function(sum, element, index) if index == 2 then return Promise.reject(errorMessage) else @@ -864,14 +877,14 @@ return function() it("should return the first canceled promise", function() local secondPromise - local sum = Promise.fold({1, 2, 3}, function(sum, element, index) + local sum = Promise.fold({ 1, 2, 3 }, function(sum, element, index) if index == 1 then return sum + element elseif index == 2 then secondPromise = Promise.delay(1):andThenReturn(sum + element) return secondPromise else - error('this should not run if the promise is cancelled') + error("this should not run if the promise is cancelled") end end, 0) expect(Promise.is(sum)).to.equal(true) @@ -885,7 +898,7 @@ return function() it("should resolve with the first settled value", function() local promise = Promise.race({ Promise.resolve(1), - Promise.resolve(2) + Promise.resolve(2), }):andThen(function(value) expect(value).to.equal(1) end) @@ -901,7 +914,7 @@ return function() Promise.new(function() end), Promise.new(function(resolve) resolve(2) - end) + end), } local promise = Promise.race(promises) @@ -916,7 +929,7 @@ return function() expect(Promise.race({ Promise.reject(), Promise.resolve(), - p + p, }):getStatus()).to.equal(Promise.Status.Rejected) expect(p:getStatus()).to.equal(Promise.Status.Cancelled) end) @@ -937,7 +950,7 @@ return function() local promises = { Promise.new(function() end), Promise.new(function() end), - p + p, } Promise.race(promises):cancel() @@ -965,9 +978,9 @@ return function() it("should catch errors after a yield", function() local bindable = Instance.new("BindableEvent") - local test = Promise.promisify(function () + local test = Promise.promisify(function() bindable.Event:Wait() - error('errortext') + error("errortext") end) local promise = test() @@ -1026,7 +1039,7 @@ return function() it("should catch synchronous errors", function() local errorText Promise.try(function() - error('errortext') + error("errortext") end):catch(function(e) errorText = tostring(e) end) @@ -1048,7 +1061,7 @@ return function() local bindable = Instance.new("BindableEvent") local promise = Promise.try(function() bindable.Event:Wait() - error('errortext') + error("errortext") end) expect(promise:getStatus()).to.equal(Promise.Status.Started) @@ -1127,11 +1140,13 @@ return function() expect(value).to.equal(true) local never, always - Promise.reject():done(function() - never = true - end):finally(function() - always = true - end) + Promise.reject() + :done(function() + never = true + end) + :finally(function() + always = true + end) expect(never).to.never.be.ok() expect(always).to.be.ok() @@ -1143,7 +1158,7 @@ return function() local p = Promise.some({ Promise.resolve(1), Promise.reject(), - Promise.resolve(2) + Promise.resolve(2), }, 2) expect(p:getStatus()).to.equal(Promise.Status.Resolved) expect(p._values[1][1]).to.equal(1) @@ -1153,13 +1168,15 @@ return function() it("should error if the goal can't be reached", function() expect(Promise.some({ Promise.resolve(), - Promise.reject() + Promise.reject(), }, 2):getStatus()).to.equal(Promise.Status.Rejected) local reject local p = Promise.some({ Promise.resolve(), - Promise.new(function(_, r) reject = r end) + Promise.new(function(_, r) + reject = r + end), }, 2) expect(p:getStatus()).to.equal(Promise.Status.Started) @@ -1171,12 +1188,14 @@ return function() it("should cancel pending Promises once the goal is reached", function() local resolve local pending1 = Promise.new(function() end) - local pending2 = Promise.new(function(r) resolve = r end) + local pending2 = Promise.new(function(r) + resolve = r + end) local some = Promise.some({ pending1, pending2, - Promise.resolve() + Promise.resolve(), }, 2) expect(some:getStatus()).to.equal(Promise.Status.Started) @@ -1198,7 +1217,7 @@ return function() it("should return an empty array if amount is 0", function() local p = Promise.some({ - Promise.resolve(2) + Promise.resolve(2), }, 0) expect(p:getStatus()).to.equal(Promise.Status.Resolved) @@ -1226,7 +1245,7 @@ return function() local promises = { Promise.new(function() end), Promise.new(function() end), - p + p, } Promise.some(promises, 3):cancel() @@ -1241,7 +1260,7 @@ return function() local p = Promise.any({ Promise.reject(), Promise.reject(), - Promise.resolve(1) + Promise.resolve(1), }) expect(p:getStatus()).to.equal(Promise.Status.Resolved) @@ -1265,7 +1284,9 @@ return function() Promise.resolve(), Promise.reject(), Promise.resolve(), - Promise.new(function(_, r) reject = r end) + Promise.new(function(_, r) + reject = r + end), }) expect(p:getStatus()).to.equal(Promise.Status.Started) @@ -1284,7 +1305,7 @@ return function() local promises = { Promise.new(function() end), Promise.new(function() end), - p + p, } Promise.allSettled(promises):cancel() @@ -1349,9 +1370,12 @@ return function() describe("Promise.each", function() it("should iterate", function() local ok, result = Promise.each({ - "foo", "bar", "baz", "qux" + "foo", + "bar", + "baz", + "qux", }, function(...) - return {...} + return { ... } end):_unwrap() expect(ok).to.equal(true) @@ -1370,7 +1394,9 @@ return function() local callCounts = {} local promise = Promise.each({ - "foo", "bar", "baz" + "foo", + "bar", + "baz", }, function(value, index) callCounts[index] = (callCounts[index] or 0) + 1 @@ -1415,7 +1441,7 @@ return function() end) it("should reject with the value if the predicate promise rejects", function() - local promise = Promise.each({1, 2, 3}, function() + local promise = Promise.each({ 1, 2, 3 }, function() return Promise.reject("foobar") end) @@ -1430,7 +1456,7 @@ return function() end) local promise = Promise.each({ - innerPromise + innerPromise, }, function(value) return value * 2 end) @@ -1445,7 +1471,7 @@ return function() it("should reject with the value if a Promise from the list rejects", function() local called = false - local promise = Promise.each({1, 2, Promise.reject("foobar")}, function(value) + local promise = Promise.each({ 1, 2, Promise.reject("foobar") }, function(value) called = true return "never" end) @@ -1460,7 +1486,7 @@ return function() cancelled:cancel() local called = false - local promise = Promise.each({1, 2, cancelled}, function() + local promise = Promise.each({ 1, 2, cancelled }, function() called = true end) @@ -1473,13 +1499,13 @@ return function() local callCounts = {} local promise = Promise.each({ - "foo", "bar", "baz" + "foo", + "bar", + "baz", }, function(value, index) callCounts[index] = (callCounts[index] or 0) + 1 - return Promise.new(function() - - end) + return Promise.new(function() end) end) expect(promise:getStatus()).to.equal(Promise.Status.Started) @@ -1497,10 +1523,11 @@ return function() local innerPromise local promise = Promise.each({ - "foo", "bar", "baz" + "foo", + "bar", + "baz", }, function(value, index) - innerPromise = Promise.new(function() - end) + innerPromise = Promise.new(function() end) return innerPromise end) @@ -1512,7 +1539,7 @@ return function() it("should cancel Promises in the list if Promise.each is cancelled", function() local innerPromise = Promise.new(function() end) - local promise = Promise.each({innerPromise}, function() end) + local promise = Promise.each({ innerPromise }, function() end) promise:cancel() @@ -1595,7 +1622,7 @@ return function() local obj = { andThen = function() return 1 - end + end, } expect(Promise.is(obj)).to.equal(true) @@ -1606,9 +1633,7 @@ return function() OldPromise.prototype = {} OldPromise.__index = OldPromise.prototype - function OldPromise.prototype:andThen() - - end + function OldPromise.prototype:andThen() end local oldPromise = setmetatable({}, OldPromise)