diff --git a/.typos.toml b/.typos.toml
index 0b4a95179..7b8092737 100644
--- a/.typos.toml
+++ b/.typos.toml
@@ -4,7 +4,8 @@ extend-ignore-identifiers-re = [
"mmaped",
"arange",
"Nd",
- "nin"
+ "nin",
+ "cudaDevAttrMaxSharedMemoryPerBlockOptin"
]
[files]
diff --git a/Cargo.lock b/Cargo.lock
index 121b6c542..fcef40396 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -20,9 +20,9 @@ checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d"
[[package]]
name = "addr2line"
-version = "0.24.1"
+version = "0.24.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f5fb1d8e4442bd405fdfd1dacb42792696b0cf9cb15882e5d097b742a676d375"
+checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1"
dependencies = [
"gimli",
]
@@ -354,18 +354,18 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]]
name = "bytemuck"
-version = "1.18.0"
+version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae"
+checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d"
dependencies = [
"bytemuck_derive",
]
[[package]]
name = "bytemuck_derive"
-version = "1.7.1"
+version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26"
+checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec"
dependencies = [
"proc-macro2",
"quote",
@@ -393,13 +393,14 @@ checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3"
[[package]]
name = "candle-core"
version = "0.7.2"
-source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80"
+source = "git+https://github.com/EricLBuehler/candle.git?rev=f2b6941#f2b6941a4856ae5d5b3413789f58d4ee3aa5a562"
dependencies = [
"accelerate-src",
"byteorder",
"candle-kernels",
"candle-metal-kernels",
"cudarc",
+ "float8",
"gemm",
"half",
"intel-mkl-src",
@@ -420,7 +421,7 @@ dependencies = [
[[package]]
name = "candle-flash-attn"
version = "0.7.2"
-source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80"
+source = "git+https://github.com/EricLBuehler/candle.git?rev=f2b6941#f2b6941a4856ae5d5b3413789f58d4ee3aa5a562"
dependencies = [
"anyhow",
"bindgen_cuda 0.1.5",
@@ -431,7 +432,7 @@ dependencies = [
[[package]]
name = "candle-kernels"
version = "0.7.2"
-source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80"
+source = "git+https://github.com/EricLBuehler/candle.git?rev=f2b6941#f2b6941a4856ae5d5b3413789f58d4ee3aa5a562"
dependencies = [
"bindgen_cuda 0.1.5",
]
@@ -439,7 +440,7 @@ dependencies = [
[[package]]
name = "candle-metal-kernels"
version = "0.7.2"
-source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80"
+source = "git+https://github.com/EricLBuehler/candle.git?rev=f2b6941#f2b6941a4856ae5d5b3413789f58d4ee3aa5a562"
dependencies = [
"metal",
"once_cell",
@@ -450,7 +451,7 @@ dependencies = [
[[package]]
name = "candle-nn"
version = "0.7.2"
-source = "git+https://github.com/EricLBuehler/candle.git?rev=20a57c4#20a57c4bcf300e4bc6c1e48f7f3702668ae8cb80"
+source = "git+https://github.com/EricLBuehler/candle.git?rev=f2b6941#f2b6941a4856ae5d5b3413789f58d4ee3aa5a562"
dependencies = [
"accelerate-src",
"candle-core",
@@ -467,9 +468,9 @@ dependencies = [
[[package]]
name = "cc"
-version = "1.1.24"
+version = "1.1.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "812acba72f0a070b003d3697490d2b55b837230ae7c6c6497f05cc2ddbb8d938"
+checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945"
dependencies = [
"shlex",
]
@@ -538,9 +539,9 @@ dependencies = [
[[package]]
name = "clap"
-version = "4.5.19"
+version = "4.5.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7be5744db7978a28d9df86a214130d106a89ce49644cbc4e3f0c22c3fba30615"
+checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8"
dependencies = [
"clap_builder",
"clap_derive",
@@ -548,9 +549,9 @@ dependencies = [
[[package]]
name = "clap_builder"
-version = "4.5.19"
+version = "4.5.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a5fbc17d3ef8278f55b282b2a2e75ae6f6c7d4bb70ed3d0382375104bfafdb4b"
+checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54"
dependencies = [
"anstream",
"anstyle",
@@ -882,18 +883,18 @@ dependencies = [
[[package]]
name = "derive_builder"
-version = "0.20.1"
+version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cd33f37ee6a119146a1781d3356a7c26028f83d779b2e04ecd45fdc75c76877b"
+checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947"
dependencies = [
"derive_builder_macro",
]
[[package]]
name = "derive_builder_core"
-version = "0.20.1"
+version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7431fa049613920234f22c47fdc33e6cf3ee83067091ea4277a3f8c4587aae38"
+checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8"
dependencies = [
"darling 0.20.10",
"proc-macro2",
@@ -903,9 +904,9 @@ dependencies = [
[[package]]
name = "derive_builder_macro"
-version = "0.20.1"
+version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc"
+checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [
"derive_builder_core",
"syn 2.0.79",
@@ -1119,6 +1120,19 @@ dependencies = [
"miniz_oxide 0.8.0",
]
+[[package]]
+name = "float8"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c7c3475274d374d263c4c40c43ad854c5bdf733c7db775bbd3c1ca2ad7427978"
+dependencies = [
+ "cudarc",
+ "half",
+ "num-traits",
+ "rand",
+ "rand_distr",
+]
+
[[package]]
name = "flume"
version = "0.11.0"
@@ -1187,9 +1201,9 @@ dependencies = [
[[package]]
name = "futures"
-version = "0.3.30"
+version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
+checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
@@ -1202,9 +1216,9 @@ dependencies = [
[[package]]
name = "futures-channel"
-version = "0.3.30"
+version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78"
+checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
@@ -1212,15 +1226,15 @@ dependencies = [
[[package]]
name = "futures-core"
-version = "0.3.30"
+version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d"
+checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
-version = "0.3.30"
+version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d"
+checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
@@ -1229,15 +1243,15 @@ dependencies = [
[[package]]
name = "futures-io"
-version = "0.3.30"
+version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
+checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
-version = "0.3.30"
+version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
+checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
@@ -1246,21 +1260,21 @@ dependencies = [
[[package]]
name = "futures-sink"
-version = "0.3.30"
+version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5"
+checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7"
[[package]]
name = "futures-task"
-version = "0.3.30"
+version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
+checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-util"
-version = "0.3.30"
+version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
+checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
@@ -1447,9 +1461,9 @@ dependencies = [
[[package]]
name = "gimli"
-version = "0.31.0"
+version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "32085ea23f3234fc7846555e85283ba4de91e21016dc0455a16286d87a292d64"
+checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glob"
@@ -1583,9 +1597,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
-version = "1.4.1"
+version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05"
+checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a"
dependencies = [
"bytes",
"futures-channel",
@@ -1794,9 +1808,9 @@ dependencies = [
[[package]]
name = "ipnet"
-version = "2.10.0"
+version = "2.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4"
+checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708"
[[package]]
name = "is_terminal_polyfill"
@@ -1845,9 +1859,9 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0"
[[package]]
name = "js-sys"
-version = "0.3.70"
+version = "0.3.72"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a"
+checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9"
dependencies = [
"wasm-bindgen",
]
@@ -2155,6 +2169,7 @@ dependencies = [
"derive_more",
"dirs",
"either",
+ "float8",
"futures",
"galil-seiferas",
"half",
@@ -2208,6 +2223,7 @@ dependencies = [
"anyhow",
"bindgen_cuda 0.1.6",
"candle-core",
+ "float8",
"half",
]
@@ -2243,8 +2259,10 @@ dependencies = [
"byteorder",
"candle-core",
"candle-nn",
+ "float8",
"half",
"lazy_static",
+ "once_cell",
"paste",
"rayon",
"serde",
@@ -2484,9 +2502,9 @@ dependencies = [
[[package]]
name = "object"
-version = "0.36.4"
+version = "0.36.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a"
+checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e"
dependencies = [
"memchr",
]
@@ -2535,12 +2553,9 @@ dependencies = [
[[package]]
name = "once_cell"
-version = "1.20.1"
+version = "1.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "82881c4be219ab5faaf2ad5e5e5ecdff8c66bd7402ca3160975c93b24961afd1"
-dependencies = [
- "portable-atomic",
-]
+checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
[[package]]
name = "onig"
@@ -2815,9 +2830,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
-version = "1.0.86"
+version = "1.0.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
+checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a"
dependencies = [
"unicode-ident",
]
@@ -2836,9 +2851,9 @@ dependencies = [
[[package]]
name = "pyo3"
-version = "0.22.3"
+version = "0.22.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225"
+checksum = "3d922163ba1f79c04bc49073ba7b32fd5a8d3b76a87c955921234b8e77333c51"
dependencies = [
"anyhow",
"cfg-if",
@@ -2867,9 +2882,9 @@ dependencies = [
[[package]]
name = "pyo3-build-config"
-version = "0.22.3"
+version = "0.22.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3"
+checksum = "bc38c5feeb496c8321091edf3d63e9a6829eab4b863b4a6a65f26f3e9cc6b179"
dependencies = [
"once_cell",
"target-lexicon",
@@ -2877,9 +2892,9 @@ dependencies = [
[[package]]
name = "pyo3-ffi"
-version = "0.22.3"
+version = "0.22.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c"
+checksum = "94845622d88ae274d2729fcefc850e63d7a3ddff5e3ce11bd88486db9f1d357d"
dependencies = [
"libc",
"pyo3-build-config",
@@ -2887,9 +2902,9 @@ dependencies = [
[[package]]
name = "pyo3-macros"
-version = "0.22.3"
+version = "0.22.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28"
+checksum = "e655aad15e09b94ffdb3ce3d217acf652e26bbc37697ef012f5e5e348c716e5e"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
@@ -2899,9 +2914,9 @@ dependencies = [
[[package]]
name = "pyo3-macros-backend"
-version = "0.22.3"
+version = "0.22.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1"
+checksum = "ae1e3f09eecd94618f60a455a23def79f79eba4dc561a97324bf9ac8c6df30ce"
dependencies = [
"heck",
"proc-macro2",
@@ -3295,9 +3310,9 @@ dependencies = [
[[package]]
name = "rustls"
-version = "0.23.13"
+version = "0.23.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8"
+checksum = "415d9944693cb90382053259f89fbb077ea730ad7273047ec63b19bc9b160ba8"
dependencies = [
"log",
"once_cell",
@@ -3319,9 +3334,9 @@ dependencies = [
[[package]]
name = "rustls-pki-types"
-version = "1.9.0"
+version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55"
+checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b"
[[package]]
name = "rustls-webpki"
@@ -3336,9 +3351,9 @@ dependencies = [
[[package]]
name = "rustversion"
-version = "1.0.17"
+version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6"
+checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248"
[[package]]
name = "ryu"
@@ -3367,9 +3382,9 @@ dependencies = [
[[package]]
name = "schannel"
-version = "0.1.24"
+version = "0.1.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b"
+checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1"
dependencies = [
"windows-sys 0.59.0",
]
@@ -4215,9 +4230,9 @@ dependencies = [
[[package]]
name = "unicode-bidi"
-version = "0.3.15"
+version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75"
+checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893"
[[package]]
name = "unicode-ident"
@@ -4329,9 +4344,9 @@ dependencies = [
[[package]]
name = "utoipa-gen"
-version = "4.3.0"
+version = "4.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7bf0e16c02bc4bf5322ab65f10ab1149bdbcaa782cba66dc7057370a3f8190be"
+checksum = "20c24e8ab68ff9ee746aad22d39b5535601e6416d1b0feeabf78be986a5c4392"
dependencies = [
"proc-macro-error",
"proc-macro2",
@@ -4446,9 +4461,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasm-bindgen"
-version = "0.2.93"
+version = "0.2.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5"
+checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e"
dependencies = [
"cfg-if",
"once_cell",
@@ -4457,9 +4472,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-backend"
-version = "0.2.93"
+version = "0.2.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b"
+checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358"
dependencies = [
"bumpalo",
"log",
@@ -4472,9 +4487,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-futures"
-version = "0.4.43"
+version = "0.4.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed"
+checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b"
dependencies = [
"cfg-if",
"js-sys",
@@ -4484,9 +4499,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro"
-version = "0.2.93"
+version = "0.2.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf"
+checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
@@ -4494,9 +4509,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro-support"
-version = "0.2.93"
+version = "0.2.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836"
+checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68"
dependencies = [
"proc-macro2",
"quote",
@@ -4507,15 +4522,15 @@ dependencies = [
[[package]]
name = "wasm-bindgen-shared"
-version = "0.2.93"
+version = "0.2.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484"
+checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d"
[[package]]
name = "web-sys"
-version = "0.3.70"
+version = "0.3.72"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0"
+checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112"
dependencies = [
"js-sys",
"wasm-bindgen",
diff --git a/Cargo.toml b/Cargo.toml
index 109717ae2..4d4a2c746 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -25,8 +25,8 @@ license = "MIT"
[workspace.dependencies]
anyhow = "1.0.80"
-candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" }
-candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" }
+candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "f2b6941" }
+candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "f2b6941" }
serde = "1.0.197"
serde_json = "1.0.114"
indexmap = { version = "2.2.5", features = ["serde"] }
@@ -37,7 +37,7 @@ tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
futures = "0.3"
clap = { version = "4.5.1", features = ["derive"] }
-pyo3 = { version = "0.22.0", features = ["full", "extension-module", "either"] }
+pyo3 = { version = "0.22.4", features = ["full", "extension-module", "either"] }
tokio = { version = "1.36.0", features = ["full", "rt-multi-thread"] }
once_cell = "1.19.0"
# All features but avif, avif increases the msrv dramatically
@@ -49,3 +49,4 @@ rayon = "1.1.0"
url = "2.5.2"
data-url = "0.3.1"
buildstructor = "0.5.4"
+float8 = "0.1.1"
diff --git a/README.md b/README.md
index 5d4cefc52..bbfbdfd30 100644
--- a/README.md
+++ b/README.md
@@ -28,6 +28,9 @@ Please submit requests for new models [here](https://github.com/EricLBuehler/mis
*After following installation instructions*
+- Check out UQFF for prequantized models of various methods!
+ - Models can be found [here](https://huggingface.co/collections/EricB/uqff-670e4a49d56ecdd3f7f0fd4c).
+
- 🦙📷 Run the **Llama 3.2 Vision** Model: [documentation and guide here](docs/VLLAMA.md)
@@ -70,7 +73,7 @@ Please submit requests for new models [here](https://github.com/EricLBuehler/mis
- Other models: [see a support matrix](#support-matrix) and [how to run them](#run-with-the-cli)
-Mistal.rs supports several model categories:
+Mistral.rs supports several model categories:
- Text to Text
- Text+Image to Text: Vision (see [the docs](docs/VISION_MODELS.md))
- Text to Image: Image Generation (see [the docs](docs/IMAGEGEN_MODELS.md))
@@ -91,7 +94,7 @@ Mistal.rs supports several model categories:
**Quantization**:
- [Details](docs/QUANTS.md)
- GGML: 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit, with ISQ support.
-- GPTQ: 2-bit, 3-bit, 4-bit and 8-bit
+- GPTQ: 2-bit, 3-bit, 4-bit and 8-bit, with [Marlin](https://github.com/IST-DASLab/marlin) kernel support in 4-bit and 8-bit.
- HQQ: 4-bit and 8 bit, with ISQ support
**Powerful**:
@@ -106,7 +109,7 @@ Mistal.rs supports several model categories:
- [PagedAttention](docs/PAGED_ATTENTION.md) and continuous batching
- Prefix caching
- [Topology](docs/TOPOLOGY.md): Configure ISQ and device mapping easily
-- [UQFF](docs/UQFF.md): Quantized file format for easy mixing of quants, see some [models](docs/UQFF.md#list-of-models) which have already been converted.
+- [UQFF](docs/UQFF.md): Quantized file format for easy mixing of quants, [collection here](https://huggingface.co/collections/EricB/uqff-670e4a49d56ecdd3f7f0fd4c).
- Speculative Decoding: Mix supported models as the draft model or the target model
- Dynamic LoRA adapter activation with adapter preloading: [examples and docs](docs/ADAPTER_MODELS.md#adapter-model-dynamic-adapter-activation)
@@ -202,7 +205,7 @@ Enabling features is done by passing `--features ...` to the build system. When
- Install the [Python package here](mistralrs-pyo3/README.md).
-1) Install required packages
+1) Install required packages:
- `OpenSSL` (*Example on Ubuntu:* `sudo apt install libssl-dev`)
- *Linux only:* `pkg-config` (*Example on Ubuntu:* `sudo apt install pkg-config`)
@@ -220,13 +223,13 @@ Enabling features is done by passing `--features ...` to the build system. When
huggingface-cli login
```
-4) Download the code
+4) Download the code:
```bash
git clone https://github.com/EricLBuehler/mistral.rs.git
cd mistral.rs
```
-5) Build or install
+5) Build or install:
- Base build command
```bash
cargo build --release
@@ -257,14 +260,14 @@ Enabling features is done by passing `--features ...` to the build system. When
```bash
cargo install --path mistralrs-server --features cuda
```
-6) The build process will output a binary `misralrs-server` at `./target/release/mistralrs-server` which may be copied into the working directory with the following command:
+6) The build process will output a binary `mistralrs-server` at `./target/release/mistralrs-server` which may be copied into the working directory with the following command:
*Example on Ubuntu:*
```
cp ./target/release/mistralrs-server ./mistralrs-server
```
-7) Use our APIs and integrations
+7) Use our APIs and integrations:
[APIs and integrations list](#apis-and-integrations)
@@ -377,8 +380,6 @@ please consider using the method demonstrated in examples below, where the token
Mistral.rs uses subcommands to control the model type. They are generally of format `-`. Please run `./mistralrs-server --help` to see the subcommands.
-Additionally, for models without quantization, the model architecture should be provided as the `--arch` or `-a` argument in contrast to GGUF models which encode the architecture in the file.
-
### Architecture for plain models
> Note: for plain models, you can specify the data type to load and run in. This must be one of `f32`, `f16`, `bf16` or `auto` to choose based on the device. This is specified in the `--dype`/`-d` parameter after the model architecture (`plain`).
diff --git a/docs/ISQ.md b/docs/ISQ.md
index 76cff4fc0..bfaad1a04 100644
--- a/docs/ISQ.md
+++ b/docs/ISQ.md
@@ -21,6 +21,7 @@ To set the ISQ type for individual layers, use a model [`topology`](TOPOLOGY.md)
- Q8K (*not available on CUDA*)
- HQQ4
- HQQ8
+- FP8
When using ISQ, it will automatically load ISQ-able weights into CPU memory before applying ISQ. The ISQ application process moves the weights to device memory. This process is implemented to avoid memory spikes from loading the model in full precision.
diff --git a/docs/QUANTS.md b/docs/QUANTS.md
index 7daa93a1c..6b37d35a0 100644
--- a/docs/QUANTS.md
+++ b/docs/QUANTS.md
@@ -12,6 +12,7 @@ Mistral.rs supports the following quantization:
- Supported in all plain and adapter models
- CUDA only
- 2, 3, 4, 8 bit
+ - [Marlin](https://github.com/IST-DASLab/marlin) kernel support in 4-bit and 8-bit.
- HQQ
- Supported in all plain and adapter models via ISQ
- CUDA and CPU only
@@ -41,6 +42,7 @@ cargo run --features cuda -- -i --isq Q4K plain -m microsoft/Phi-3-mini-4k-instr
- Use the `plain` (cli) / `Plain` (Python) model selector
- Provide the model ID for the GPTQ model
- Mistral.rs will automatically detect and use GPTQ quantization.
+- The [Marlin](https://github.com/IST-DASLab/marlin) kernel will automatically be used in 4-bit and 8-bit.
```
cargo run --features cuda -- -i plain -m kaitchup/Phi-3-mini-4k-instruct-gptq-4bit -a phi3
diff --git a/docs/UQFF.md b/docs/UQFF.md
index 7dfa4a30b..6a9686ac3 100644
--- a/docs/UQFF.md
+++ b/docs/UQFF.md
@@ -51,24 +51,31 @@ The following quantization formats are supported in UQFF. One can, of course, be
- HQQ4
- HQQ8
+- FP8:
+ - FP8 E4M3 (4-bit exponent, 3-bit mantissa)
+
## Loading a UQFF model
-To load a UQFF model, one should specify the artifact path. This can be either be a path to a UQFF file locally, or a Hugging Face model ID with the format `/`. For example, the following work:
+To load a UQFF model, one should specify the filename. This will be located based on the model ID, and can
+be loaded locally or from Hugging Face based on the model ID.
-- `EricB/Phi-3.5-mini-instruct-ISQ/phi3.5-mini-instruct-q4k.uqff`
+- `phi3.5-mini-instruct-q4k.uqff`
- `../UQFF/phi3.5-mini-instruct-q4k.uqff`
-> Note: when loading an UQFF model, it will take precedence over any ISQ setting.
+You can find a [collection of UQFF models here](https://huggingface.co/collections/EricB/uqff-670e4a49d56ecdd3f7f0fd4c), which each include a simple
+command to get started.
+
+> Note: when loading an UQFF model, *any* ISQ setting will be ignored.
### Running with the CLI
```
-cargo run --features cuda -- -i plain -m microsoft/Phi-3.5-mini-instruct --from-uqff EricB/Phi-3.5-mini-instruct-ISQ/phi3.5-mini-instruct-q4k.uqff
+./mistralrs-server -i plain -m EricB/Phi-3.5-mini-instruct-UQFF --from-uqff phi3.5-mini-instruct-f8e4m3.uqff
```
### Using with the Rust API
-Modify the Normal or Vision config as follows:
+Modify the Normal or Vision config as follows and update the model ID to point to a UQFF model:
```diff
NormalSpecificConfig {
@@ -78,7 +85,7 @@ NormalSpecificConfig {
organization: Default::default(),
write_uqff: None,
- from_uqff: None,
-+ from_uqff: Some("EricB/Phi-3.5-mini-instruct-ISQ/phi3.5-mini-instruct-q4k.uqff".to_string()),
++ from_uqff: Some("phi3.5-mini-instruct-q4k.uqff".to_string()), // Pull from specified HF hub repo
}
```
@@ -89,7 +96,7 @@ VisionSpecificConfig {
topology: None,
write_uqff: None,
- from_uqff: None,
-+ from_uqff: Some("../UQFF/phi3.5-mini-instruct-q4k.uqff".to_string()),
++ from_uqff: Some("../phi3.5-mini-instruct-q4k.uqff".to_string()), // Local path
}
```
@@ -97,8 +104,8 @@ VisionSpecificConfig {
Modify the `Which` instantiation as follows:
```diff
Which.Plain(
- model_id="microsoft/Phi-3.5-mini-instruct",
-+ from_uqff="EricB/Phi-3.5-mini-instruct-ISQ/phi3.5-mini-instruct-q4k.uqff"
+ model_id="EricB/Phi-3.5-mini-instruct-UQFF",
++ from_uqff="phi3.5-mini-instruct-q4k.uqff"
),
```
@@ -109,6 +116,11 @@ Creating a UQFF model requires you to generate the UQFF file.
- This means specifying a local path to a file ending in `.uqff`, where your new UQFF model will be created.
- The quantization of a UQFF model is determined from the ISQ or model topology (see the [topology docs](TOPOLOGY.md) for more details on how ISQ and the topology mix).
+Along with the UQFF file, the generation process will also output several `.json` configuration files and `residual.safetensors`. All of these files are considered the
+UQFF model, and should be kept together or uploaded.
+
+> Note: Only the `.uqff` files are unique to the quantization level(s). If you are generating multiple UQFF files, it is OK for the others to be overwritten.
+
After creating the UQFF file, you can upload the model to Hugging Face. To do this:
1) [Create a new model](https://huggingface.co/docs/transformers/v4.17.0/en/create_a_model).
2) Upload the UQFF file:
@@ -120,7 +132,7 @@ After creating the UQFF file, you can upload the model to Hugging Face. To do th
### Creating with the CLI
```
-cargo run --features cuda -- --isq Q4K -i plain -m microsoft/Phi-3.5-mini-instruct --write-uqff phi3.5-mini-instruct-q4k.uqff
+./mistralrs-server --isq Q4K -i plain -m microsoft/Phi-3.5-mini-instruct --write-uqff phi3.5-mini-instruct-q4k.uqff
```
### Creating with the Rust API
@@ -151,7 +163,7 @@ VisionSpecificConfig {
```
### Creating with the Python API
-Modify the `Which` instantiation as follows:
+Modify the `Which` instantiation as follows. Be sure to add the `in_situ_quant`.
```diff
Which.Plain(
model_id="microsoft/Phi-3.5-mini-instruct",
@@ -170,10 +182,6 @@ After this, you can use Git to track, commit, and push files.
## List of models
-Have you created a UQFF model on Hugging Face? If so, please [create an issue](https://github.com/EricLBuehler/mistral.rs/issues/new) and we will include it here!
+You can find a list of models in the [Hugging Face model collection](https://huggingface.co/collections/EricB/uqff-670e4a49d56ecdd3f7f0fd4c).
-| Name | Base model | UQFF model |
-| -- | -- | -- |
-| Phi 3.5 Mini Instruct | microsoft/Phi-3.5-mini-instruct | [EricB/Phi-3.5-mini-instruct-UQFF](EricB/Phi-3.5-mini-instruct-UQFF) |
-| Llama 3.2 Vision | meta-llama/Llama-3.2-11B-Vision-Instruct | [EricB/Llama-3.2-11B-Vision-Instruct-UQFF](https://huggingface.co/EricB/Llama-3.2-11B-Vision-Instruct-UQFF) |
-| Mistral Nemo 2407 | mistralai/Mistral-Nemo-Instruct-2407 | [EricB/Mistral-Nemo-Instruct-2407-UQFF](https://huggingface.co/EricB/Mistral-Nemo-Instruct-2407-UQFF) |
+Have you created a UQFF model on Hugging Face? If so, please [create an issue](https://github.com/EricLBuehler/mistral.rs/issues/new).
diff --git a/docs/UQFF/LAYOUT.md b/docs/UQFF/LAYOUT.md
index ceabd7c88..c0b29e65d 100644
--- a/docs/UQFF/LAYOUT.md
+++ b/docs/UQFF/LAYOUT.md
@@ -6,6 +6,7 @@ The following describes the exact memory layout of HQFF tensors of version 0.1.0
- [GGUF quantization](#gguf-quantization)
- [HQQ quantization](#hqq-quantization)
- [Uquantized layers](#unquantized-layers)
+- [FP8 layers](#fp8-layers)
- [Standard tensors](#standard-tensors)
@@ -32,6 +33,18 @@ The following describes the exact memory layout of HQFF tensors of version 0.1.0
| **Array** Weight tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |
| **[Optional]** **Array** Bias tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |
+## FP8 layers
+| ID | Element type | Endianness |
+| -------- | -------- | -------- |
+| HQFF version | u32 | little endian |
+| ISQ type (1) | u8 | little endian |
+| Whether bias data is included (boolean) | u8 | little endian |
+| **Array** Weight tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |
+| Dequant W scalar | f32 | little endian
+| Dequant X scalar | f32 | little endian
+| Quant scalar | f32 | little endian
+| Quantization type | u32 | little endian
+| **[Optional]** **Array** Bias tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |
## HQQ quantization
| ID | Element type | Endianness |
@@ -51,6 +64,19 @@ The following describes the exact memory layout of HQFF tensors of version 0.1.0
| CFG round zeroes (boolean) | u8 | little endian |
| CFG channel wise (boolean) | u8 | little endian |
+## FP8 layers
+| ID | Element type | Endianness |
+| -------- | -------- | -------- |
+| HQFF version | u32 | little endian |
+| ISQ type (3) | u8 | little endian |
+| Whether bias data is included (boolean) | u8 | little endian |
+| **Array** Weight tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |
+| Dequant scale W | f32 | little endian |
+| Dequant scale X | f32 | little endian |
+| Quant scale | f32 | little endian |
+| Layer dtype | u32 | little endian |
+| **[Optional]** **Array** Bias tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |
+
## Standard tensors
| ID | Element type | Endianness |
| -------- | -------- | -------- |
diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml
index dbcfab697..3a6e8a060 100644
--- a/mistralrs-core/Cargo.toml
+++ b/mistralrs-core/Cargo.toml
@@ -17,7 +17,7 @@ candle-core.workspace = true
candle-nn.workspace = true
serde.workspace = true
serde_json.workspace = true
-candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4", optional = true }
+candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "f2b6941", optional = true }
dirs = "5.0.1"
hf-hub = "0.3.2"
thiserror = "1.0.57"
@@ -78,10 +78,11 @@ regex = "1.10.6"
safetensors = "0.4.5"
serde_plain = "1.0.2"
as-any = "0.3.1"
+float8.workspace = true
[features]
pyo3_macros = ["pyo3"]
-cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda"]
+cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda", "float8/cuda"]
cudnn = ["candle-core/cudnn"]
metal = ["candle-core/metal", "candle-nn/metal"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
diff --git a/mistralrs-core/src/common_models/t5/mod.rs b/mistralrs-core/src/common_models/t5/mod.rs
index dc4918ee4..b51592725 100644
--- a/mistralrs-core/src/common_models/t5/mod.rs
+++ b/mistralrs-core/src/common_models/t5/mod.rs
@@ -5,6 +5,7 @@
use candle_core::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, VarBuilder};
+use float8::F8E4M3;
use serde::Deserialize;
use std::sync::Arc;
@@ -596,6 +597,7 @@ impl TensorInfExtend for Tensor {
DType::BF16 => Ok(sum.to_scalar::()? == half::bf16::from_f32_const(0.)),
DType::F32 => Ok(sum.to_scalar::()? == 0.),
DType::F64 => Ok(sum.to_scalar::()? == 0.),
+ DType::F8E4M3 => Ok(sum.to_scalar::()? == F8E4M3::ZERO),
}
}
}
@@ -611,6 +613,7 @@ fn clamp_for_f16(xs: &Tensor) -> Result {
DType::BF16 => half::bf16::MAX.to_f64_const() - 1000.,
DType::F32 => f32::MAX as f64 - 1000.,
DType::F64 => f64::MAX - 1000.,
+ DType::F8E4M3 => F8E4M3::MAX.to_f64() - 1000.,
};
if xs.is_inf()?.any()? {
max -= 1000.;
diff --git a/mistralrs-core/src/cublaslt/api.rs b/mistralrs-core/src/cublaslt/api.rs
index 24aca6ba2..8bb11d028 100644
--- a/mistralrs-core/src/cublaslt/api.rs
+++ b/mistralrs-core/src/cublaslt/api.rs
@@ -1,13 +1,14 @@
-pub use candle_core::cuda_backend::cudarc::cublaslt::Activation;
+use candle_core::cuda::cudarc::driver::DevicePtr;
+use float8::F8E4M3;
use std::ffi::c_int;
use candle_core::backend::BackendStorage;
use candle_core::cuda_backend::WrapErr;
-use candle_core::{CpuStorage, Device, Layout, Result, Shape, Storage, Tensor};
+use candle_core::{CpuStorage, DType, Device, Layout, Result, Shape, Storage, Tensor};
use half::{bf16, f16};
use std::sync::Arc;
-use candle_core::cuda_backend::cudarc::cublaslt::{CudaBlasLT, Matmul, MatmulConfig};
+use super::matmul::{Activation, CudaBlasLT, Matmul, MatmulConfig};
#[derive(Debug, Clone)]
pub struct CublasLt(Arc);
@@ -858,11 +859,12 @@ pub fn fused_batch_matmul(
a.apply_op2(b, op)
}
}
-
#[cfg(test)]
mod tests {
+ use std::f32::consts::PI;
+
use super::*;
- use candle_core::{DType, Device};
+ use candle_core::{DType, Device, IndexOp};
fn to_vec2_round(t: Tensor, digits: i32) -> Result>> {
let b = 10f32.powi(digits);
diff --git a/mistralrs-core/src/cublaslt/matmul.rs b/mistralrs-core/src/cublaslt/matmul.rs
new file mode 100644
index 000000000..898a30522
--- /dev/null
+++ b/mistralrs-core/src/cublaslt/matmul.rs
@@ -0,0 +1,453 @@
+use candle_core::cuda::cudarc::cublaslt::result::set_matrix_layout_attribute;
+use candle_core::cuda::cudarc::cublaslt::{result, result::CublasError, sys};
+use candle_core::cuda::cudarc::driver::sys::{CUdevice_attribute, CUdeviceptr, CUstream};
+use candle_core::cuda::cudarc::driver::{
+ CudaDevice, CudaSlice, DevicePtr, DevicePtrMut, DriverError,
+};
+use core::ffi::c_int;
+use core::mem;
+use float8::F8E4M3;
+use half::bf16;
+use std::sync::Arc;
+
+/// Wrapper around [sys::cublasLtHandle_t]
+///
+/// 1. Create with [CudaBlasLT::new()]
+/// 2. Execute matmul kernel with matmul. f32 is supported. f16 and bf16 are supported
+/// if feature `half` is activated
+///
+/// Note: This maintains a instance of [`Arc`], so will prevent the device
+/// from being dropped. Kernels will be launched on the device device default stream.
+#[derive(Debug)]
+pub struct CudaBlasLT {
+ handle: sys::cublasLtHandle_t,
+ workspace: Workspace,
+ device: Arc,
+}
+
+unsafe impl Send for CudaBlasLT {}
+
+unsafe impl Sync for CudaBlasLT {}
+
+impl CudaBlasLT {
+ /// Creates a new cublasLt handle.
+ pub fn new(device: Arc) -> Result {
+ let handle = result::create_handle()?;
+ let workspace = Workspace::new(device.clone()).unwrap();
+
+ Ok(Self {
+ handle,
+ workspace,
+ device,
+ })
+ }
+}
+
+impl Drop for CudaBlasLT {
+ fn drop(&mut self) {
+ let handle = mem::replace(&mut self.handle, std::ptr::null_mut());
+ if !handle.is_null() {
+ unsafe { result::destroy_handle(handle) }.unwrap();
+ }
+ }
+}
+
+/// User owned CublasLt workspace buffer.
+/// The workspace is initialised following the Nvidia recommendations:
+///
+/// 1. NVIDIA Hopper Architecture: 32 MiB
+/// 2. Other: 4 MiB
+#[derive(Debug, Clone)]
+pub struct Workspace {
+ pub(crate) buffer: CudaSlice,
+ pub(crate) size: usize,
+}
+
+impl Workspace {
+ /// Creates a CublasLt workspace buffer on the provided device
+ pub fn new(device: Arc) -> Result {
+ device.bind_to_thread()?;
+
+ let major =
+ device.attribute(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)?;
+ let workspace_size = if major >= 9 { 33_554_432 } else { 4_194_304 };
+
+ let buffer = unsafe { device.alloc::(workspace_size)? };
+ Ok(Self {
+ buffer,
+ size: workspace_size,
+ })
+ }
+}
+
+/// Available activation for kernel fusing in matmul
+#[derive(Debug, Clone)]
+pub enum Activation {
+ Relu,
+ Gelu,
+}
+
+/// MatrixLayout helper type
+struct MatrixLayout {
+ handle: sys::cublasLtMatrixLayout_t,
+}
+
+impl MatrixLayout {
+ fn new(
+ matrix_type: sys::cudaDataType,
+ rows: u64,
+ cols: u64,
+ ld: i64,
+ ) -> Result {
+ let handle = result::create_matrix_layout(matrix_type, rows, cols, ld)?;
+ Ok(Self { handle })
+ }
+
+ fn set_batch(&self, size: c_int, stride: i64) -> Result<(), CublasError> {
+ unsafe {
+ // Set batch size
+ set_matrix_layout_attribute(
+ self.handle,
+ sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
+ (&size) as *const _ as *const _,
+ mem::size_of::(),
+ )?;
+ // Set batch stride
+ set_matrix_layout_attribute(
+ self.handle,
+ sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
+ (&stride) as *const _ as *const _,
+ mem::size_of::(),
+ )?;
+ }
+ Ok(())
+ }
+}
+
+impl Drop for MatrixLayout {
+ fn drop(&mut self) {
+ // panic on failure
+ unsafe {
+ result::destroy_matrix_layout(self.handle).expect("Unable to destroy matrix layout")
+ }
+ }
+}
+
+enum Matrix {
+ A,
+ B,
+ C,
+ D,
+}
+
+/// MatmulDesc helper type
+struct MatmulDesc {
+ handle: sys::cublasLtMatmulDesc_t,
+}
+
+impl MatmulDesc {
+ fn new(
+ compute_type: sys::cublasComputeType_t,
+ scale_type: sys::cudaDataType,
+ ) -> Result {
+ let handle = result::create_matmul_desc(compute_type, scale_type)?;
+ Ok(Self { handle })
+ }
+
+ fn set_transpose(&self, transpose: bool, matrix: Matrix) -> Result<(), CublasError> {
+ // Set transpose
+ // 1 == T, 0 == N
+ let transpose = transpose as i32;
+ let attr = match matrix {
+ Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA,
+ Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB,
+ Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSC,
+ Matrix::D => unreachable!(),
+ };
+
+ unsafe {
+ result::set_matmul_desc_attribute(
+ self.handle,
+ attr,
+ (&transpose) as *const _ as *const _,
+ mem::size_of::(),
+ )?;
+ }
+ Ok(())
+ }
+
+ // Epilogue system can be leveraged to fuse add and activation operations
+ fn set_epilogue(
+ &self,
+ act: Option<&Activation>,
+ bias_ptr: Option<&CUdeviceptr>,
+ stride_bias: Option,
+ ) -> Result<(), CublasError> {
+ let epilogue = if let Some(bias_ptr) = bias_ptr {
+ let epilogue = act
+ .map(|act| match act {
+ // Act + bias
+ Activation::Relu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS,
+ Activation::Gelu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS,
+ })
+ // Only bias
+ .unwrap_or(sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS);
+
+ // Set bias CUdeviceptr in matmul_desc
+ unsafe {
+ result::set_matmul_desc_attribute(
+ self.handle,
+ sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER,
+ bias_ptr as *const CUdeviceptr as *const _,
+ mem::size_of::(),
+ )?;
+ }
+
+ if let Some(stride_bias) = stride_bias {
+ // Set bias batch stride
+ unsafe {
+ result::set_matmul_desc_attribute(
+ self.handle,
+ sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE,
+ (&stride_bias) as *const _ as *const _,
+ mem::size_of::(),
+ )?;
+ }
+ }
+ epilogue
+ } else if let Some(act) = act {
+ // Only Act
+ match act {
+ Activation::Relu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU,
+ Activation::Gelu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU,
+ }
+ } else {
+ // No epilogue
+ sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT
+ };
+
+ // Set epilogue
+ unsafe {
+ result::set_matmul_desc_attribute(
+ self.handle,
+ sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE,
+ (&epilogue) as *const _ as *const _,
+ mem::size_of::(),
+ )?;
+ }
+ Ok(())
+ }
+}
+
+impl Drop for MatmulDesc {
+ fn drop(&mut self) {
+ unsafe { result::destroy_matmul_desc(self.handle).expect("Unable to destroy matmul desc") }
+ }
+}
+
+/// MatmulPref helper type
+struct MatmulPref {
+ handle: sys::cublasLtMatmulPreference_t,
+}
+
+impl MatmulPref {
+ fn new() -> Result {
+ let handle = result::create_matmul_pref()?;
+ Ok(Self { handle })
+ }
+
+ fn set_workspace_size(&self, size: usize) -> Result<(), CublasError> {
+ unsafe {
+ // Set workspace size
+ result::set_matmul_pref_attribute(
+ self.handle,
+ sys::cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
+ (&size) as *const _ as *const _,
+ mem::size_of::(),
+ )?;
+ }
+ Ok(())
+ }
+}
+
+impl Drop for MatmulPref {
+ fn drop(&mut self) {
+ unsafe { result::destroy_matmul_pref(self.handle).expect("Unable to destroy matmul pref") }
+ }
+}
+
+/// [Matmul] super-trait
+pub trait MatmulShared {
+ /// Returns a reference to the underlying cublasLt handle.
+ fn handle(&self) -> &sys::cublasLtHandle_t;
+
+ /// Returns a reference to the underlying cublasLt workspace
+ fn workspace(&self) -> &Workspace;
+
+ /// Returns a reference to the underlying stream
+ fn stream(&self) -> &CUstream;
+}
+
+/// Configuration for [Matmul]
+#[derive(Debug, Copy, Clone)]
+pub struct MatmulConfig {
+ pub transa: bool,
+ pub transb: bool,
+ pub m: u64,
+ pub n: u64,
+ pub k: u64,
+ pub alpha: f32,
+ pub lda: i64,
+ pub ldb: i64,
+ pub beta: f32,
+ pub ldc: i64,
+ pub stride_a: Option,
+ pub stride_b: Option,
+ pub stride_c: Option,
+ pub stride_bias: Option,
+ pub batch_size: Option,
+}
+
+/// Matrix matrix multiplication with elements of type `T`.
+pub trait Matmul: MatmulShared {
+ /// Underlying CUDA Type for `T`
+ fn matrix_type() -> sys::cudaDataType;
+
+ /// Underlying CUDA Compute Type for `T`
+ fn compute_type() -> sys::cublasComputeType_t;
+
+ /// Matrix matrix multiplication. See
+ /// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul)
+ ///
+ /// # Safety
+ /// This is unsafe because improper arguments may lead to invalid
+ /// memory accesses.
+ unsafe fn matmul, O: DevicePtrMut>(
+ &self,
+ cfg: MatmulConfig,
+ a: &I,
+ b: &I,
+ c: &mut O,
+ bias: Option<&I>,
+ act: Option<&Activation>,
+ ) -> Result<(), CublasError> {
+ let (a_rows, a_cols) = if cfg.transa {
+ (cfg.k, cfg.m)
+ } else {
+ (cfg.m, cfg.k)
+ };
+ let (b_rows, b_cols) = if cfg.transb {
+ (cfg.n, cfg.k)
+ } else {
+ (cfg.k, cfg.n)
+ };
+
+ // Creates matrix layouts
+ let a_layout = MatrixLayout::new(Self::matrix_type(), a_rows, a_cols, cfg.lda)?;
+ if let (Some(batch_size), Some(stride_a)) = (cfg.batch_size, cfg.stride_a) {
+ a_layout.set_batch(batch_size, stride_a)?;
+ }
+
+ let b_layout = MatrixLayout::new(Self::matrix_type(), b_rows, b_cols, cfg.ldb)?;
+ if let (Some(batch_size), Some(stride_b)) = (cfg.batch_size, cfg.stride_b) {
+ b_layout.set_batch(batch_size, stride_b)?;
+ }
+
+ let c_layout = MatrixLayout::new(Self::matrix_type(), cfg.m, cfg.n, cfg.ldc)?;
+ if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) {
+ c_layout.set_batch(batch_size, stride_c)?;
+ }
+
+ // Matmul description
+ let matmul_desc = MatmulDesc::new(Self::compute_type(), sys::cudaDataType_t::CUDA_R_32F)?;
+
+ // Set transa
+ matmul_desc.set_transpose(cfg.transa, Matrix::A)?;
+ // Set transb
+ matmul_desc.set_transpose(cfg.transb, Matrix::B)?;
+
+ // Epilogue system can be leveraged to fuse add and activation operations
+ matmul_desc.set_epilogue(act, bias.map(|b| b.device_ptr()), cfg.stride_bias)?;
+
+ // Create matmul heuristic search preferences
+ let matmul_pref = MatmulPref::new()?;
+
+ // Set workspace size
+ matmul_pref.set_workspace_size(self.workspace().size)?;
+
+ // Get heuristic given Config, bias, act and workspace size
+ let heuristic = result::get_matmul_algo_heuristic(
+ *self.handle(),
+ matmul_desc.handle,
+ a_layout.handle,
+ b_layout.handle,
+ c_layout.handle,
+ c_layout.handle,
+ matmul_pref.handle,
+ )?;
+
+ // Launch matmul kernel
+ result::matmul(
+ *self.handle(),
+ matmul_desc.handle,
+ (&cfg.alpha) as *const _ as *const _,
+ (&cfg.beta) as *const _ as *const _,
+ *a.device_ptr() as *const _,
+ a_layout.handle,
+ *b.device_ptr() as *const _,
+ b_layout.handle,
+ *c.device_ptr_mut() as *const _,
+ c_layout.handle,
+ *c.device_ptr_mut() as *mut _,
+ c_layout.handle,
+ (&heuristic.algo) as *const _,
+ *self.workspace().buffer.device_ptr() as *const CUdeviceptr as *mut _,
+ self.workspace().size,
+ *self.stream() as *mut _,
+ )
+ }
+}
+
+impl MatmulShared for CudaBlasLT {
+ fn handle(&self) -> &sys::cublasLtHandle_t {
+ &self.handle
+ }
+
+ fn workspace(&self) -> &Workspace {
+ &self.workspace
+ }
+
+ fn stream(&self) -> &CUstream {
+ self.device.cu_stream()
+ }
+}
+
+impl Matmul for CudaBlasLT {
+ fn matrix_type() -> sys::cudaDataType {
+ sys::cudaDataType_t::CUDA_R_32F
+ }
+
+ fn compute_type() -> sys::cublasComputeType_t {
+ sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32
+ }
+}
+
+impl Matmul for CudaBlasLT {
+ fn matrix_type() -> sys::cudaDataType {
+ sys::cudaDataType_t::CUDA_R_16F
+ }
+
+ fn compute_type() -> sys::cublasComputeType_t {
+ sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
+ }
+}
+
+impl Matmul for CudaBlasLT {
+ fn matrix_type() -> sys::cudaDataType {
+ sys::cudaDataType_t::CUDA_R_16BF
+ }
+
+ fn compute_type() -> sys::cublasComputeType_t {
+ sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
+ }
+}
diff --git a/mistralrs-core/src/cublaslt/mod.rs b/mistralrs-core/src/cublaslt/mod.rs
index 7657186ca..9d6046b38 100644
--- a/mistralrs-core/src/cublaslt/mod.rs
+++ b/mistralrs-core/src/cublaslt/mod.rs
@@ -9,9 +9,11 @@ use std::sync::{Mutex, Once};
#[cfg(feature = "cuda")]
mod api;
+#[cfg(feature = "cuda")]
+mod matmul;
#[cfg(feature = "cuda")]
-use api::{fused_batch_matmul, fused_matmul, Activation, CublasLt};
+use api::{fused_batch_matmul, fused_matmul, CublasLt};
static INIT: Once = Once::new();
static mut CUBLASLT: Option = None;
@@ -70,8 +72,8 @@ impl CublasLtWrapper {
#[cfg(feature = "cuda")]
{
let inner_act = act.map(|a| match a {
- CandleActivation::Relu => Activation::Relu,
- CandleActivation::Gelu => Activation::Gelu,
+ CandleActivation::Relu => matmul::Activation::Relu,
+ CandleActivation::Gelu => matmul::Activation::Gelu,
_ => unreachable!("Unsupported activation in cublaslt matmul"),
});
let mut result = fused_batch_matmul(
diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs
index 14be94939..4d9afe5e8 100644
--- a/mistralrs-core/src/layers.rs
+++ b/mistralrs-core/src/layers.rs
@@ -18,7 +18,7 @@ use candle_nn::{
Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, Linear, Module, VarBuilder,
};
use mistralrs_quant::QuantMethod;
-use serde::Deserialize;
+use serde::{Deserialize, Serialize};
pub use crate::attention::Sdpa;
pub use crate::layers_masker::CausalMasker;
@@ -51,9 +51,21 @@ impl RmsNorm {
Ok(Self { eps, weight: w })
}
+ /// Gemma uses weight + 1.0. Undo for UQFF generation.
+ pub fn undo_gemma(&self) -> Result {
+ Ok(Self {
+ eps: self.eps,
+ weight: (&self.weight - 1.0)?,
+ })
+ }
+
pub fn from_w(w: Tensor, eps: f64) -> Result {
Ok(Self { eps, weight: w })
}
+
+ pub fn weight(&self) -> &Tensor {
+ &self.weight
+ }
}
impl Module for RmsNorm {
@@ -92,7 +104,8 @@ pub struct PhiRotaryEmbedding {
original_max_position_embeddings: usize,
}
-#[derive(Debug, Clone, Deserialize)]
+#[derive(Debug, Clone, Deserialize, Serialize)]
+#[serde(rename_all = "lowercase")]
pub enum ScaledRopeType {
#[serde(alias = "su")]
#[serde(alias = "longrope")]
@@ -114,7 +127,7 @@ impl FromStr for ScaledRopeType {
}
}
-#[derive(Debug, Clone, Deserialize)]
+#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum PhiRopeScalingConfig {
Classic {
@@ -395,7 +408,7 @@ pub enum Llama3RotaryEmbedding {
Default(RotaryEmbedding),
}
-#[derive(Debug, Clone, Deserialize, Default)]
+#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub enum Llama3RopeType {
#[serde(rename = "llama3")]
Llama3,
@@ -404,7 +417,7 @@ pub enum Llama3RopeType {
Default,
}
-#[derive(Debug, Clone, Deserialize, Default)]
+#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Llama3RopeConfig {
pub factor: f32,
pub low_freq_factor: f32,
@@ -872,6 +885,51 @@ impl RotaryEmbedding {
}
}
+#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
+#[serde(rename_all = "lowercase")]
+pub enum Activation {
+ #[default]
+ #[serde(alias = "gelu")]
+ Gelu,
+ #[serde(alias = "gelu_new")]
+ NewGelu,
+ Relu,
+ Relu2,
+ Relu6,
+ Silu,
+ Sigmoid,
+ HardSigmoid,
+ Swiglu,
+ Swish,
+ HardSwish,
+ Elu(f64),
+ LeakyRelu(f64),
+ #[serde(alias = "gelu_pytorch_tanh")]
+ GeluPytorchTanh,
+}
+
+impl Module for Activation {
+ fn forward(&self, xs: &Tensor) -> Result {
+ match self {
+ Self::Gelu => xs.gelu_erf(),
+ // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
+ Self::NewGelu => xs.gelu(),
+ Self::Relu => xs.relu(),
+ Self::Relu2 => xs.relu()?.sqr(),
+ Self::Relu6 => xs.clamp(0f32, 6f32),
+ Self::Silu => xs.silu(),
+ Self::Sigmoid => candle_nn::ops::sigmoid(xs),
+ Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
+ Self::Swiglu => candle_nn::ops::swiglu(xs),
+ Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
+ Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
+ &Self::Elu(alpha) => xs.elu(alpha),
+ &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
+ Self::GeluPytorchTanh => xs.gelu(),
+ }
+ }
+}
+
fn lp_norm(xs: &Tensor, p: usize, dim: usize) -> Result {
let l2_norm = xs.powf(p as f64)?.sum_keepdim(dim)?.sqrt()?;
Ok(l2_norm) //xs.broadcast_div(&l2_norm)
diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs
index 25674e5c4..e38ae1632 100644
--- a/mistralrs-core/src/lib.rs
+++ b/mistralrs-core/src/lib.rs
@@ -1,5 +1,6 @@
#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
+use candle_core::Device;
use cublaslt::setup_cublas_lt_wrapper;
use engine::Engine;
pub use engine::{EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP};
@@ -107,6 +108,11 @@ pub use utils::paged_attn_supported;
pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
static ENGINE_ID: AtomicUsize = AtomicUsize::new(0);
+pub struct MistralRsConfig {
+ pub kind: ModelKind,
+ pub device: Device,
+}
+
/// The MistralRs struct handles sending requests to the engine.
/// It is the core multi-threaded component of mistral.rs, and uses `mspc`
/// `Sender` and `Receiver` primitives to send and receive requests to the
@@ -121,6 +127,7 @@ pub struct MistralRs {
engine_handler: RwLock>,
engine_id: usize,
category: ModelCategory,
+ config: MistralRsConfig,
}
#[derive(Clone)]
@@ -324,6 +331,10 @@ impl MistralRs {
let sender = RwLock::new(tx);
let id = pipeline.try_lock().unwrap().name();
+ let kind = pipeline.try_lock().unwrap().get_metadata().kind.clone();
+ let device = pipeline.try_lock().unwrap().device();
+ let config = MistralRsConfig { kind, device };
+
let engine_handler = thread::spawn(move || {
let rt = Runtime::new().unwrap();
rt.block_on(async move {
@@ -357,6 +368,7 @@ impl MistralRs {
reboot_state,
engine_handler: RwLock::new(engine_handler),
category,
+ config,
})
}
@@ -483,4 +495,8 @@ impl MistralRs {
.expect("Unable to write data");
}
}
+
+ pub fn config(&self) -> &MistralRsConfig {
+ &self.config
+ }
}
diff --git a/mistralrs-core/src/models/gemma.rs b/mistralrs-core/src/models/gemma.rs
index 16d17d4cd..8a1b4ef1f 100644
--- a/mistralrs-core/src/models/gemma.rs
+++ b/mistralrs-core/src/models/gemma.rs
@@ -3,7 +3,7 @@
use std::{collections::HashMap, sync::Arc};
use candle_core::{DType, Device, Module, Result, Tensor};
-use candle_nn::{Activation, Linear, RotaryEmbedding, VarBuilder};
+use candle_nn::{Linear, RotaryEmbedding, VarBuilder};
use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear};
use crate::{
@@ -14,7 +14,7 @@ use crate::{
attention::SdpaParams,
device_map::DeviceMapper,
get_delta_from_lora_ab,
- layers::{CausalMasker, MatMul, RmsNorm, Sdpa},
+ layers::{Activation, CausalMasker, MatMul, RmsNorm, Sdpa},
layers_masker::PastKvLenCache,
paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
pipeline::{
@@ -23,7 +23,7 @@ use crate::{
Cache, IsqModel, NormalLoadingMetadata, NormalModel,
},
serde_default_fn,
- utils::progress::NiceProgressBar,
+ utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
fn default_max_position_embeddings() -> usize {
@@ -32,7 +32,7 @@ fn default_max_position_embeddings() -> usize {
serde_default_fn!(bool, word_emb_default, false);
-#[derive(serde::Deserialize, Debug, Clone, Default)]
+#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, Default)]
pub struct Config {
pub attention_bias: bool,
pub head_dim: usize,
@@ -75,7 +75,7 @@ struct MLP {
gate_proj: Arc,
up_proj: Arc,
down_proj: Arc,
- act_fn: candle_nn::Activation,
+ act_fn: Activation,
params: Vec,
}
@@ -538,7 +538,7 @@ impl Model {
num_kv_heads: cfg.num_key_value_heads,
num_attn_heads: cfg.num_attention_heads,
sliding_window: None,
- head_dim: None,
+ head_dim: Some(cfg.head_dim),
},
})
}
@@ -615,6 +615,26 @@ impl IsqModel for Model {
}
(tensors, &*self.mapper)
}
+
+ fn residual_tensors(&self) -> Vec<(String, Tensor)> {
+ let uvb = UnVarBuilder::new();
+
+ let uvb_m = uvb.pp("model");
+ uvb_m.pp("embed_tokens").add(&self.embed_tokens);
+ uvb_m.pp("norm").add(&self.norm.undo_gemma().unwrap());
+
+ for (layer_idx, layer) in self.layers.iter().enumerate() {
+ let uvb_l = uvb_m.pp("layers").pp(layer_idx);
+ uvb_l
+ .pp("input_layernorm")
+ .add(&layer.input_layernorm.undo_gemma().unwrap());
+ uvb_l
+ .pp("post_attention_layernorm")
+ .add(&layer.post_attention_layernorm.undo_gemma().unwrap());
+ }
+
+ uvb.to_safetensors()
+ }
}
impl NormalModel for Model {
diff --git a/mistralrs-core/src/models/gemma2.rs b/mistralrs-core/src/models/gemma2.rs
index 3edba1491..6448e6407 100644
--- a/mistralrs-core/src/models/gemma2.rs
+++ b/mistralrs-core/src/models/gemma2.rs
@@ -3,7 +3,7 @@
use std::{collections::HashMap, sync::Arc};
use candle_core::{DType, Device, Module, Result, Tensor};
-use candle_nn::{Activation, Linear, RotaryEmbedding, VarBuilder};
+use candle_nn::{Linear, RotaryEmbedding, VarBuilder};
use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear};
use crate::{
@@ -14,17 +14,17 @@ use crate::{
attention::SdpaParams,
device_map::DeviceMapper,
get_delta_from_lora_ab,
- layers::{CausalMasker, MatMul, RmsNorm, Sdpa},
+ layers::{Activation, CausalMasker, MatMul, RmsNorm, Sdpa},
paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
pipeline::{
extract_logits,
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
Cache, IsqModel, NormalLoadingMetadata, NormalModel,
},
- utils::progress::NiceProgressBar,
+ utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
-#[derive(Debug, Clone, Default)]
+#[derive(Debug, Clone, Default, serde::Serialize)]
pub struct Config {
pub attention_bias: bool,
pub head_dim: usize,
@@ -69,7 +69,7 @@ struct MLP {
gate_proj: Arc,
up_proj: Arc,
down_proj: Arc,
- act_fn: candle_nn::Activation,
+ act_fn: Activation,
params: Vec,
}
@@ -687,6 +687,32 @@ impl IsqModel for Model {
}
(tensors, &*self.mapper)
}
+
+ fn residual_tensors(&self) -> Vec<(String, Tensor)> {
+ let uvb = UnVarBuilder::new();
+
+ let uvb_m = uvb.pp("model");
+ uvb_m.pp("embed_tokens").add(&self.embed_tokens);
+ uvb_m.pp("norm").add(&self.norm.undo_gemma().unwrap());
+
+ for (layer_idx, layer) in self.layers.iter().enumerate() {
+ let uvb_l = uvb_m.pp("layers").pp(layer_idx);
+ uvb_l
+ .pp("input_layernorm")
+ .add(&layer.input_layernorm.undo_gemma().unwrap());
+ uvb_l
+ .pp("post_attention_layernorm")
+ .add(&layer.post_attention_layernorm.undo_gemma().unwrap());
+ uvb_l
+ .pp("pre_feedforward_layernorm")
+ .add(&layer.pre_feedforward_layernorm.undo_gemma().unwrap());
+ uvb_l
+ .pp("post_feedforward_layernorm")
+ .add(&layer.post_feedforward_layernorm.undo_gemma().unwrap());
+ }
+
+ uvb.to_safetensors()
+ }
}
impl NormalModel for Model {
diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs
index 7969b637d..0f10d03e9 100644
--- a/mistralrs-core/src/models/llama.rs
+++ b/mistralrs-core/src/models/llama.rs
@@ -3,7 +3,7 @@
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear};
-use serde::Deserialize;
+use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc};
use crate::{
@@ -23,12 +23,12 @@ use crate::{
IsqModel, NormalLoadingMetadata, NormalModel,
},
serde_default_fn,
- utils::progress::NiceProgressBar,
+ utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
serde_default_fn!(bool, word_emb_default, false);
-#[derive(Debug, Clone, Deserialize, Default)]
+#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Config {
pub hidden_size: usize,
pub intermediate_size: usize,
@@ -585,6 +585,22 @@ impl IsqModel for Llama {
}
(tensors, &*self.mapper)
}
+
+ fn residual_tensors(&self) -> Vec<(String, Tensor)> {
+ let uvb = UnVarBuilder::new();
+
+ let uvb_m = uvb.pp("model");
+ uvb_m.pp("embed_tokens").add(&self.wte);
+ uvb_m.pp("norm").add(&self.ln_f);
+
+ for (layer_idx, layer) in self.blocks.iter().enumerate() {
+ let uvb_l = uvb_m.pp("layers").pp(layer_idx);
+ uvb_l.pp("input_layernorm").add(&layer.rms_1);
+ uvb_l.pp("post_attention_layernorm").add(&layer.rms_2);
+ }
+
+ uvb.to_safetensors()
+ }
}
impl NormalModel for Llama {
diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs
index 9aad4cc25..af86aa02d 100644
--- a/mistralrs-core/src/models/mistral.rs
+++ b/mistralrs-core/src/models/mistral.rs
@@ -2,8 +2,9 @@
/// Mistral LLM, https://github.com/mistralai/mistral-src
use candle_core::{DType, Device, Module, Result, Tensor};
-use candle_nn::{Activation, VarBuilder};
+use candle_nn::VarBuilder;
use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear};
+use serde::Serialize;
use std::{collections::HashMap, sync::Arc};
use crate::{
@@ -14,7 +15,7 @@ use crate::{
attention::SdpaParams,
device_map::DeviceMapper,
get_delta_from_lora_ab,
- layers::{CausalMasker, MatMul, RmsNorm, RotaryEmbedding, Sdpa},
+ layers::{Activation, CausalMasker, MatMul, RmsNorm, RotaryEmbedding, Sdpa},
layers_masker::PastKvLenCache,
paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
pipeline::{
@@ -22,10 +23,10 @@ use crate::{
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
Cache, IsqModel, NormalLoadingMetadata, NormalModel,
},
- utils::progress::NiceProgressBar,
+ utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
-#[derive(Debug, Clone, Default)]
+#[derive(Debug, Clone, Default, Serialize)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) hidden_size: usize,
@@ -658,6 +659,24 @@ impl IsqModel for Model {
}
(tensors, &*self.mapper)
}
+
+ fn residual_tensors(&self) -> Vec<(String, Tensor)> {
+ let uvb = UnVarBuilder::new();
+
+ let uvb_m = uvb.pp("model");
+ uvb_m.pp("embed_tokens").add(&self.embed_tokens);
+ uvb_m.pp("norm").add(&self.norm);
+
+ for (layer_idx, layer) in self.layers.iter().enumerate() {
+ let uvb_l = uvb_m.pp("layers").pp(layer_idx);
+ uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
+ uvb_l
+ .pp("post_attention_layernorm")
+ .add(&layer.post_attention_layernorm);
+ }
+
+ uvb.to_safetensors()
+ }
}
impl NormalModel for Model {
diff --git a/mistralrs-core/src/models/mixtral.rs b/mistralrs-core/src/models/mixtral.rs
index ba37dab52..147204187 100644
--- a/mistralrs-core/src/models/mixtral.rs
+++ b/mistralrs-core/src/models/mixtral.rs
@@ -4,16 +4,16 @@
/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
/// https://mistral.ai/news/mixtral-of-experts/
use candle_core::{DType, Device, Module, Result, Tensor};
-use candle_nn::{Activation, RotaryEmbedding, VarBuilder};
+use candle_nn::{RotaryEmbedding, VarBuilder};
use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear};
-use serde::Deserialize;
+use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc};
use crate::{
amoe::AnyMoeBaseModelMixin,
attention::SdpaParams,
device_map::DeviceMapper,
- layers::{CausalMasker, MatMul, RmsNorm, Sdpa},
+ layers::{Activation, CausalMasker, MatMul, RmsNorm, Sdpa},
layers_masker::PastKvLenCache,
paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
pipeline::{
@@ -22,13 +22,13 @@ use crate::{
Cache, IsqModel, NormalLoadingMetadata, NormalModel,
},
serde_default_fn,
- utils::progress::NiceProgressBar,
+ utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
serde_default_fn!(bool, word_emb_default, false);
/// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113
-#[derive(Debug, Clone, Deserialize)]
+#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) hidden_size: usize,
@@ -663,6 +663,24 @@ impl IsqModel for Model {
}
(tensors, &*self.mapper)
}
+
+ fn residual_tensors(&self) -> Vec<(String, Tensor)> {
+ let uvb = UnVarBuilder::new();
+
+ let uvb_m = uvb.pp("model");
+ uvb_m.pp("embed_tokens").add(&self.embed_tokens);
+ uvb_m.pp("norm").add(&self.norm);
+
+ for (layer_idx, layer) in self.layers.iter().enumerate() {
+ let uvb_l = uvb_m.pp("layers").pp(layer_idx);
+ uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
+ uvb_l
+ .pp("post_attention_layernorm")
+ .add(&layer.post_attention_layernorm);
+ }
+
+ uvb.to_safetensors()
+ }
}
impl NormalModel for Model {
diff --git a/mistralrs-core/src/models/phi2.rs b/mistralrs-core/src/models/phi2.rs
index 913275e2e..9f07075cc 100644
--- a/mistralrs-core/src/models/phi2.rs
+++ b/mistralrs-core/src/models/phi2.rs
@@ -7,11 +7,9 @@ use std::{collections::HashMap, sync::Arc};
/// This corresponds to the model update made with the following commit:
/// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869
use candle_core::{DType, Device, Result, Tensor};
-use candle_nn::{
- embedding, layer_norm, Activation, Embedding, LayerNorm, RotaryEmbedding, VarBuilder,
-};
+use candle_nn::{embedding, layer_norm, Embedding, LayerNorm, RotaryEmbedding, VarBuilder};
use mistralrs_quant::{QuantMethod, QuantizedConfig};
-use serde::Deserialize;
+use serde::{Deserialize, Serialize};
use crate::{
amoe::{
@@ -21,7 +19,7 @@ use crate::{
attention::SdpaParams,
device_map::DeviceMapper,
get_delta_from_lora_ab,
- layers::{CausalMasker, MatMul, Sdpa},
+ layers::{Activation, CausalMasker, MatMul, Sdpa},
layers_masker::PastKvLenCache,
paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
pipeline::{
@@ -30,13 +28,13 @@ use crate::{
Cache, IsqModel, NormalLoadingMetadata, NormalModel,
},
serde_default_fn,
- utils::progress::NiceProgressBar,
+ utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
serde_default_fn!(bool, word_emb_default, false);
// https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py
-#[derive(Debug, Clone, Deserialize, Default)]
+#[derive(Debug, Clone, Deserialize, Default, Serialize)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) hidden_size: usize,
@@ -599,6 +597,21 @@ impl IsqModel for Model {
}
(tensors, &*self.mapper)
}
+
+ fn residual_tensors(&self) -> Vec<(String, Tensor)> {
+ let uvb = UnVarBuilder::new();
+
+ let uvb_m = uvb.pp("model");
+ uvb_m.pp("embed_tokens").add(&self.embed_tokens);
+ uvb_m.pp("norm").add(&self.final_layernorm);
+
+ for (layer_idx, layer) in self.layers.iter().enumerate() {
+ let uvb_l = uvb_m.pp("layers").pp(layer_idx);
+ uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
+ }
+
+ uvb.to_safetensors()
+ }
}
impl NormalModel for Model {
diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs
index b0d9ddce2..c371ab0f9 100644
--- a/mistralrs-core/src/models/phi3.rs
+++ b/mistralrs-core/src/models/phi3.rs
@@ -16,8 +16,8 @@ use crate::{
device_map::DeviceMapper,
get_delta_from_lora_ab,
layers::{
- CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig, PhiRotaryEmbedding, RmsNorm,
- Sdpa,
+ Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig, PhiRotaryEmbedding,
+ RmsNorm, Sdpa,
},
layers_masker::PastKvLenCache,
paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
@@ -27,16 +27,16 @@ use crate::{
Cache, IsqModel, NormalLoadingMetadata, NormalModel,
},
serde_default_fn,
- utils::progress::NiceProgressBar,
+ utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
serde_default_fn!(bool, word_emb_default, false);
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
-#[derive(Debug, Clone, serde::Deserialize, Default)]
+#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
pub struct Config {
pub vocab_size: usize,
- pub hidden_act: candle_nn::Activation,
+ pub hidden_act: Activation,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
@@ -239,7 +239,7 @@ impl Attention {
struct Mlp {
gate_up_proj: Arc,
down_proj: Arc,
- act_fn: candle_nn::Activation,
+ act_fn: Activation,
i_size: usize,
params: Vec,
}
@@ -597,6 +597,24 @@ impl IsqModel for Model {
}
(tensors, &*self.mapper)
}
+
+ fn residual_tensors(&self) -> Vec<(String, Tensor)> {
+ let uvb = UnVarBuilder::new();
+
+ let uvb_m = uvb.pp("model");
+ uvb_m.pp("embed_tokens").add(&self.embed_tokens);
+ uvb_m.pp("norm").add(&self.norm);
+
+ for (layer_idx, layer) in self.layers.iter().enumerate() {
+ let uvb_l = uvb_m.pp("layers").pp(layer_idx);
+ uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
+ uvb_l
+ .pp("post_attention_layernorm")
+ .add(&layer.post_attention_layernorm);
+ }
+
+ uvb.to_safetensors()
+ }
}
impl NormalModel for Model {
diff --git a/mistralrs-core/src/models/phi3_5_moe.rs b/mistralrs-core/src/models/phi3_5_moe.rs
index c1f032692..0f02c1c68 100644
--- a/mistralrs-core/src/models/phi3_5_moe.rs
+++ b/mistralrs-core/src/models/phi3_5_moe.rs
@@ -11,7 +11,10 @@ use crate::{
amoe::AnyMoeBaseModelMixin,
attention::SdpaParams,
device_map::DeviceMapper,
- layers::{CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig, PhiRotaryEmbedding, Sdpa},
+ layers::{
+ Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig, PhiRotaryEmbedding,
+ Sdpa,
+ },
layers_masker::{masked_fill, PastKvLenCache},
ops::NonZeroOp,
paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
@@ -21,16 +24,16 @@ use crate::{
Cache, IsqModel, NormalLoadingMetadata, NormalModel,
},
serde_default_fn,
- utils::progress::NiceProgressBar,
+ utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
serde_default_fn!(bool, word_emb_default, false);
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
-#[derive(Debug, Clone, serde::Deserialize, Default)]
+#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
pub struct Config {
pub(crate) vocab_size: usize,
- pub(crate) hidden_act: candle_nn::Activation,
+ pub(crate) hidden_act: Activation,
pub(crate) hidden_size: usize,
pub(crate) intermediate_size: usize,
pub(crate) num_hidden_layers: usize,
@@ -250,7 +253,7 @@ struct Mlp {
w1: Arc,
w2: Arc,
w3: Arc,
- act_fn: candle_nn::Activation,
+ act_fn: Activation,
}
impl Mlp {
@@ -732,6 +735,48 @@ impl IsqModel for Model {
}
(tensors, &*self.mapper)
}
+
+ fn residual_tensors(&self) -> Vec<(String, Tensor)> {
+ let uvb = UnVarBuilder::new();
+
+ let uvb_m = uvb.pp("model");
+ uvb_m.pp("embed_tokens").add(&self.embed_tokens);
+ uvb_m.pp("norm").add(&self.norm);
+
+ for (layer_idx, layer) in self.layers.iter().enumerate() {
+ let uvb_l = uvb_m.pp("layers").pp(layer_idx);
+ uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
+ uvb_l
+ .pp("post_attention_layernorm")
+ .add(&layer.post_attention_layernorm);
+ }
+
+ uvb.to_safetensors()
+ }
+
+ fn residual_tensors_moe_experts_only(&self) -> Option> {
+ let uvb = UnVarBuilder::new();
+
+ let uvb_m = uvb.pp("model");
+ uvb_m.pp("embed_tokens").add(&self.embed_tokens);
+ uvb_m.pp("norm").add(&self.norm);
+
+ for (layer_idx, layer) in self.layers.iter().enumerate() {
+ let uvb_l = uvb_m.pp("layers").pp(layer_idx);
+ uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
+ uvb_l
+ .pp("post_attention_layernorm")
+ .add(&layer.post_attention_layernorm);
+
+ let uvb_attn = uvb_l.pp("self_attn");
+ uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj);
+ uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj);
+ uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj);
+ uvb_attn.pp("o_proj").add(&layer.self_attn.o_proj);
+ }
+
+ Some(uvb.to_safetensors())
+ }
}
impl NormalModel for Model {
diff --git a/mistralrs-core/src/models/qwen2.rs b/mistralrs-core/src/models/qwen2.rs
index ee4a64f88..7d3cf7b03 100644
--- a/mistralrs-core/src/models/qwen2.rs
+++ b/mistralrs-core/src/models/qwen2.rs
@@ -1,7 +1,7 @@
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use candle_core::{DType, Device, Module, Result, Tensor};
-use candle_nn::{Activation, RotaryEmbedding, VarBuilder};
+use candle_nn::{RotaryEmbedding, VarBuilder};
use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear};
use std::{collections::HashMap, sync::Arc};
@@ -13,7 +13,7 @@ use crate::{
attention::SdpaParams,
device_map::DeviceMapper,
get_delta_from_lora_ab,
- layers::{CausalMasker, MatMul, RmsNorm, Sdpa},
+ layers::{Activation, CausalMasker, MatMul, RmsNorm, Sdpa},
layers_masker::PastKvLenCache,
paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
pipeline::{
@@ -22,12 +22,12 @@ use crate::{
Cache, IsqModel, NormalLoadingMetadata, NormalModel,
},
serde_default_fn,
- utils::progress::NiceProgressBar,
+ utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
serde_default_fn!(bool, word_emb_default, false);
-#[derive(Debug, Clone, serde::Deserialize, Default)]
+#[derive(Debug, Clone, serde::Deserialize, Default, serde::Serialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
@@ -599,6 +599,24 @@ impl IsqModel for Model {
}
(tensors, &*self.mapper)
}
+
+ fn residual_tensors(&self) -> Vec<(String, Tensor)> {
+ let uvb = UnVarBuilder::new();
+
+ let uvb_m = uvb.pp("model");
+ uvb_m.pp("embed_tokens").add(&self.embed_tokens);
+ uvb_m.pp("norm").add(&self.norm);
+
+ for (layer_idx, layer) in self.layers.iter().enumerate() {
+ let uvb_l = uvb_m.pp("layers").pp(layer_idx);
+ uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
+ uvb_l
+ .pp("post_attention_layernorm")
+ .add(&layer.post_attention_layernorm);
+ }
+
+ uvb.to_safetensors()
+ }
}
impl NormalModel for Model {
diff --git a/mistralrs-core/src/models/starcoder2.rs b/mistralrs-core/src/models/starcoder2.rs
index 3c3e8d849..458c48cea 100644
--- a/mistralrs-core/src/models/starcoder2.rs
+++ b/mistralrs-core/src/models/starcoder2.rs
@@ -10,7 +10,7 @@ use crate::{
attention::SdpaParams,
device_map::DeviceMapper,
get_delta_from_lora_ab,
- layers::{CausalMasker, MatMul, RotaryEmbedding, Sdpa},
+ layers::{Activation, CausalMasker, MatMul, RotaryEmbedding, Sdpa},
layers_masker::PastKvLenCache,
paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
pipeline::{
@@ -19,13 +19,13 @@ use crate::{
Cache, IsqModel, NormalLoadingMetadata, NormalModel,
},
serde_default_fn,
- utils::progress::NiceProgressBar,
+ utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
AnyMoeConfig, AnyMoeExpertType,
};
serde_default_fn!(bool, word_emb_default, false);
-#[derive(Debug, Clone, serde::Deserialize, Default)]
+#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) hidden_size: usize,
@@ -33,7 +33,7 @@ pub struct Config {
pub(crate) num_hidden_layers: usize,
pub(crate) num_attention_heads: usize,
pub(crate) num_key_value_heads: usize,
- pub(crate) hidden_act: candle_nn::Activation,
+ pub(crate) hidden_act: Activation,
pub(crate) max_position_embeddings: usize,
pub(crate) norm_epsilon: f64,
pub(crate) rope_theta: f64,
@@ -51,7 +51,7 @@ pub struct Config {
struct MLP {
c_fc: Arc,
c_proj: Arc,
- act: candle_nn::Activation,
+ act: Activation,
params: Vec,
}
@@ -587,6 +587,24 @@ impl IsqModel for Model {
}
(tensors, &*self.mapper)
}
+
+ fn residual_tensors(&self) -> Vec<(String, Tensor)> {
+ let uvb = UnVarBuilder::new();
+
+ let uvb_m = uvb.pp("model");
+ uvb_m.pp("embed_tokens").add(&self.embed_tokens);
+ uvb_m.pp("norm").add(&self.norm);
+
+ for (layer_idx, layer) in self.layers.iter().enumerate() {
+ let uvb_l = uvb_m.pp("layers").pp(layer_idx);
+ uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
+ uvb_l
+ .pp("post_attention_layernorm")
+ .add(&layer.post_attention_layernorm);
+ }
+
+ uvb.to_safetensors()
+ }
}
impl NormalModel for Model {
diff --git a/mistralrs-core/src/ops.rs b/mistralrs-core/src/ops.rs
index 0d7b5321d..020bfe04f 100644
--- a/mistralrs-core/src/ops.rs
+++ b/mistralrs-core/src/ops.rs
@@ -123,6 +123,7 @@ impl CustomOp2 for BitWise {
CpuStorage::F16(_) => Err(Error::UnsupportedDTypeForOp(DType::F16, "bitwise")),
CpuStorage::F32(_) => Err(Error::UnsupportedDTypeForOp(DType::F32, "bitwise")),
CpuStorage::F64(_) => Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise")),
+ CpuStorage::F8E4M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise")),
}
}
#[cfg(feature = "cuda")]
@@ -191,6 +192,9 @@ impl CustomOp2 for BitWise {
DType::F64 => {
return Err(Error::UnsupportedDTypeForOp(DType::F64, "bitwise"));
}
+ DType::F8E4M3 => {
+ return Err(Error::UnsupportedDTypeForOp(DType::F8E4M3, "bitwise"));
+ }
};
let dst = match s1.dtype() {
DType::U8 => {
@@ -397,6 +401,7 @@ fn count_nonzero_cuda(dtype: candle_core::DType, d_in: *const c_void, n: u32) ->
candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n),
candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n),
candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n),
+ candle_core::DType::F8E4M3 => todo!(),
}
}
}
@@ -438,6 +443,7 @@ fn nonzero_cuda(
candle_core::DType::F64 => {
ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out)
}
+ candle_core::DType::F8E4M3 => todo!(),
}
}
}
@@ -461,6 +467,7 @@ impl CustomOp1 for NonZero {
candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout),
candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout),
candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout),
+ candle_core::CpuStorage::F8E4M3(_vs) => todo!(),
};
let index_len = layout.dims().len();
let result_len = result.len() / index_len;
@@ -488,6 +495,7 @@ impl CustomOp1 for NonZero {
candle_core::DType::F16 => *storage.as_cuda_slice::()?.device_ptr(),
candle_core::DType::F32 => *storage.as_cuda_slice::()?.device_ptr(),
candle_core::DType::F64 => *storage.as_cuda_slice::()?.device_ptr(),
+ candle_core::DType::F8E4M3 => todo!(),
} as *const c_void;
let n = layout.shape().elem_count();
let num_nonzero = count_nonzero_cuda(storage.dtype(), d_in, u32::try_from(n)?);
diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs
index 0b64fc206..352db6a5b 100644
--- a/mistralrs-core/src/pipeline/ggml.rs
+++ b/mistralrs-core/src/pipeline/ggml.rs
@@ -109,7 +109,7 @@ impl GGMLLoaderBuilder {
quantized_model_id: String,
quantized_filename: String,
) -> Self {
- let kind = ModelKind::Quantized {
+ let kind = ModelKind::GgufQuantized {
quant: QuantizationKind::Ggml,
};
@@ -339,8 +339,8 @@ impl Loader for GGMLLoader {
// Config into model:
// NOTE: No architecture to infer like GGUF, Llama model is implicitly matched
let model = match self.kind {
- ModelKind::Quantized { .. } => Model::Llama(QLlama::try_from(model_config)?),
- ModelKind::AdapterQuantized { .. } => {
+ ModelKind::GgufQuantized { .. } => Model::Llama(QLlama::try_from(model_config)?),
+ ModelKind::GgufAdapter { .. } => {
Model::XLoraLlama(XLoraQLlama::try_from(model_config)?)
}
_ => unreachable!(),
@@ -410,7 +410,8 @@ impl Loader for GGMLLoader {
self,
self.quantized_model_id,
Some(vec![self.quantized_filename.as_ref().unwrap().clone()]),
- silent
+ silent,
+ false // Never loading UQFF
);
self.load_model_from_path(
&paths?,
diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs
index febd76e99..432d1e050 100644
--- a/mistralrs-core/src/pipeline/gguf.rs
+++ b/mistralrs-core/src/pipeline/gguf.rs
@@ -124,7 +124,7 @@ impl GGUFLoaderBuilder {
quantized_filenames: Vec,
config: GGUFSpecificConfig,
) -> Self {
- let kind = ModelKind::Quantized {
+ let kind = ModelKind::GgufQuantized {
quant: QuantizationKind::Gguf,
};
@@ -394,7 +394,7 @@ impl Loader for GGUFLoader {
let has_adapter = self.kind.is_adapted();
let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
- let paged_attn_config = if matches!(self.kind, ModelKind::AdapterQuantized { .. }) {
+ let paged_attn_config = if matches!(self.kind, ModelKind::GgufAdapter { .. }) {
warn!("Adapter models do not currently support PagedAttention, running without");
None
} else {
@@ -431,7 +431,7 @@ impl Loader for GGUFLoader {
// Config into model:
let model = match self.kind {
- ModelKind::Quantized { .. } => match arch {
+ ModelKind::GgufQuantized { .. } => match arch {
GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?),
GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?),
GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?),
@@ -440,7 +440,7 @@ impl Loader for GGUFLoader {
}
a => bail!("Unsupported architecture `{a:?}` for GGUF"),
},
- ModelKind::AdapterQuantized { adapter, .. } => match arch {
+ ModelKind::GgufAdapter { adapter, .. } => match arch {
GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?),
GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?),
a => bail!(
diff --git a/mistralrs-core/src/pipeline/isq.rs b/mistralrs-core/src/pipeline/isq.rs
index d0b187040..5bcdc5b9f 100644
--- a/mistralrs-core/src/pipeline/isq.rs
+++ b/mistralrs-core/src/pipeline/isq.rs
@@ -1,6 +1,7 @@
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
+ fs::File,
path::PathBuf,
str::FromStr,
sync::{atomic::AtomicUsize, Arc},
@@ -8,17 +9,21 @@ use std::{
};
use anyhow::Result;
-use candle_core::{Device, Tensor};
+use candle_core::{Context, Device, Tensor};
use indicatif::{ParallelProgressIterator, ProgressBar, ProgressStyle};
use mistralrs_quant::{
- GgufMatMul, HqqLayer, IsqType, QuantMethod, QuantizedSerde, QuantizedSerdeType, UnquantLinear,
+ FP8Linear, GgufMatMul, HqqLayer, IsqType, QuantMethod, QuantizedSerde, QuantizedSerdeType,
+ UnquantLinear,
};
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use regex::Regex;
use serde::Deserialize;
+use tokenizers::Tokenizer;
use tracing::info;
-use crate::{device_map::DeviceMapper, serde_default_fn, topology::LayerTopology, Topology};
+use crate::{device_map::DeviceMapper, topology::LayerTopology, Topology};
+
+pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
/// Parse ISQ value: one of
/// - `Q4_0`
@@ -54,10 +59,11 @@ pub fn parse_isq_value(s: &str) -> Result {
"q8k" => IsqType::Q8K,
"hqq8" => IsqType::HQQ8,
"hqq4" => IsqType::HQQ4,
+ "fp8" => IsqType::F8E4M3,
// "hqq3" => IsqType::HQQ3,
// "hqq2" => IsqType::HQQ2,
// "hqq1" => IsqType::HQQ1,
- _ => return Err(format!("ISQ type {s} unknown, choose one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`.")),
+ _ => return Err(format!("ISQ type {s} unknown, choose one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`.")),
};
#[cfg(feature = "cuda")]
{
@@ -74,11 +80,12 @@ pub fn parse_isq_value(s: &str) -> Result {
| IsqType::Q5K
| IsqType::Q6K
| IsqType::HQQ8
- | IsqType::HQQ4 // | IsqType::HQQ3
- // | IsqType::HQQ2
- // | IsqType::HQQ1
+ | IsqType::HQQ4
+ | IsqType::F8E4M3 // | IsqType::HQQ3
+ // | IsqType::HQQ2
+ // | IsqType::HQQ1
) {
- return Err("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`".to_string());
+ return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`".to_string());
}
}
Ok(tp)
@@ -108,6 +115,15 @@ impl FromStr for IsqOrganization {
}
}
+pub struct UqffFullSer<'a> {
+ pub tokenizer: &'a Tokenizer,
+ pub template_filename: &'a Option,
+ pub generation_config: Option<&'a PathBuf>,
+ pub config: String,
+ pub processor_filename: &'a Option,
+ pub preprocessor_filename: &'a Option,
+}
+
pub trait IsqModel {
/// Corresponds to `IsqOrganization::Default`
#[allow(clippy::type_complexity)]
@@ -130,7 +146,19 @@ pub trait IsqModel {
self.get_layers()
}
+ /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers`].
+ fn residual_tensors(&self) -> Vec<(String, Tensor)>;
+
+ /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers_moe_experts_only`].
+ fn residual_tensors_moe_experts_only(&self) -> Option> {
+ None
+ }
+
/// Quantize the model in-situ.
+ ///
+ /// This function will also create a UQFF file, or, if the model supports it (residual tensors are returned),
+ /// a full serialization is created.
+ #[allow(clippy::too_many_arguments)]
fn quantize(
&mut self,
dtype: Option,
@@ -139,6 +167,7 @@ pub trait IsqModel {
silent: bool,
organization: IsqOrganization,
write_artifacts: Option<&PathBuf>,
+ full_ser: UqffFullSer<'_>,
) -> candle_core::Result<()> {
{
let (mut tensors, mapper) = match organization {
@@ -275,10 +304,7 @@ pub trait IsqModel {
);
if !serialized.extension().is_some_and(|ext| ext == "uqff") {
- candle_core::bail!(
- "UQFF output path extension must be {:?}",
- serialized.extension().as_ref().unwrap()
- );
+ candle_core::bail!("UQFF output path extension must be `.uqff`",);
}
let bar = ProgressBar::new(total_tensors as u64);
@@ -331,7 +357,99 @@ pub trait IsqModel {
}
});
+ let parent = serialized
+ .parent()
+ .context("Target UQFF path must have a filename!")?;
+
+ std::fs::create_dir_all(parent)?;
+
safetensors::serialize_to_file(quantized_values?, &None, serialized)?;
+
+ let residual = match organization {
+ IsqOrganization::Default => self.residual_tensors(),
+ IsqOrganization::MoeExpertsOnly => self
+ .residual_tensors_moe_experts_only()
+ .unwrap_or(self.residual_tensors()),
+ };
+
+ let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
+ let config_out = parent.join("config.json");
+ let tokenizer_out = parent.join("tokenizer.json");
+ let tokenizer_cfg_out = parent.join("tokenizer_config.json");
+ let gen_cfg_out = parent.join("generation_config.json");
+ let processor_out = parent.join("processor_config.json");
+ let preprocessor_out = parent.join("preprocessor_config.json");
+
+ info!(
+ "Serializing {} residual tensors to `{}`.",
+ residual.len(),
+ residual_out.display()
+ );
+
+ safetensors::serialize_to_file(residual, &None, &residual_out)?;
+
+ let UqffFullSer {
+ tokenizer,
+ template_filename,
+ generation_config,
+ config,
+ processor_filename,
+ preprocessor_filename,
+ } = full_ser;
+
+ info!("Serializing configuration to `{}`.", config_out.display());
+
+ std::fs::write(config_out, config)?;
+
+ info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
+
+ serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
+ .map_err(candle_core::Error::msg)?;
+
+ if let Some(template_filename) = template_filename {
+ info!(
+ "Serializing tokenizer config to `{}`.",
+ tokenizer_cfg_out.display()
+ );
+
+ let template =
+ std::fs::read(template_filename).map_err(candle_core::Error::msg)?;
+ std::fs::write(&tokenizer_cfg_out, template)
+ .map_err(candle_core::Error::msg)?;
+ }
+
+ if let Some(generation_config) = generation_config {
+ info!(
+ "Serializing generation config to `{}`.",
+ gen_cfg_out.display()
+ );
+
+ let cfg =
+ std::fs::read(generation_config).map_err(candle_core::Error::msg)?;
+ std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?;
+ }
+
+ if let Some(processor_config) = processor_filename {
+ info!(
+ "Serializing processor config to `{}`.",
+ processor_out.display()
+ );
+
+ let cfg =
+ std::fs::read(processor_config).map_err(candle_core::Error::msg)?;
+ std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?;
+ }
+
+ if let Some(preprocessor_config) = preprocessor_filename {
+ info!(
+ "Serializing preprocessor config to `{}`.",
+ preprocessor_out.display()
+ );
+
+ let cfg =
+ std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?;
+ std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?;
+ }
}
}
@@ -412,7 +530,97 @@ pub trait IsqModel {
.collect::>>()
};
- safetensors::serialize_to_file(quantized_values?, &None, serialized)?;
+ let parent = serialized
+ .parent()
+ .context("Target UQFF path must have a filename!")?;
+
+ std::fs::create_dir_all(parent)?;
+
+ let residual = match organization {
+ IsqOrganization::Default => self.residual_tensors(),
+ IsqOrganization::MoeExpertsOnly => self
+ .residual_tensors_moe_experts_only()
+ .unwrap_or(self.residual_tensors()),
+ };
+
+ let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
+ let config_out = parent.join("config.json");
+ let tokenizer_out = parent.join("tokenizer.json");
+ let tokenizer_cfg_out = parent.join("tokenizer_config.json");
+ let gen_cfg_out = parent.join("generation_config.json");
+ let processor_out = parent.join("processor_config.json");
+ let preprocessor_out = parent.join("preprocessor_config.json");
+
+ info!(
+ "Serializing {} residual tensors to `{}`.",
+ residual.len(),
+ residual_out.display()
+ );
+
+ safetensors::serialize_to_file(residual, &None, &residual_out)?;
+
+ let UqffFullSer {
+ tokenizer,
+ template_filename,
+ generation_config,
+ config,
+ processor_filename,
+ preprocessor_filename,
+ } = full_ser;
+
+ info!("Serializing configuration to `{}`.", config_out.display());
+
+ std::fs::write(config_out, config)?;
+
+ info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
+
+ serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
+ .map_err(candle_core::Error::msg)?;
+
+ if let Some(template_filename) = template_filename {
+ info!(
+ "Serializing tokenizer config to `{}`.",
+ tokenizer_cfg_out.display()
+ );
+
+ let template =
+ std::fs::read(template_filename).map_err(candle_core::Error::msg)?;
+ std::fs::write(&tokenizer_cfg_out, template)
+ .map_err(candle_core::Error::msg)?;
+ }
+
+ if let Some(generation_config) = generation_config {
+ info!(
+ "Serializing generation config to `{}`.",
+ gen_cfg_out.display()
+ );
+
+ let cfg =
+ std::fs::read(generation_config).map_err(candle_core::Error::msg)?;
+ std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?;
+ }
+
+ if let Some(processor_config) = processor_filename {
+ info!(
+ "Serializing processor config to `{}`.",
+ processor_out.display()
+ );
+
+ let cfg =
+ std::fs::read(processor_config).map_err(candle_core::Error::msg)?;
+ std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?;
+ }
+
+ if let Some(preprocessor_config) = preprocessor_filename {
+ info!(
+ "Serializing preprocessor config to `{}`.",
+ preprocessor_out.display()
+ );
+
+ let cfg =
+ std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?;
+ std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?;
+ }
}
}
let delta = Instant::now().duration_since(t_start).as_secs_f32();
@@ -511,6 +719,9 @@ pub trait IsqModel {
QuantizedSerdeType::Hqq => {
HqqLayer::deserialize(Cow::from(artifact), &devices[i])?
}
+ QuantizedSerdeType::Fp8 => {
+ FP8Linear::deserialize(Cow::from(artifact), &devices[i])?
+ }
};
*tensor = deserialized;
}
@@ -537,6 +748,9 @@ pub trait IsqModel {
QuantizedSerdeType::Hqq => {
HqqLayer::deserialize(Cow::from(artifact), &devices[i])?
}
+ QuantizedSerdeType::Fp8 => {
+ FP8Linear::deserialize(Cow::from(artifact), &devices[i])?
+ }
};
*tensor = deserialized;
}
@@ -568,11 +782,3 @@ pub(crate) trait IsqModelLoader {
self.isq_layer_regexes(config)
}
}
-
-serde_default_fn!(bool, word_emb_default, false);
-
-#[derive(Deserialize)]
-pub(crate) struct WordEmbeddingsShim {
- #[serde(default = "word_emb_default")]
- pub(crate) tie_word_embeddings: bool,
-}
diff --git a/mistralrs-core/src/pipeline/loaders/mod.rs b/mistralrs-core/src/pipeline/loaders/mod.rs
index 72eace924..f6c6fefd2 100644
--- a/mistralrs-core/src/pipeline/loaders/mod.rs
+++ b/mistralrs-core/src/pipeline/loaders/mod.rs
@@ -236,17 +236,17 @@ impl fmt::Display for TokenSource {
#[derive(Clone, Default, derive_more::From, strum::Display)]
pub enum ModelKind {
#[default]
- #[strum(to_string = "normal (no quant, no adapters)")]
+ #[strum(to_string = "normal (no adapters)")]
Normal,
- #[strum(to_string = "quantized from {quant} (no adapters)")]
- Quantized { quant: QuantizationKind },
+ #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
+ GgufQuantized { quant: QuantizationKind },
- #[strum(to_string = "{adapter}, (no quant)")]
+ #[strum(to_string = "{adapter}")]
Adapter { adapter: AdapterKind },
- #[strum(to_string = "{adapter}, quantized from {quant}")]
- AdapterQuantized {
+ #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
+ GgufAdapter {
adapter: AdapterKind,
quant: QuantizationKind,
},
@@ -311,7 +311,7 @@ impl ModelKind {
match self {
Normal | Adapter { .. } => vec![None],
- Quantized { quant } | AdapterQuantized { quant, .. } => vec![Some(*quant)],
+ GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
Speculative { target, draft } => {
let t = *target.clone();
let d = *draft.clone();
@@ -335,8 +335,8 @@ impl ModelKind {
use ModelKind::*;
match self {
- Normal | Quantized { .. } => vec![None],
- Adapter { adapter } | AdapterQuantized { adapter, .. } => vec![Some(*adapter)],
+ Normal | GgufQuantized { .. } => vec![None],
+ Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
Speculative { target, draft } => {
let t = *target.clone();
let d = *draft.clone();
diff --git a/mistralrs-core/src/pipeline/loaders/normal_loaders.rs b/mistralrs-core/src/pipeline/loaders/normal_loaders.rs
index 21d93ca5e..b5ee6d336 100644
--- a/mistralrs-core/src/pipeline/loaders/normal_loaders.rs
+++ b/mistralrs-core/src/pipeline/loaders/normal_loaders.rs
@@ -7,11 +7,11 @@ use std::{
use crate::{
amoe::AnyMoeBaseModelMixin,
device_map::DeviceMapper,
- layers::{Llama3RopeConfig, PhiRopeScalingConfig},
+ layers::{Activation, Llama3RopeConfig, PhiRopeScalingConfig},
lora::{LoraConfig, Ordering},
paged_attention::{AttentionImplementation, ModelConfigMetadata},
pipeline::{
- isq::{IsqModelLoader, WordEmbeddingsShim},
+ isq::IsqModelLoader,
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
Cache, IsqModel,
},
@@ -21,7 +21,7 @@ use crate::{
};
use anyhow::Result;
use candle_core::{Device, Tensor};
-use candle_nn::{Activation, VarBuilder};
+use candle_nn::VarBuilder;
use mistralrs_quant::QuantizedConfig;
#[cfg(feature = "pyo3_macros")]
@@ -391,35 +391,19 @@ impl NormalModelLoader for MistralLoader {
}
impl IsqModelLoader for MistralLoader {
- fn isq_layer_regexes(&self, config: &str) -> Result> {
- let mut regexes = Vec::new();
- if serde_json::from_str::(config)?.tie_word_embeddings {
- regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?);
- } else {
- regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?);
- }
- // Attention
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
- )?);
- // MLP
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$",
- )?);
- Ok(regexes)
+ fn isq_layer_regexes(&self, _config: &str) -> Result> {
+ Ok(vec![
+ Regex::new(r"lm_head\.(weight|bias)$")?,
+ // Attention
+ Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
+ // MLP
+ Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
+ ])
}
}
@@ -535,35 +519,19 @@ impl NormalModelLoader for GemmaLoader {
}
impl IsqModelLoader for GemmaLoader {
- fn isq_layer_regexes(&self, config: &str) -> Result> {
- let mut regexes = Vec::new();
- if serde_json::from_str::(config)?.tie_word_embeddings {
- regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?);
- } else {
- regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?);
- }
- // Attention
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
- )?);
- // MLP
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$",
- )?);
- Ok(regexes)
+ fn isq_layer_regexes(&self, _config: &str) -> Result> {
+ Ok(vec![
+ Regex::new(r"lm_head\.(weight|bias)$")?,
+ // Attention
+ Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
+ // MLP
+ Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
+ ])
}
}
@@ -673,35 +641,19 @@ impl NormalModelLoader for LlamaLoader {
}
impl IsqModelLoader for LlamaLoader {
- fn isq_layer_regexes(&self, config: &str) -> Result> {
- let mut regexes = Vec::new();
- if serde_json::from_str::(config)?.tie_word_embeddings {
- regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?);
- } else {
- regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?);
- }
- // Attention
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
- )?);
- // MLP
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$",
- )?);
- Ok(regexes)
+ fn isq_layer_regexes(&self, _config: &str) -> Result> {
+ Ok(vec![
+ Regex::new(r"lm_head\.(weight|bias)$")?,
+ // Attention
+ Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
+ // MLP
+ Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
+ ])
}
}
@@ -807,40 +759,20 @@ impl NormalModelLoader for MixtralLoader {
}
impl IsqModelLoader for MixtralLoader {
- fn isq_layer_regexes(&self, config: &str) -> Result> {
- let mut regexes = Vec::new();
- if serde_json::from_str::(config)?.tie_word_embeddings {
- regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?);
- } else {
- regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?);
- }
- // Attention
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
- )?);
- // Experts
- regexes.push(Regex::new(
- r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$",
- )?);
- Ok(regexes)
+ fn isq_layer_regexes(&self, _config: &str) -> Result> {
+ Ok(vec![
+ Regex::new(r"lm_head\.(weight|bias)$")?,
+ // Attention
+ Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
+ // Experts
+ Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
+ ])
}
}
@@ -947,30 +879,18 @@ impl NormalModelLoader for Phi2Loader {
}
impl IsqModelLoader for Phi2Loader {
- fn isq_layer_regexes(&self, config: &str) -> Result> {
- let mut regexes = Vec::new();
- if serde_json::from_str::(config)?.tie_word_embeddings {
- regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?);
- } else {
- regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?);
- }
- // Attention
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$",
- )?);
- // MLP
- regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?);
- regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?);
- Ok(regexes)
+ fn isq_layer_regexes(&self, _config: &str) -> Result> {
+ Ok(vec![
+ Regex::new(r"lm_head\.(weight|bias)$")?,
+ // Attention
+ Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
+ // MLP
+ Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
+ ])
}
}
@@ -979,7 +899,7 @@ impl IsqModelLoader for Phi2Loader {
#[derive(Deserialize)]
struct Phi3BasicConfig {
vocab_size: usize,
- hidden_act: candle_nn::Activation,
+ hidden_act: Activation,
hidden_size: usize,
intermediate_size: usize,
num_hidden_layers: usize,
@@ -1083,28 +1003,17 @@ impl NormalModelLoader for Phi3Loader {
}
impl IsqModelLoader for Phi3Loader {
- fn isq_layer_regexes(&self, config: &str) -> Result> {
- let mut regexes = Vec::new();
- if serde_json::from_str::(config)?.tie_word_embeddings {
- regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?);
- } else {
- regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?);
- }
- // Attention
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
- )?);
- // MLP
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$",
- )?);
- Ok(regexes)
+ fn isq_layer_regexes(&self, _config: &str) -> Result> {
+ Ok(vec![
+ Regex::new(r"lm_head\.(weight|bias)$")?,
+ // Attention
+ Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
+ // MLP
+ Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
+ ])
}
}
@@ -1199,35 +1108,19 @@ impl NormalModelLoader for Qwen2Loader {
}
impl IsqModelLoader for Qwen2Loader {
- fn isq_layer_regexes(&self, config: &str) -> Result> {
- let mut regexes = Vec::new();
- if serde_json::from_str::(config)?.tie_word_embeddings {
- regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?);
- } else {
- regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?);
- }
- // Attention
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$",
- )?);
- // MLP
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?);
- Ok(regexes)
+ fn isq_layer_regexes(&self, _config: &str) -> Result> {
+ Ok(vec![
+ Regex::new(r"lm_head\.(weight|bias)$")?,
+ // Attention
+ Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
+ // MLP
+ Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
+ ])
}
}
@@ -1348,35 +1241,19 @@ impl NormalModelLoader for Gemma2Loader {
}
impl IsqModelLoader for Gemma2Loader {
- fn isq_layer_regexes(&self, config: &str) -> Result> {
- let mut regexes = Vec::new();
- if serde_json::from_str::(config)?.tie_word_embeddings {
- regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?);
- } else {
- regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?);
- }
- // Attention
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$",
- )?);
- // MLP
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?);
- Ok(regexes)
+ fn isq_layer_regexes(&self, _config: &str) -> Result> {
+ Ok(vec![
+ Regex::new(r"lm_head\.(weight|bias)$")?,
+ // Attention
+ Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
+ // MLP
+ Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
+ ])
}
}
@@ -1390,7 +1267,7 @@ struct Starcoder2BasicConfig {
num_hidden_layers: usize,
num_attention_heads: usize,
num_key_value_heads: usize,
- hidden_act: candle_nn::Activation,
+ hidden_act: Activation,
max_position_embeddings: usize,
norm_epsilon: f64,
rope_theta: f64,
@@ -1482,30 +1359,18 @@ impl NormalModelLoader for Starcoder2Loader {
}
impl IsqModelLoader for Starcoder2Loader {
- fn isq_layer_regexes(&self, config: &str) -> Result> {
- let mut regexes = Vec::new();
- if serde_json::from_str::(config)?.tie_word_embeddings {
- regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?);
- } else {
- regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?);
- }
- // Attention
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$",
- )?);
- // MLP
- regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.c_fc\.(weight|bias)$")?);
- regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?);
- Ok(regexes)
+ fn isq_layer_regexes(&self, _config: &str) -> Result> {
+ Ok(vec![
+ Regex::new(r"lm_head\.(weight|bias)$")?,
+ // Attention
+ Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
+ // MLP
+ Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
+ ])
}
}
@@ -1514,7 +1379,7 @@ impl IsqModelLoader for Starcoder2Loader {
#[derive(Deserialize)]
struct Phi3_5MoEBasicConfig {
vocab_size: usize,
- hidden_act: candle_nn::Activation,
+ hidden_act: Activation,
hidden_size: usize,
intermediate_size: usize,
num_hidden_layers: usize,
@@ -1622,56 +1487,28 @@ impl NormalModelLoader for Phi3_5MoELoader {
}
impl IsqModelLoader for Phi3_5MoELoader {
- fn isq_layer_regexes(&self, config: &str) -> Result> {
- let mut regexes = Vec::new();
- if serde_json::from_str::(config)?.tie_word_embeddings {
- regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?);
- } else {
- regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?);
- }
- // Attention
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$",
- )?);
- // MLP
- regexes.push(Regex::new(
- r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$",
- )?);
- Ok(regexes)
- }
-
- fn isq_layer_regexes_moqe(&self, config: &str) -> Result> {
- let mut regexes = Vec::new();
- if serde_json::from_str::(config)?.tie_word_embeddings {
- regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?);
- } else {
- regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?);
- }
- // MLP
- regexes.push(Regex::new(
- r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$",
- )?);
- Ok(regexes)
+ fn isq_layer_regexes(&self, _config: &str) -> Result> {
+ Ok(vec![
+ Regex::new(r"lm_head\.(weight|bias)$")?,
+ // Attention
+ Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
+ // MLP
+ Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
+ ])
+ }
+
+ fn isq_layer_regexes_moqe(&self, _config: &str) -> Result> {
+ Ok(vec![
+ Regex::new(r"lm_head\.(weight|bias)$")?,
+ // MLP
+ Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
+ ])
}
}
diff --git a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs
index efe84b2c3..a41ca568b 100644
--- a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs
+++ b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs
@@ -15,7 +15,7 @@ use serde::Deserialize;
use super::NormalLoadingMetadata;
use crate::amoe::AnyMoeBaseModelMixin;
use crate::paged_attention::{AttentionImplementation, ModelConfigMetadata};
-use crate::pipeline::isq::{IsqModelLoader, WordEmbeddingsShim};
+use crate::pipeline::isq::IsqModelLoader;
use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
use crate::pipeline::{Cache, IsqModel, Processor, ProcessorCreator};
use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
@@ -157,28 +157,16 @@ impl VisionModelLoader for Phi3VLoader {
}
impl IsqModelLoader for Phi3VLoader {
- fn isq_layer_regexes(&self, config: &str) -> Result> {
- let mut regexes = Vec::new();
- if serde_json::from_str::(config)?.tie_word_embeddings {
- regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?);
- } else {
- regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?);
- }
- // Attention
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
- )?);
- // MLP
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$",
- )?);
- Ok(regexes)
+ fn isq_layer_regexes(&self, _config: &str) -> Result> {
+ Ok(vec![
+ Regex::new(r"lm_head\.(weight|bias)$")?,
+ // Attention
+ Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
+ // MLP
+ Regex::new(r"layers\.(\d+)\.mlp\.gate__up_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
+ ])
}
}
@@ -240,7 +228,6 @@ impl VisionModelLoader for Idefics2Loader {
impl IsqModelLoader for Idefics2Loader {
fn isq_layer_regexes(&self, _config: &str) -> Result> {
Ok(vec![
- // Tie weights is unsupported for this model
Regex::new(r"lm_head\.(weight|bias)$")?,
// Attention
Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
@@ -310,7 +297,6 @@ impl VisionModelLoader for LLaVANextLoader {
impl IsqModelLoader for LLaVANextLoader {
fn isq_layer_regexes(&self, _config: &str) -> Result> {
Ok(vec![
- // Tie weights is unsupported for this model
Regex::new(r"lm_head\.(weight|bias)$")?,
// Attention
Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
@@ -380,7 +366,6 @@ impl VisionModelLoader for LLaVALoader {
impl IsqModelLoader for LLaVALoader {
fn isq_layer_regexes(&self, _config: &str) -> Result> {
Ok(vec![
- // Tie weights is unsupported for this model
Regex::new(r"lm_head\.(weight|bias)$")?,
// Attention
Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
@@ -448,37 +433,18 @@ impl VisionModelLoader for VLlamaLoader {
}
impl IsqModelLoader for VLlamaLoader {
- fn isq_layer_regexes(&self, config: &str) -> Result> {
- let mut regexes = Vec::new();
- if serde_json::from_str::(config)?
- .text_config
- .tie_word_embeddings
- {
- regexes.push(Regex::new(r"(embed_tokens|lm_head)\.(weight|bias)$")?);
- } else {
- regexes.push(Regex::new(r"lm_head\.(weight|bias)$")?);
- }
- // Attention
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
- )?);
- // MLP
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$",
- )?);
- regexes.push(Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?);
- regexes.push(Regex::new(
- r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$",
- )?);
- Ok(regexes)
+ fn isq_layer_regexes(&self, _config: &str) -> Result> {
+ Ok(vec![
+ Regex::new(r"lm_head\.(weight|bias)$")?,
+ // Attention
+ Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
+ // MLP
+ Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
+ Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
+ ])
}
}
diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs
index a74af8efc..7f3b8420b 100644
--- a/mistralrs-core/src/pipeline/macros.rs
+++ b/mistralrs-core/src/pipeline/macros.rs
@@ -56,7 +56,16 @@ macro_rules! api_get_file {
#[doc(hidden)]
#[macro_export]
macro_rules! get_paths {
- ($path_name:ident, $token_source:expr, $revision:expr, $this:expr, $quantized_model_id:expr, $quantized_filename:expr, $silent:expr) => {{
+ (
+ $path_name:ident,
+ $token_source:expr,
+ $revision:expr,
+ $this:expr,
+ $quantized_model_id:expr,
+ $quantized_filename:expr,
+ $silent:expr,
+ $loading_uqff:expr
+ ) => {{
let api = ApiBuilder::new()
.with_progress(!$silent)
.with_token(get_token($token_source)?)
@@ -84,6 +93,7 @@ macro_rules! get_paths {
&$quantized_filename,
&api,
&model_id,
+ $loading_uqff,
)?;
let XLoraPaths {
adapter_configs,
@@ -169,54 +179,49 @@ macro_rules! get_paths {
#[doc(hidden)]
#[macro_export]
-macro_rules! get_write_uqff_paths {
+macro_rules! get_uqff_paths {
($from_uqff:expr, $this:expr, $silent:expr) => {{
- if !$from_uqff.exists() {
- // Assume it's a HF model id
- let path = $from_uqff.to_string_lossy().to_string();
- let parts = path.rsplitn(2, '/').collect::>();
-
- if parts.len() != 2 {
- anyhow::bail!("ISQ artifact load path `{path}` not found locally must have format `/`");
- }
-
- let file = parts[0];
- let model_id = parts[1];
+ let api = ApiBuilder::new()
+ .with_progress(!$silent)
+ .with_token(get_token(
+ &$this
+ .token_source
+ .read()
+ .expect("Failed to read token source")
+ .clone()
+ .unwrap_or(TokenSource::None),
+ )?)
+ .build()?;
+ let revision = $this
+ .revision
+ .read()
+ .expect("Failed to read revision")
+ .clone()
+ .unwrap_or("main".to_string());
+ let api = api.repo(Repo::with_revision(
+ $this.model_id.to_string(),
+ RepoType::Model,
+ revision.clone(),
+ ));
- let api = ApiBuilder::new()
- .with_progress(!$silent)
- .with_token(get_token(
- &$this
- .token_source
- .read()
- .expect("Failed to read token source")
- .clone()
- .unwrap_or(TokenSource::None),
- )?)
- .build()?;
- let revision = $this
- .revision
- .read()
- .expect("Failed to read revision")
- .clone()
- .unwrap_or("main".to_string());
- let api = api.repo(Repo::with_revision(
- model_id.to_string(),
- RepoType::Model,
- revision.clone(),
- ));
+ let file = $from_uqff.display().to_string();
- api_get_file!(api, file, Path::new(model_id))
- } else {
- $from_uqff
- }
+ api_get_file!(api, &file, Path::new(&$this.model_id))
}};
}
#[doc(hidden)]
#[macro_export]
macro_rules! get_paths_gguf {
- ($path_name:ident, $token_source:expr, $revision:expr, $this:expr, $quantized_model_id:expr, $quantized_filenames:expr, $silent:expr) => {{
+ (
+ $path_name:ident,
+ $token_source:expr,
+ $revision:expr,
+ $this:expr,
+ $quantized_model_id:expr,
+ $quantized_filenames:expr,
+ $silent:expr
+ ) => {{
let api = ApiBuilder::new()
.with_progress(!$silent)
.with_token(get_token($token_source)?)
@@ -258,6 +263,7 @@ macro_rules! get_paths_gguf {
&Some($quantized_filenames),
&api,
&model_id,
+ false, // Never loading UQFF
)?;
let XLoraPaths {
@@ -345,7 +351,21 @@ macro_rules! get_paths_gguf {
#[doc(hidden)]
#[macro_export]
macro_rules! normal_model_loader {
- ($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $loading_uqff:expr, $real_device:expr, $attention_mechanism:expr, $is_moqe:expr) => {{
+ (
+ $paths:expr,
+ $dtype:expr,
+ $device:expr,
+ $config:expr,
+ $loader:expr,
+ $use_flash_attn:expr,
+ $silent:expr,
+ $mapper:expr,
+ $loading_isq:expr,
+ $loading_uqff:expr,
+ $real_device:expr,
+ $attention_mechanism:expr,
+ $is_moqe:expr
+ ) => {{
let regexes = if $loading_isq && $loading_uqff {
// Dummy weights for the layers which will be overwritten...
Some(std::sync::Arc::new(if $is_moqe {
@@ -384,7 +404,20 @@ macro_rules! normal_model_loader {
#[doc(hidden)]
#[macro_export]
macro_rules! vision_normal_model_loader {
- ($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $loading_uqff:expr, $real_device:expr, $attention_mechanism:expr) => {{
+ (
+ $paths:expr,
+ $dtype:expr,
+ $device:expr,
+ $config:expr,
+ $loader:expr,
+ $use_flash_attn:expr,
+ $silent:expr,
+ $mapper:expr,
+ $loading_isq:expr,
+ $loading_uqff:expr,
+ $real_device:expr,
+ $attention_mechanism:expr
+ ) => {{
let regexes = if $loading_isq && $loading_uqff {
// Dummy weights for the layers which will be overwritten...
Some(std::sync::Arc::new($loader.isq_layer_regexes(&$config)?))
@@ -419,7 +452,18 @@ macro_rules! vision_normal_model_loader {
#[doc(hidden)]
#[macro_export]
macro_rules! xlora_model_loader {
- ($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
+ (
+ $paths:expr,
+ $dtype:expr,
+ $device:expr,
+ $config:expr,
+ $loader:expr,
+ $use_flash_attn:expr,
+ $silent:expr,
+ $mapper:expr,
+ $loading_isq:expr,
+ $real_device:expr
+ ) => {{
let mut safetensors_paths = $paths.get_weight_filenames().iter().collect::>();
safetensors_paths.push($paths.get_classifier_path().as_ref().unwrap());
let vb = from_mmaped_safetensors(
diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs
index 49c22a277..7ef391988 100644
--- a/mistralrs-core/src/pipeline/normal.rs
+++ b/mistralrs-core/src/pipeline/normal.rs
@@ -18,6 +18,7 @@ use crate::amoe::AnyMoeExpertType;
use crate::lora::Ordering;
use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
+use crate::pipeline::isq::UqffFullSer;
use crate::pipeline::sampling::sample_and_add_toks;
use crate::pipeline::{get_chat_template, Cache};
use crate::pipeline::{ChatTemplate, LocalModelPaths};
@@ -28,9 +29,9 @@ use crate::utils::tokenizer::get_tokenizer;
use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
use crate::xlora_models::NonGranularState;
use crate::{
- api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_write_uqff_paths,
- lora_model_loader, normal_model_loader, xlora_model_loader, DeviceMapMetadata,
- PagedAttentionConfig, Pipeline, Topology, TryIntoDType,
+ api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
+ normal_model_loader, xlora_model_loader, DeviceMapMetadata, PagedAttentionConfig, Pipeline,
+ Topology, TryIntoDType,
};
use anyhow::Result;
use candle_core::{Device, Tensor, Var};
@@ -59,6 +60,10 @@ pub struct NormalPipeline {
topology: Option,
silent: bool,
organization: IsqOrganization,
+ // For full UQFF serialization
+ template_filename: Option,
+ generation_config: Option,
+ config: String,
}
/// A loader for a "normal" (non-quantized) model.
@@ -75,6 +80,7 @@ pub struct NormalLoader {
tgt_non_granular_index: Option,
token_source: RwLock