bcachefs: Add safe versions of varint encode/decode
authorKent Overstreet <kent.overstreet@gmail.com>
Tue, 13 Jul 2021 20:03:51 +0000 (16:03 -0400)
committerKent Overstreet <kent.overstreet@linux.dev>
Sun, 22 Oct 2023 21:09:08 +0000 (17:09 -0400)
This adds safe versions of bch2_varint_(encode|decode) that don't read
or write past the end of the buffer, or varint being encoded.

Signed-off-by: Kent Overstreet <kent.overstreet@gmail.com>
fs/bcachefs/alloc_background.c
fs/bcachefs/inode.c
fs/bcachefs/varint.c
fs/bcachefs/varint.h

index fc20649b19cf49c5f940259eb7010a47b36d0d4c..26aca7d3977b8988f89470ab4c4fb7341f7a9db7 100644 (file)
@@ -130,7 +130,7 @@ static int bch2_alloc_unpack_v2(struct bkey_alloc_unpacked *out,
 
 #define x(_name, _bits)                                                        \
        if (fieldnr < a.v->nr_fields) {                                 \
-               ret = bch2_varint_decode(in, end, &v);                  \
+               ret = bch2_varint_decode_fast(in, end, &v);             \
                if (ret < 0)                                            \
                        return ret;                                     \
                in += ret;                                              \
@@ -166,7 +166,7 @@ static void bch2_alloc_pack_v2(struct bkey_alloc_buf *dst,
        nr_fields++;                                                    \
                                                                        \
        if (src._name) {                                                \
-               out += bch2_varint_encode(out, src._name);              \
+               out += bch2_varint_encode_fast(out, src._name);         \
                                                                        \
                last_nonzero_field = out;                               \
                last_nonzero_fieldnr = nr_fields;                       \
index c5f93b8ca1c6fd8ec7233f9b27485312b52fa87d..565aebba30e6087e48497c75112532ca00dd5083 100644 (file)
@@ -137,7 +137,7 @@ static void bch2_inode_pack_v2(struct bkey_inode_buf *packed,
        nr_fields++;                                                    \
                                                                        \
        if (inode->_name) {                                             \
-               ret = bch2_varint_encode(out, inode->_name);            \
+               ret = bch2_varint_encode_fast(out, inode->_name);       \
                out += ret;                                             \
                                                                        \
                if (_bits > 64)                                         \
@@ -246,13 +246,13 @@ static int bch2_inode_unpack_v2(struct bkey_s_c_inode inode,
 
 #define x(_name, _bits)                                                        \
        if (fieldnr < INODE_NR_FIELDS(inode.v)) {                       \
-               ret = bch2_varint_decode(in, end, &v[0]);               \
+               ret = bch2_varint_decode_fast(in, end, &v[0]);          \
                if (ret < 0)                                            \
                        return ret;                                     \
                in += ret;                                              \
                                                                        \
                if (_bits > 64) {                                       \
-                       ret = bch2_varint_decode(in, end, &v[1]);       \
+                       ret = bch2_varint_decode_fast(in, end, &v[1]);  \
                        if (ret < 0)                                    \
                                return ret;                             \
                        in += ret;                                      \
index 0f3d06a6a685d1b11f396d70c65466106ec12920..6955ff5dc19ccfd6701bc93a05741716036dbaa4 100644 (file)
@@ -2,10 +2,18 @@
 
 #include <linux/bitops.h>
 #include <linux/math.h>
+#include <linux/string.h>
 #include <asm/unaligned.h>
 
 #include "varint.h"
 
+/**
+ * bch2_varint_encode - encode a variable length integer
+ * @out - destination to encode to
+ * @v  - unsigned integer to encode
+ *
+ * Returns the size in bytes of the encoded integer - at most 9 bytes
+ */
 int bch2_varint_encode(u8 *out, u64 v)
 {
        unsigned bits = fls64(v|1);
@@ -14,16 +22,79 @@ int bch2_varint_encode(u8 *out, u64 v)
        if (likely(bytes < 9)) {
                v <<= bytes;
                v |= ~(~0 << (bytes - 1));
+               v = cpu_to_le64(v);
+               memcpy(out, &v, bytes);
        } else {
                *out++ = 255;
                bytes = 9;
+               put_unaligned_le64(v, out);
        }
 
-       put_unaligned_le64(v, out);
        return bytes;
 }
 
+/**
+ * bch2_varint_decode - encode a variable length integer
+ * @in - varint to decode
+ * @end        - end of buffer to decode from
+ * @out        - on success, decoded integer
+ *
+ * Returns the size in bytes of the decoded integer - or -1 on failure (would
+ * have read past the end of the buffer)
+ */
 int bch2_varint_decode(const u8 *in, const u8 *end, u64 *out)
+{
+       unsigned bytes = likely(in < end)
+               ? ffz(*in & 255) + 1
+               : 1;
+       u64 v;
+
+       if (unlikely(in + bytes > end))
+               return -1;
+
+       if (likely(bytes < 9)) {
+               v = 0;
+               memcpy(&v, in, bytes);
+               v = le64_to_cpu(v);
+               v >>= bytes;
+       } else {
+               v = get_unaligned_le64(++in);
+       }
+
+       *out = v;
+       return bytes;
+}
+
+/**
+ * bch2_varint_encode_fast - fast version of bch2_varint_encode
+ *
+ * This version assumes it's always safe to write 8 bytes to @out, even if the
+ * encoded integer would be smaller.
+ */
+int bch2_varint_encode_fast(u8 *out, u64 v)
+{
+       unsigned bits = fls64(v|1);
+       unsigned bytes = DIV_ROUND_UP(bits, 7);
+
+       if (likely(bytes < 9)) {
+               v <<= bytes;
+               v |= ~(~0 << (bytes - 1));
+       } else {
+               *out++ = 255;
+               bytes = 9;
+       }
+
+       put_unaligned_le64(v, out);
+       return bytes;
+}
+
+/**
+ * bch2_varint_decode_fast - fast version of bch2_varint_decode
+ *
+ * This version assumes that it is safe to read at most 8 bytes past the end of
+ * @end (we still return an error if the varint extends past @end).
+ */
+int bch2_varint_decode_fast(const u8 *in, const u8 *end, u64 *out)
 {
        u64 v = get_unaligned_le64(in);
        unsigned bytes = ffz(v & 255) + 1;
index 8daf813576b7b7277052899acd1c0b3d43b0f8b7..92a182fb3d7aed9fdcda1600451ca996ca5620b7 100644 (file)
@@ -5,4 +5,7 @@
 int bch2_varint_encode(u8 *, u64);
 int bch2_varint_decode(const u8 *, const u8 *, u64 *);
 
+int bch2_varint_encode_fast(u8 *, u64);
+int bch2_varint_decode_fast(const u8 *, const u8 *, u64 *);
+
 #endif /* _BCACHEFS_VARINT_H */