From 3dbb121906b10660a8eb61605303b7d8de42b1e8 Mon Sep 17 00:00:00 2001 From: Eryn Lynn Date: Fri, 27 Sep 2019 18:46:10 -0400 Subject: [PATCH] Make error more robust when misusing all/race --- lib/init.lua | 20 ++++++++++---------- lib/init.spec.lua | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/lib/init.lua b/lib/init.lua index 88f2bd9..b73c02a 100644 --- a/lib/init.lua +++ b/lib/init.lua @@ -211,19 +211,19 @@ function Promise.all(promises) error("Please pass a list of promises to Promise.all", 2) end + -- We need to check that each value is a promise here so that we can produce + -- a proper error rather than a rejected promise with our error. + for i, promise in pairs(promises) do + if not Promise.is(promise) then + error(("Non-promise value passed into Promise.all at index %s"):format(tostring(i)), 2) + end + end + -- If there are no values then return an already resolved promise. if #promises == 0 then return Promise.resolve({}) end - -- We need to check that each value is a promise here so that we can produce - -- a proper error rather than a rejected promise with our error. - for i = 1, #promises do - if not Promise.is(promises[i]) then - error(("Non-promise value passed into Promise.all at index #%d"):format(i), 2) - end - end - return Promise.new(function(resolve, reject) -- An array to contain our resolved values from the given promises. local resolvedValues = {} @@ -264,8 +264,8 @@ end function Promise.race(promises) assert(type(promises) == "table", "Please pass a list of promises to Promise.race") - for i, promise in ipairs(promises) do - assert(Promise.is(promise), ("Non-promise value passed into Promise.race at index #%d"):format(i)) + for i, promise in pairs(promises) do + assert(Promise.is(promise), ("Non-promise value passed into Promise.race at index %s"):format(tostring(i))) end return Promise.new(function(resolve, reject, onCancel) diff --git a/lib/init.spec.lua b/lib/init.spec.lua index 3fdeb4f..24a0108 100644 --- a/lib/init.spec.lua +++ b/lib/init.spec.lua @@ -550,6 +550,15 @@ return function() expect(first).to.equal("foo") expect(second).to.equal("bar") end) + + it("should error if a non-array table is passed in", function() + local ok, err = pcall(function() + Promise.all(Promise.new(function() end)) + end) + + expect(ok).to.be.ok() + expect(err:find("Non%-promise")).to.be.ok() + end) end) describe("Promise.race", function() @@ -581,6 +590,15 @@ return function() expect(promises[1]:getStatus()).to.equal(Promise.Status.Cancelled) expect(promises[2]:getStatus()).to.equal(Promise.Status.Resolved) end) + + it("should error if a non-array table is passed in", function() + local ok, err = pcall(function() + Promise.race(Promise.new(function() end)) + end) + + expect(ok).to.be.ok() + expect(err:find("Non%-promise")).to.be.ok() + end) end) describe("Promise.promisify", function()