diff --git a/src/Buffer/init.luau b/src/Buffer/init.luau index 425c421..2f6616a 100644 --- a/src/Buffer/init.luau +++ b/src/Buffer/init.luau @@ -5,7 +5,7 @@ export type Writer = { buf: buffer, cursor: number, capacity: number, - refs: {any}, + refs: { any }, } local DEFAULT_CAPACITY: number = 64 @@ -67,31 +67,50 @@ local T_COLSEQ = 0xF2 -- ColorSequence local T_NUMSEQ = 0xF3 -- NumberSequence local TYPED_READERS = { - [1] = function(b, o) return buffer.readu8(b, o), o + 1 end, - [2] = function(b, o) return buffer.readi8(b, o), o + 1 end, - [3] = function(b, o) return buffer.readu16(b, o), o + 2 end, - [4] = function(b, o) return buffer.readi16(b, o), o + 2 end, - [5] = function(b, o) return buffer.readu32(b, o), o + 4 end, - [6] = function(b, o) return buffer.readi32(b, o), o + 4 end, - [7] = function(b, o) return buffer.readf32(b, o), o + 4 end, - [8] = function(b, o) return buffer.readf64(b, o), o + 8 end, + [1] = function(b, o) + return buffer.readu8(b, o), o + 1 + end, + [2] = function(b, o) + return buffer.readi8(b, o), o + 1 + end, + [3] = function(b, o) + return buffer.readu16(b, o), o + 2 + end, + [4] = function(b, o) + return buffer.readi16(b, o), o + 2 + end, + [5] = function(b, o) + return buffer.readu32(b, o), o + 4 + end, + [6] = function(b, o) + return buffer.readi32(b, o), o + 4 + end, + [7] = function(b, o) + return buffer.readf32(b, o), o + 4 + end, + [8] = function(b, o) + return buffer.readf64(b, o), o + 8 + end, } local F16_MANTISSA_BITS = 1024 -local F16_DENORM do +local F16_DENORM +do F16_DENORM = table.create(1024) F16_DENORM[1] = 0 for m = 1, 1023 do F16_DENORM[m + 1] = math.ldexp(m / F16_MANTISSA_BITS, -14) end end -local F16_EXP2 do +local F16_EXP2 +do F16_EXP2 = table.create(31) for e = 1, 30 do F16_EXP2[e] = math.ldexp(1, e - 15) end end -local F16_NORM_MANTISSA do +local F16_NORM_MANTISSA +do F16_NORM_MANTISSA = table.create(1024) for m = 0, 1023 do F16_NORM_MANTISSA[m + 1] = 1 + m / F16_MANTISSA_BITS @@ -100,10 +119,18 @@ end local F32TOF16_SUBNORM_POW2 = { 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024 } local function f32ToF16(value: number): number - if value ~= value then return 0x7E00 end - if value == math.huge then return 0x7C00 end - if value == -math.huge then return 0xFC00 end - if value == 0 then return 0 end + if value ~= value then + return 0x7E00 + end + if value == math.huge then + return 0x7C00 + end + if value == -math.huge then + return 0xFC00 + end + if value == 0 then + return 0 + end local sign = 0 if value < 0 then @@ -115,7 +142,9 @@ local function f32ToF16(value: number): number exponent += 14 if exponent <= 0 then - if exponent < -10 then return sign end + if exponent < -10 then + return sign + end local scale = F32TOF16_SUBNORM_POW2[10 + exponent] or bit32.lshift(1, 10 + exponent) mantissa = math.floor(mantissa * scale + 0.5) return bit32.bor(sign, mantissa) @@ -127,7 +156,9 @@ local function f32ToF16(value: number): number if mantissa == F16_MANTISSA_BITS then mantissa = 0 exponent += 1 - if exponent >= 31 then return bit32.bor(sign, 0x7C00) end + if exponent >= 31 then + return bit32.bor(sign, 0x7C00) + end end return bit32.bor(sign, bit32.lshift(exponent, 10), mantissa) end @@ -141,7 +172,7 @@ local function f16ToF32(bits: number): number if exponent == 0 then value = mantissa == 0 and 0 or (F16_DENORM[mantissa + 1] or math.ldexp(mantissa / 1024, -14)) elseif exponent == 31 then - value = mantissa == 0 and math.huge or 0/0 + value = mantissa == 0 and math.huge or 0 / 0 else value = F16_NORM_MANTISSA[mantissa + 1] * F16_EXP2[exponent] end @@ -150,10 +181,18 @@ local function f16ToF32(bits: number): number end local function varUIntSize(value: number): number - if value < 0x80 then return 1 end - if value < 0x4000 then return 2 end - if value < 0x200000 then return 3 end - if value < 0x10000000 then return 4 end + if value < 0x80 then + return 1 + end + if value < 0x4000 then + return 2 + end + if value < 0x200000 then + return 3 + end + if value < 0x10000000 then + return 4 + end return 5 end @@ -191,7 +230,9 @@ end local function ensureSpace(w: Writer, bytes: number) local needed = w.cursor + bytes - if needed <= w.capacity then return end + if needed <= w.capacity then + return + end local newCap = w.capacity while newCap < needed do @@ -355,19 +396,29 @@ local function packString(w: Writer, s: string) end -- is the table a homogeneous number array -local function analyzeNumberArray(t: {any}): (boolean, string?, number) +local function analyzeNumberArray(t: { any }): (boolean, string?, number) local count = #t - if count == 0 then return false, nil, 0 end + if count == 0 then + return false, nil, 0 + end local minVal, maxVal = math.huge, -math.huge local allInt = true for i = 1, count do local v = t[i] - if type(v) ~= "number" then return false, nil, count end - if v ~= math.floor(v) then allInt = false end - if v < minVal then minVal = v end - if v > maxVal then maxVal = v end + if type(v) ~= "number" then + return false, nil, count + end + if v ~= math.floor(v) then + allInt = false + end + if v < minVal then + minVal = v + end + if v > maxVal then + maxVal = v + end end if not allInt then @@ -391,11 +442,17 @@ end local TYPED_CODES = { u8 = 1, i8 = 2, u16 = 3, i16 = 4, u32 = 5, i32 = 6, f32 = 7, f64 = 8 } local TYPED_SIZES = { u8 = 1, i8 = 1, u16 = 2, i16 = 2, u32 = 4, i32 = 4, f32 = 4, f64 = 8 } local TYPED_WRITERS = { - u8 = buffer.writeu8, i8 = buffer.writei8, u16 = buffer.writeu16, i16 = buffer.writei16, - u32 = buffer.writeu32, i32 = buffer.writei32, f32 = buffer.writef32, f64 = buffer.writef64, + u8 = buffer.writeu8, + i8 = buffer.writei8, + u16 = buffer.writeu16, + i16 = buffer.writei16, + u32 = buffer.writeu32, + i32 = buffer.writei32, + f32 = buffer.writef32, + f64 = buffer.writef64, } -local function packTable(w: Writer, t: {[any]: any}) +local function packTable(w: Writer, t: { [any]: any }) local count = 0 local maxIdx = 0 local isArray = true @@ -403,7 +460,9 @@ local function packTable(w: Writer, t: {[any]: any}) for k in t do count += 1 if type(k) == "number" and k > 0 and k == math.floor(k) then - if k > maxIdx then maxIdx = k end + if k > maxIdx then + maxIdx = k + end else isArray = false end @@ -651,9 +710,9 @@ local function readF16(b: buffer, o: number): number return f16ToF32(buffer.readu16(b, o)) end -local unpackValue: (buf: buffer, pos: number, refs: {any}?) -> (any, number) +local unpackValue: (buf: buffer, pos: number, refs: { any }?) -> (any, number) -unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) +unpackValue = function(buf: buffer, pos: number, refs: { any }?): (any, number) local t = buffer.readu8(buf, pos) pos += 1 @@ -667,18 +726,40 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) return t - 0x80 - 32, pos end - if t == T_NIL then return nil, pos end - if t == T_FALSE then return false, pos end - if t == T_TRUE then return true, pos end + if t == T_NIL then + return nil, pos + end + if t == T_FALSE then + return false, pos + end + if t == T_TRUE then + return true, pos + end - if t == T_U8 then return buffer.readu8(buf, pos), pos + 1 end - if t == T_U16 then return buffer.readu16(buf, pos), pos + 2 end - if t == T_U32 then return buffer.readu32(buf, pos), pos + 4 end - if t == T_I8 then return buffer.readi8(buf, pos), pos + 1 end - if t == T_I16 then return buffer.readi16(buf, pos), pos + 2 end - if t == T_I32 then return buffer.readi32(buf, pos), pos + 4 end - if t == T_F32 then return buffer.readf32(buf, pos), pos + 4 end - if t == T_F64 then return buffer.readf64(buf, pos), pos + 8 end + if t == T_U8 then + return buffer.readu8(buf, pos), pos + 1 + end + if t == T_U16 then + return buffer.readu16(buf, pos), pos + 2 + end + if t == T_U32 then + return buffer.readu32(buf, pos), pos + 4 + end + if t == T_I8 then + return buffer.readi8(buf, pos), pos + 1 + end + if t == T_I16 then + return buffer.readi16(buf, pos), pos + 2 + end + if t == T_I32 then + return buffer.readi32(buf, pos), pos + 4 + end + if t == T_F32 then + return buffer.readf32(buf, pos), pos + 4 + end + if t == T_F64 then + return buffer.readf64(buf, pos), pos + 8 + end -- inline str len (0-15) if t >= T_STR_BASE and t <= T_STR_BASE + 15 then @@ -695,7 +776,8 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) return buffer.readstring(buf, pos + 2, n), pos + 2 + n end if t == T_STRVAR then - local n; n, pos = readVarUInt(buf, pos) + local n + n, pos = readVarUInt(buf, pos) return buffer.readstring(buf, pos, n), pos + n end @@ -710,7 +792,8 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) end if t == T_ARR8 then - local count = buffer.readu8(buf, pos); pos += 1 + local count = buffer.readu8(buf, pos) + pos += 1 local arr = table.create(count) for i = 1, count do arr[i], pos = unpackValue(buf, pos, refs) @@ -718,7 +801,8 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) return arr, pos end if t == T_ARR16 then - local count = buffer.readu16(buf, pos); pos += 2 + local count = buffer.readu16(buf, pos) + pos += 2 local arr = table.create(count) for i = 1, count do arr[i], pos = unpackValue(buf, pos, refs) @@ -726,7 +810,8 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) return arr, pos end if t == T_ARRVAR then - local count; count, pos = readVarUInt(buf, pos) + local count + count, pos = readVarUInt(buf, pos) local arr = table.create(count) for i = 1, count do arr[i], pos = unpackValue(buf, pos, refs) @@ -748,7 +833,8 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) end if t == T_MAP8 then - local count = buffer.readu8(buf, pos); pos += 1 + local count = buffer.readu8(buf, pos) + pos += 1 local map = {} for _ = 1, count do local k, v @@ -759,7 +845,8 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) return map, pos end if t == T_MAP16 then - local count = buffer.readu16(buf, pos); pos += 2 + local count = buffer.readu16(buf, pos) + pos += 2 local map = {} for _ = 1, count do local k, v @@ -770,7 +857,8 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) return map, pos end if t == T_MAPVAR then - local count; count, pos = readVarUInt(buf, pos) + local count + count, pos = readVarUInt(buf, pos) local map = {} for _ = 1, count do local k, v @@ -783,8 +871,10 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) -- typed array if t == T_TYPED_ARR then - local subtype = buffer.readu8(buf, pos); pos += 1 - local count; count, pos = readVarUInt(buf, pos) + local subtype = buffer.readu8(buf, pos) + pos += 1 + local count + count, pos = readVarUInt(buf, pos) local reader = TYPED_READERS[subtype] local arr = table.create(count) for i = 1, count do @@ -858,12 +948,14 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) end if t == T_INSTANCE then - local idx; idx, pos = readVarUInt(buf, pos) + local idx + idx, pos = readVarUInt(buf, pos) return refs and refs[idx] or nil, pos end if t == T_ENUMITEM then - local nameLen; nameLen, pos = readVarUInt(buf, pos) + local nameLen + nameLen, pos = readVarUInt(buf, pos) local enumName = buffer.readstring(buf, pos, nameLen) pos += nameLen local val = buffer.readu16(buf, pos) @@ -871,7 +963,8 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) end if t == T_ENUM then - local nameLen; nameLen, pos = readVarUInt(buf, pos) + local nameLen + nameLen, pos = readVarUInt(buf, pos) local enumName = buffer.readstring(buf, pos, nameLen) return (Enum :: any)[enumName], pos + nameLen end @@ -892,9 +985,12 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) if t == T_RECT then return Rect.new( - buffer.readf32(buf, pos), buffer.readf32(buf, pos + 4), - buffer.readf32(buf, pos + 8), buffer.readf32(buf, pos + 12) - ), pos + 16 + buffer.readf32(buf, pos), + buffer.readf32(buf, pos + 4), + buffer.readf32(buf, pos + 8), + buffer.readf32(buf, pos + 12) + ), + pos + 16 end if t == T_NUMBERRANGE then @@ -905,11 +1001,13 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) return Ray.new( Vector3.new(buffer.readf32(buf, pos), buffer.readf32(buf, pos + 4), buffer.readf32(buf, pos + 8)), Vector3.new(buffer.readf32(buf, pos + 12), buffer.readf32(buf, pos + 16), buffer.readf32(buf, pos + 20)) - ), pos + 24 + ), + pos + 24 end if t == T_COLSEQ then - local count = buffer.readu8(buf, pos); pos += 1 + local count = buffer.readu8(buf, pos) + pos += 1 local keypoints = table.create(count) for i = 1, count do local time = buffer.readf32(buf, pos) @@ -923,7 +1021,8 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) end if t == T_NUMSEQ then - local count = buffer.readu8(buf, pos); pos += 1 + local count = buffer.readu8(buf, pos) + pos += 1 local keypoints = table.create(count) for i = 1, count do local time = buffer.readf32(buf, pos) @@ -936,7 +1035,8 @@ unpackValue = function(buf: buffer, pos: number, refs: {any}?): (any, number) end if t == T_BUFFER then - local len; len, pos = readVarUInt(buf, pos) + local len + len, pos = readVarUInt(buf, pos) local b = buffer.create(len) buffer.copy(b, 0, buf, pos, len) return b, pos + len @@ -951,7 +1051,7 @@ local function build(w: Writer): buffer return result end -local function buildWithRefs(w: Writer): (buffer, {any}?) +local function buildWithRefs(w: Writer): (buffer, { any }?) local result = buffer.create(w.cursor) buffer.copy(result, 0, w.buf, 0, w.cursor) return result, #w.refs > 0 and table.clone(w.refs) or nil @@ -994,60 +1094,118 @@ function Schema.map(key: SchemaType, value: SchemaType): SchemaType return { type = "map", key = key, value = value } end -function Schema.struct(fields: {[string]: SchemaType}): SchemaType +function Schema.struct(fields: { [string]: SchemaType }): SchemaType local orderedFields = {} for k, v in fields do table.insert(orderedFields, { key = k, schema = v }) end - table.sort(orderedFields, function(a, b) return a.key < b.key end) + table.sort(orderedFields, function(a, b) + return a.key < b.key + end) return { type = "struct", fields = orderedFields } end local function compilePacker(s: SchemaType): (Writer, any) -> () - if s.type == "u8" then return wByte end - if s.type == "i8" then return function(w, v) ensureSpace(w, 1) buffer.writei8(w.buf, w.cursor, v) w.cursor += 1 end end - if s.type == "u16" then return wU16 end - if s.type == "i16" then return wI16 end - if s.type == "u32" then return wU32 end - if s.type == "i32" then return wI32 end - if s.type == "f32" then return wF32 end - if s.type == "f64" then return wF64 end - if s.type == "f16" then return wF16 end - if s.type == "boolean" then return function(w, v) wByte(w, v and 1 or 0) end end - if s.type == "string" then return function(w, v) local len = #v wVarUInt(w, len) wString(w, v) end end - - if s.type == "vector3" then return function(w, v) wF16(w, f32ToF16(v.X)) wF16(w, f32ToF16(v.Y)) wF16(w, f32ToF16(v.Z)) end end - if s.type == "vector2" then return function(w, v) wF16(w, f32ToF16(v.X)) wF16(w, f32ToF16(v.Y)) end end - - if s.type == "cframe" then - return function(w, v) - local pos = v.Position - local rx, ry, rz = v:ToOrientation() - wF16(w, f32ToF16(pos.X)) wF16(w, f32ToF16(pos.Y)) wF16(w, f32ToF16(pos.Z)) - wF16(w, f32ToF16(rx)) wF16(w, f32ToF16(ry)) wF16(w, f32ToF16(rz)) - end + if s.type == "u8" then + return wByte + end + if s.type == "i8" then + return function(w, v) + ensureSpace(w, 1) + buffer.writei8(w.buf, w.cursor, v) + w.cursor += 1 + end + end + if s.type == "u16" then + return wU16 + end + if s.type == "i16" then + return wI16 + end + if s.type == "u32" then + return wU32 + end + if s.type == "i32" then + return wI32 + end + if s.type == "f32" then + return wF32 + end + if s.type == "f64" then + return wF64 + end + if s.type == "f16" then + return wF16 + end + if s.type == "boolean" then + return function(w, v) + wByte(w, v and 1 or 0) + end + end + if s.type == "string" then + return function(w, v) + local len = #v + wVarUInt(w, len) + wString(w, v) + end end - if s.type == "color3" then - return function(w, v) + if s.type == "vector3" then + return function(w, v) + wF16(w, f32ToF16(v.X)) + wF16(w, f32ToF16(v.Y)) + wF16(w, f32ToF16(v.Z)) + end + end + if s.type == "vector2" then + return function(w, v) + wF16(w, f32ToF16(v.X)) + wF16(w, f32ToF16(v.Y)) + end + end + + if s.type == "cframe" then + return function(w, v) + local pos = v.Position + local rx, ry, rz = v:ToOrientation() + wF16(w, f32ToF16(pos.X)) + wF16(w, f32ToF16(pos.Y)) + wF16(w, f32ToF16(pos.Z)) + wF16(w, f32ToF16(rx)) + wF16(w, f32ToF16(ry)) + wF16(w, f32ToF16(rz)) + end + end + + if s.type == "color3" then + return function(w, v) wByte(w, math.clamp(math.round(v.R * 255), 0, 255)) wByte(w, math.clamp(math.round(v.G * 255), 0, 255)) wByte(w, math.clamp(math.round(v.B * 255), 0, 255)) - end + end + end + + if s.type == "instance" then + return function(w, v) + table.insert(w.refs, v) + wVarUInt(w, #w.refs) + end end - if s.type == "instance" then return function(w, v) table.insert(w.refs, v) wVarUInt(w, #w.refs) end end - if s.type == "struct" then local fields = {} for _, field in s.fields do table.insert(fields, { key = field.key, packer = compilePacker(field.schema) }) end return function(w, v) - if type(v) ~= "table" then error(`Expected table for struct, got {typeof(v)}`) end + if type(v) ~= "table" then + error(`Expected table for struct, got {typeof(v)}`) + end for _, f in fields do local val = v[f.key] - if val == nil then error(`Schema Error: Missing required field '{f.key}'`) end + if val == nil then + error(`Schema Error: Missing required field '{f.key}'`) + end f.packer(w, val) end end @@ -1056,11 +1214,15 @@ local function compilePacker(s: SchemaType): (Writer, any) -> () if s.type == "array" then local itemPacker = compilePacker(s.item) return function(w, v) - if type(v) ~= "table" then error(`Expected table for array, got {typeof(v)}`) end + if type(v) ~= "table" then + error(`Expected table for array, got {typeof(v)}`) + end local len = #v wVarUInt(w, len) for i = 1, len do - if v[i] == nil then error(`Schema Error: Array item at index {i} is nil`) end + if v[i] == nil then + error(`Schema Error: Array item at index {i} is nil`) + end itemPacker(w, v[i]) end end @@ -1071,7 +1233,9 @@ local function compilePacker(s: SchemaType): (Writer, any) -> () local valPacker = compilePacker(s.value) return function(w, v) local count = 0 - for _ in v do count += 1 end + for _ in v do + count += 1 + end wVarUInt(w, count) for k, val in v do keyPacker(w, k) @@ -1095,22 +1259,63 @@ local function compilePacker(s: SchemaType): (Writer, any) -> () return function() end end -local function compileReader(s: SchemaType): (buffer, number, {any}?) -> (any, number) - if s.type == "u8" then return function(b, c) return buffer.readu8(b, c), c + 1 end end - if s.type == "i8" then return function(b, c) return buffer.readi8(b, c), c + 1 end end - if s.type == "u16" then return function(b, c) return buffer.readu16(b, c), c + 2 end end - if s.type == "i16" then return function(b, c) return buffer.readi16(b, c), c + 2 end end - if s.type == "u32" then return function(b, c) return buffer.readu32(b, c), c + 4 end end - if s.type == "i32" then return function(b, c) return buffer.readi32(b, c), c + 4 end end - if s.type == "f32" then return function(b, c) return buffer.readf32(b, c), c + 4 end end - if s.type == "f64" then return function(b, c) return buffer.readf64(b, c), c + 8 end end - if s.type == "f16" then return function(b, c) return f16ToF32(buffer.readu16(b, c)), c + 2 end end - if s.type == "boolean" then return function(b, c) return buffer.readu8(b, c) ~= 0, c + 1 end end - if s.type == "string" then - return function(b, c) - local len; len, c = readVarUInt(b, c) - return buffer.readstring(b, c, len), c + len - end +local function compileReader(s: SchemaType): (buffer, number, { any }?) -> (any, number) + if s.type == "u8" then + return function(b, c) + return buffer.readu8(b, c), c + 1 + end + end + if s.type == "i8" then + return function(b, c) + return buffer.readi8(b, c), c + 1 + end + end + if s.type == "u16" then + return function(b, c) + return buffer.readu16(b, c), c + 2 + end + end + if s.type == "i16" then + return function(b, c) + return buffer.readi16(b, c), c + 2 + end + end + if s.type == "u32" then + return function(b, c) + return buffer.readu32(b, c), c + 4 + end + end + if s.type == "i32" then + return function(b, c) + return buffer.readi32(b, c), c + 4 + end + end + if s.type == "f32" then + return function(b, c) + return buffer.readf32(b, c), c + 4 + end + end + if s.type == "f64" then + return function(b, c) + return buffer.readf64(b, c), c + 8 + end + end + if s.type == "f16" then + return function(b, c) + return f16ToF32(buffer.readu16(b, c)), c + 2 + end + end + if s.type == "boolean" then + return function(b, c) + return buffer.readu8(b, c) ~= 0, c + 1 + end + end + if s.type == "string" then + return function(b, c) + local len + len, c = readVarUInt(b, c) + return buffer.readstring(b, c, len), c + len + end end if s.type == "vector3" then return function(b, c) @@ -1120,7 +1325,7 @@ local function compileReader(s: SchemaType): (buffer, number, {any}?) -> (any, n return Vector3.new(x, y, z), c + 6 end end - + if s.type == "vector2" then return function(b, c) local x = f16ToF32(buffer.readu16(b, c)) @@ -1128,7 +1333,7 @@ local function compileReader(s: SchemaType): (buffer, number, {any}?) -> (any, n return Vector2.new(x, y), c + 4 end end - + if s.type == "color3" then return function(b, c) local r = buffer.readu8(b, c) @@ -1151,11 +1356,12 @@ local function compileReader(s: SchemaType): (buffer, number, {any}?) -> (any, n end if s.type == "instance" then return function(b, c, refs) - local idx; idx, c = readVarUInt(b, c) + local idx + idx, c = readVarUInt(b, c) return refs and refs[idx] or nil, c end end - + if s.type == "struct" then local fields = {} for _, field in s.fields do @@ -1173,7 +1379,8 @@ local function compileReader(s: SchemaType): (buffer, number, {any}?) -> (any, n if s.type == "array" then local itemReader = compileReader(s.item) return function(b, c, refs) - local len; len, c = readVarUInt(b, c) + local len + len, c = readVarUInt(b, c) local arr = table.create(len) for i = 1, len do arr[i], c = itemReader(b, c, refs) @@ -1186,7 +1393,8 @@ local function compileReader(s: SchemaType): (buffer, number, {any}?) -> (any, n local keyReader = compileReader(s.key) local valReader = compileReader(s.value) return function(b, c, refs) - local count; count, c = readVarUInt(b, c) + local count + count, c = readVarUInt(b, c) local map = {} for _ = 1, count do local k, val @@ -1210,8 +1418,10 @@ local function compileReader(s: SchemaType): (buffer, number, {any}?) -> (any, n end end end - - return function(_, c) return nil, c end + + return function(_, c) + return nil, c + end end local function packStrict(w: Writer, s: SchemaType, v: any) @@ -1224,7 +1434,7 @@ local function readStrict(buf: buffer, cursor: number, s: SchemaType, refs: { an return reader(buf, cursor, refs) end -local function writeEvents(w: Writer, events: {{any}}, schemas: {[string]: SchemaType}) +local function writeEvents(w: Writer, events: { { any } }, schemas: { [string]: SchemaType }) local count = #events wVarUInt(w, count) for _, event in events do @@ -1240,7 +1450,7 @@ local function writeEvents(w: Writer, events: {{any}}, schemas: {[string]: Schem end end -local function readEvents(buf: buffer, refs: {any}?, schemas: {[string]: SchemaType}): {{any}} +local function readEvents(buf: buffer, refs: { any }?, schemas: { [string]: SchemaType }): { { any } } local pos, count = 0 count, pos = readVarUInt(buf, pos) local events = table.create(count) @@ -1252,11 +1462,11 @@ local function readEvents(buf: buffer, refs: {any}?, schemas: {[string]: SchemaT if schema then local val val, pos = readStrict(buf, pos, schema, refs) - args = {val} + args = { val } else args, pos = unpackValue(buf, pos, refs) end - events[i] = {remote, args} + events[i] = { remote, args } end return events end @@ -1271,17 +1481,17 @@ BufferSerdes.compileReader = compileReader BufferSerdes.packStrict = packStrict BufferSerdes.readStrict = readStrict -function BufferSerdes.write(data: any): (buffer, {any}?) +function BufferSerdes.write(data: any): (buffer, { any }?) local w = createWriter() packValue(w, data) return buildWithRefs(w) end -function BufferSerdes.read(buf: buffer, refs: {any}?): any +function BufferSerdes.read(buf: buffer, refs: { any }?): any return (unpackValue(buf, 0, refs)) end -function BufferSerdes.readAll(buf: buffer, refs: {any}?): {any} +function BufferSerdes.readAll(buf: buffer, refs: { any }?): { any } local bufLen = buffer.len(buf) local pos = 0 local results = {} @@ -1309,4 +1519,4 @@ BufferSerdes.readTagged = unpackValue BufferSerdes.packTagged = packValue BufferSerdes.unpack = unpackValue -return BufferSerdes :: typeof(BufferSerdes) \ No newline at end of file +return BufferSerdes :: typeof(BufferSerdes) diff --git a/src/Client/init.luau b/src/Client/init.luau index 1ba99ec..0dc162c 100644 --- a/src/Client/init.luau +++ b/src/Client/init.luau @@ -7,7 +7,7 @@ local RunService = game:GetService("RunService") local Thread = require("./Thread") local Buffer = require("./Buffer") local Event: RemoteEvent = script.Parent:WaitForChild("Event") -local Function: RemoteFunction = script.Parent:WaitForChild("Function") +local UnreliableEvent: UnreliableRemoteEvent = script.Parent:WaitForChild("UnreliableEvent") local deltaT: number, cycle: number = 0, 1 / 61 local writer: Buffer.Writer = Buffer.createWriter() @@ -17,10 +17,11 @@ type Connection = { } type Event = { remote: string, - fn: (Player, ...any?) -> (...any?), + fn: (Player, ...any?) -> ...any?, } local queueEvent: { { any } } = {} +local queueUnreliableEvent: { { any } } = {} local eventListeners: { Event } = {} local eventSchemas: { [string]: Buffer.SchemaType } = {} @@ -31,7 +32,7 @@ Client.useSchema = function(remoteName: string, schema: Buffer.SchemaType) eventSchemas[remoteName] = schema end -Client.Connect = function(remoteName: string, fn: (Player, ...any?) -> (...any?)): Connection +Client.Connect = function(remoteName: string, fn: (Player, ...any?) -> ...any?): Connection local detail = { remote = remoteName, fn = fn, @@ -40,7 +41,9 @@ Client.Connect = function(remoteName: string, fn: (Player, ...any?) -> (...any?) return { Connected = true, Disconnect = function(self: Connection) - if not self.Connected then return end + if not self.Connected then + return + end self.Connected = false local idx = table.find(eventListeners, detail) if idx then @@ -64,7 +67,7 @@ end Client.Wait = function(remoteName: string): (number, ...any?) local thread, t = coroutine.running(), os.clock() Client.Once(remoteName, function(...: any?) - task.spawn(thread, os.clock()-t, ...) + task.spawn(thread, os.clock() - t, ...) end) return coroutine.yield() end @@ -81,10 +84,10 @@ Client.Destroy = function(remoteName: string) Client.DisconnectAll(remoteName) end -Client.Fire = function(remoteName: string, ...: any?) - table.insert(queueEvent, { +Client.Fire = function(remoteName: string, reliable: boolean, ...: any?) + table.insert(reliable and queueEvent or queueUnreliableEvent, { remoteName, - { ... } :: any + { ... } :: any, }) end @@ -101,72 +104,110 @@ Client.Invoke = function(remoteName: string, timeout: number?, ...: any?): ...an task.spawn(pending, nil) pendingInvokes[id] = nil end) - table.insert(queueEvent, { "\0", - { remoteName, id, { ... } :: any } :: any + { remoteName, id, { ... } :: any } :: any, }) return coroutine.yield() end if RunService:IsClient() then - Event.OnClientEvent:Connect(function(b: buffer, ref: { Instance? }) - if type(b) ~= "buffer" then return end + local function processIncoming(b: buffer, ref: { Instance }?, handleInvokes: boolean) + if type(b) ~= "buffer" then + return + end local contents = Buffer.readEvents(b, ref, eventSchemas) for _, content in contents do local remote = content[1] local content = content[2] - if remote == "\1" then - local id = content[1] - local results = content[2] - local pending = pendingInvokes[id] - if pending then - task.spawn(pending :: any, table.unpack(results)) - pendingInvokes[id] = nil - end - continue - end - if #eventListeners == 0 then continue end - if remote == "\0" then - local remoteName = content[1] - local id = content[2] - local args = content[3] - for _, connection in eventListeners do - if connection.remote == remoteName then - Thread(function() - local results = { connection.fn(table.unpack(args)) } - table.insert(queueEvent, { - "\1", - { id, results } :: any - }) - end) - break + if handleInvokes then + if remote == "\1" then + local id = content[1] + local results = content[2] + local pending = pendingInvokes[id] + if pending then + task.spawn(pending :: any, table.unpack(results)) + pendingInvokes[id] = nil end + continue end + if remote == "\0" then + if #eventListeners == 0 then + continue + end + local remoteName = content[1] + local id = content[2] + local args = content[3] + for _, connection in eventListeners do + if connection.remote == remoteName then + Thread(function() + local results = { connection.fn(table.unpack(args)) } + table.insert(queueEvent, { + "\1", + { id, results } :: any, + }) + end) + break + end + end + continue + end + end + if #eventListeners == 0 then continue end for _, connection in eventListeners do - if connection.remote ~= remote then continue end + if connection.remote ~= remote then + continue + end Thread(connection.fn, table.unpack(content)) end end + end + + Event.OnClientEvent:Connect(function(b: buffer, ref: { Instance }?) + processIncoming(b, ref, true) end) + + UnreliableEvent.OnClientEvent:Connect(function(b: buffer, ref: { Instance }?) + processIncoming(b, ref, false) + end) + RunService.PostSimulation:Connect(function(d: number) deltaT += d - if deltaT < cycle then return end - deltaT = 0 - if #queueEvent == 0 then return end - Buffer.writeEvents(writer, queueEvent, eventSchemas) - do - local buf, ref = Buffer.buildWithRefs(writer) - Buffer.reset(writer) - if not ref or #ref == 0 then - Event:FireServer(buf) - else - Event:FireServer(buf, ref) - end + if deltaT < cycle then + return + end + deltaT = 0 + + -- reliable + if #queueEvent > 0 then + Buffer.writeEvents(writer, queueEvent, eventSchemas) + do + local buf, ref = Buffer.buildWithRefs(writer) + Buffer.reset(writer) + if not ref or #ref == 0 then + Event:FireServer(buf) + else + Event:FireServer(buf, ref) + end + end + table.clear(queueEvent) + end + -- unreliable + if #queueUnreliableEvent > 0 then + Buffer.writeEvents(writer, queueUnreliableEvent, eventSchemas) + do + local buf, ref = Buffer.buildWithRefs(writer) + Buffer.reset(writer) + if not ref or #ref == 0 then + UnreliableEvent:FireServer(buf) + else + UnreliableEvent:FireServer(buf, ref) + end + end + table.clear(queueUnreliableEvent) end - table.clear(queueEvent) end) end diff --git a/src/Server/init.luau b/src/Server/init.luau index 2e20e97..defc024 100644 --- a/src/Server/init.luau +++ b/src/Server/init.luau @@ -8,7 +8,7 @@ local RunService = game:GetService("RunService") local Thread = require("./Thread") local Buffer = require("./Buffer") local Event: RemoteEvent = script.Parent:WaitForChild("Event") -local Function: RemoteFunction = script.Parent:WaitForChild("Function") +local UnreliableEvent: UnreliableRemoteEvent = script.Parent:WaitForChild("UnreliableEvent") local deltaT: number, cycle: number = 0, 1 / 61 local writer: Buffer.Writer = Buffer.createWriter() @@ -18,12 +18,15 @@ type Connection = { } type Event = { remote: string, - fn: (Player, ...any?) -> (...any?), + fn: (Player, ...any?) -> ...any?, } local queueEvent: { [Player]: { { any } }, } = {} +local queueUnreliableEvent: { + [Player]: { { any } }, +} = {} local eventListeners: { Event } = {} local eventSchemas: { [string]: Buffer.SchemaType } = {} local players_ready: { Player } = {} @@ -35,7 +38,7 @@ Server.useSchema = function(remoteName: string, schema: Buffer.SchemaType) eventSchemas[remoteName] = schema end -Server.Connect = function(remoteName: string, fn: (Player, ...any?) -> (...any?)): Connection +Server.Connect = function(remoteName: string, fn: (Player, ...any?) -> ...any?): Connection local detail = { remote = remoteName, fn = fn, @@ -44,7 +47,9 @@ Server.Connect = function(remoteName: string, fn: (Player, ...any?) -> (...any?) return { Connected = true, Disconnect = function(self: Connection) - if not self.Connected then return end + if not self.Connected then + return + end self.Connected = false local idx = table.find(eventListeners, detail) if idx then @@ -68,7 +73,7 @@ end Server.Wait = function(remoteName: string): (number, ...any?) local thread, t = coroutine.running(), os.clock() Server.Once(remoteName, function(...: any?) - task.spawn(thread, os.clock()-t, ...) + task.spawn(thread, os.clock() - t, ...) end) return coroutine.yield() end @@ -86,12 +91,13 @@ Server.Destroy = function(remoteName: string) end Server.Fire = function(remoteName: string, reliable: boolean, player: Player, ...: any?) - if not queueEvent[player] then - queueEvent[player] = {} :: any + local targetQueue = reliable and queueEvent or queueUnreliableEvent + if not targetQueue[player] then + targetQueue[player] = {} :: any end - table.insert(queueEvent[player], { + table.insert(targetQueue[player], { remoteName, - { ... } :: any + { ... } :: any, }) end @@ -114,62 +120,93 @@ Server.Invoke = function(remoteName: string, player: Player, timeout: number?, . task.spawn(pending, nil) pendingInvokes[id] = nil end) - + if not queueEvent[player] then + queueEvent[player] = {} :: any + end table.insert(queueEvent[player], { "\0", - { remoteName, id, { ... } :: any } :: any + { remoteName, id, { ... } :: any } :: any, }) return coroutine.yield() end if RunService:IsServer() then - Event.OnServerEvent:Connect(function(player: Player, b: buffer, ref: { Instance? }) - if type(b) ~= "buffer" then return end + local function processIncoming(player: Player, b: buffer, ref: { Instance }?, handleInvokes: boolean) + if type(b) ~= "buffer" then + return + end local contents = Buffer.readEvents(b, ref, eventSchemas) for _, content in contents do local remote = content[1] local content = content[2] - if remote == "\1" then - local id = content[1] - local results = content[2] - local pending = pendingInvokes[id] - if pending then - task.spawn(pending :: any, table.unpack(results)) - pendingInvokes[id] = nil - end - continue - end - if #eventListeners == 0 then continue end - if remote == "\0" then - local remoteName = content[1] - local id = content[2] - local args = content[3] - for _, connection in eventListeners do - if connection.remote == remoteName then - Thread(function() - local results = { connection.fn(table.unpack(args)) } - table.insert(queueEvent[player], { - "\1", - { id, results } :: any - }) - end) - break + if handleInvokes then + if remote == "\1" then + local id = content[1] + local results = content[2] + local pending = pendingInvokes[id] + if pending then + task.spawn(pending :: any, table.unpack(results)) + pendingInvokes[id] = nil end + continue end + if remote == "\0" then + if #eventListeners == 0 then + continue + end + local remoteName = content[1] + local id = content[2] + local args = content[3] + for _, connection in eventListeners do + if connection.remote == remoteName then + Thread(function() + local results = { connection.fn(table.unpack(args)) } + if not queueEvent[player] then + queueEvent[player] = {} :: any + end + table.insert(queueEvent[player], { + "\1", + { id, results } :: any, + }) + end) + break + end + end + continue + end + end + if #eventListeners == 0 then continue end for _, connection in eventListeners do - if connection.remote ~= remote then continue end + if connection.remote ~= remote then + continue + end Thread(connection.fn, player, table.unpack(content)) end end + end + + Event.OnServerEvent:Connect(function(player: Player, b: buffer, ref: { Instance }?) + processIncoming(player, b, ref, true) end) + + UnreliableEvent.OnServerEvent:Connect(function(player: Player, b: buffer, ref: { Instance }?) + processIncoming(player, b, ref, false) + end) + RunService.PostSimulation:Connect(function(d: number) deltaT += d - if deltaT < cycle then return end + if deltaT < cycle then + return + end deltaT = 0 + + -- reliable for player: Player, content in queueEvent do - if #content == 0 or player.Parent ~= Players then continue end + if #content == 0 or player.Parent ~= Players then + continue + end Buffer.writeEvents(writer, content, eventSchemas) do local buf, ref = Buffer.buildWithRefs(writer) @@ -182,7 +219,25 @@ if RunService:IsServer() then end table.clear(queueEvent[player]) end + -- unreliable + for player: Player, content in queueUnreliableEvent do + if #content == 0 or player.Parent ~= Players then + continue + end + Buffer.writeEvents(writer, content, eventSchemas) + do + local buf, ref = Buffer.buildWithRefs(writer) + Buffer.reset(writer) + if not ref or #ref == 0 then + UnreliableEvent:FireClient(player, buf) + else + UnreliableEvent:FireClient(player, buf, ref) + end + end + table.clear(queueUnreliableEvent[player]) + end end) + local function onAdded(player: Player) if not table.find(players_ready, player) then table.insert(players_ready, player) @@ -190,6 +245,9 @@ if RunService:IsServer() then if not queueEvent[player] then queueEvent[player] = {} :: any end + if not queueUnreliableEvent[player] then + queueUnreliableEvent[player] = {} :: any + end end Players.PlayerAdded:Connect(onAdded) Players.PlayerRemoving:Connect(function(player: Player) @@ -198,6 +256,10 @@ if RunService:IsServer() then table.clear(queueEvent[player]) queueEvent[player] = nil end + if queueUnreliableEvent[player] then + table.clear(queueUnreliableEvent[player]) + queueUnreliableEvent[player] = nil + end end) for _, player: Player in ipairs(Players:GetPlayers()) do onAdded(player) diff --git a/src/init.luau b/src/init.luau index 0746c12..fb57156 100644 --- a/src/init.luau +++ b/src/init.luau @@ -6,6 +6,9 @@ if game.RunService:IsServer() then if not script:FindFirstChild("Event") then Instance.new("RemoteEvent", script).Name = "Event" end + if not script:FindFirstChild("UnreliableEvent") then + Instance.new("UnreliableRemoteEvent", script).Name = "UnreliableEvent" + end if not script:FindFirstChild("Function") then Instance.new("RemoteFunction", script).Name = "Function" end