添加机器学习,调用尚未完善.

This commit is contained in:
lennlouisgeek
2026-04-13 02:01:48 +08:00
parent 49c3e0736c
commit 26e9c41750
8 changed files with 680 additions and 2 deletions

316
src-tauri/Cargo.lock generated
View File

@@ -14,6 +14,8 @@ dependencies = [
"fern", "fern",
"humantime", "humantime",
"log", "log",
"ndarray",
"ort",
"serde", "serde",
"serde_json", "serde_json",
"tauri", "tauri",
@@ -247,6 +249,12 @@ version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "base64ct"
version = "1.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06"
[[package]] [[package]]
name = "bit-set" name = "bit-set"
version = "0.8.0" version = "0.8.0"
@@ -556,7 +564,7 @@ dependencies = [
"bitflags 2.11.0", "bitflags 2.11.0",
"core-foundation", "core-foundation",
"core-graphics-types", "core-graphics-types",
"foreign-types", "foreign-types 0.5.0",
"libc", "libc",
] ]
@@ -734,6 +742,16 @@ dependencies = [
"syn 2.0.117", "syn 2.0.117",
] ]
[[package]]
name = "der"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71fd89660b2dc699704064e59e9dba0147b903e85319429e131620d022be411b"
dependencies = [
"pem-rfc7468",
"zeroize",
]
[[package]] [[package]]
name = "deranged" name = "deranged"
version = "0.5.8" version = "0.5.8"
@@ -1071,6 +1089,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared 0.1.1",
]
[[package]] [[package]]
name = "foreign-types" name = "foreign-types"
version = "0.5.0" version = "0.5.0"
@@ -1078,7 +1105,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965"
dependencies = [ dependencies = [
"foreign-types-macros", "foreign-types-macros",
"foreign-types-shared", "foreign-types-shared 0.3.1",
] ]
[[package]] [[package]]
@@ -1092,6 +1119,12 @@ dependencies = [
"syn 2.0.117", "syn 2.0.117",
] ]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]] [[package]]
name = "foreign-types-shared" name = "foreign-types-shared"
version = "0.3.1" version = "0.3.1"
@@ -1576,6 +1609,12 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "hmac-sha256"
version = "1.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec9d92d097f4749b64e8cc33d924d9f40a2d4eb91402b458014b781f5733d60f"
[[package]] [[package]]
name = "html5ever" name = "html5ever"
version = "0.29.1" version = "0.29.1"
@@ -2106,6 +2145,12 @@ version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
[[package]]
name = "lzma-rust2"
version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1670343e58806300d87950e3401e820b519b9384281bbabfb15e3636689ffd69"
[[package]] [[package]]
name = "mac" name = "mac"
version = "0.1.1" version = "0.1.1"
@@ -2163,6 +2208,16 @@ version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5"
[[package]]
name = "matrixmultiply"
version = "0.3.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.8.0" version = "2.8.0"
@@ -2240,6 +2295,38 @@ dependencies = [
"windows-sys 0.60.2", "windows-sys 0.60.2",
] ]
[[package]]
name = "native-tls"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "ndarray"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer",
]
[[package]] [[package]]
name = "ndk" name = "ndk"
version = "0.9.0" version = "0.9.0"
@@ -2305,12 +2392,30 @@ version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb" checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb"
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "num-conv" name = "num-conv"
version = "0.2.0" version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050"
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "num-traits" name = "num-traits"
version = "0.2.19" version = "0.2.19"
@@ -2483,6 +2588,50 @@ dependencies = [
"pathdiff", "pathdiff",
] ]
[[package]]
name = "openssl"
version = "0.10.76"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf"
dependencies = [
"bitflags 2.11.0",
"cfg-if",
"foreign-types 0.3.2",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "openssl-probe"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
[[package]]
name = "openssl-sys"
version = "0.9.112"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]] [[package]]
name = "option-ext" name = "option-ext"
version = "0.2.0" version = "0.2.0"
@@ -2499,6 +2648,30 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
] ]
[[package]]
name = "ort"
version = "2.0.0-rc.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7de3af33d24a745ffb8fab904b13478438d1cd52868e6f17735ef6e1f8bf133"
dependencies = [
"ndarray",
"ort-sys",
"smallvec",
"tracing",
"ureq",
]
[[package]]
name = "ort-sys"
version = "2.0.0-rc.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7b497d21a8b6fbb4b5a544f8fadb77e801a09ae0add9e411d31c6f89e3c1e90"
dependencies = [
"hmac-sha256",
"lzma-rust2",
"ureq",
]
[[package]] [[package]]
name = "pango" name = "pango"
version = "0.18.3" version = "0.18.3"
@@ -2559,6 +2732,15 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
[[package]]
name = "pem-rfc7468"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6305423e0e7738146434843d1694d621cce767262b2a86910beab705e4493d9"
dependencies = [
"base64ct",
]
[[package]] [[package]]
name = "percent-encoding" name = "percent-encoding"
version = "2.3.2" version = "2.3.2"
@@ -2821,6 +3003,21 @@ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.61.2",
] ]
[[package]]
name = "portable-atomic"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
[[package]]
name = "portable-atomic-util"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3"
dependencies = [
"portable-atomic",
]
[[package]] [[package]]
name = "potential_utf" name = "potential_utf"
version = "0.1.4" version = "0.1.4"
@@ -3046,6 +3243,12 @@ version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539"
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.18" version = "0.5.18"
@@ -3177,6 +3380,15 @@ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.61.2",
] ]
[[package]]
name = "rustls-pki-types"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd"
dependencies = [
"zeroize",
]
[[package]] [[package]]
name = "rustversion" name = "rustversion"
version = "1.0.22" version = "1.0.22"
@@ -3198,6 +3410,15 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "schannel"
version = "0.1.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939"
dependencies = [
"windows-sys 0.61.2",
]
[[package]] [[package]]
name = "schemars" name = "schemars"
version = "0.8.22" version = "0.8.22"
@@ -3255,6 +3476,29 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "security-framework"
version = "3.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d"
dependencies = [
"bitflags 2.11.0",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "selectors" name = "selectors"
version = "0.24.0" version = "0.24.0"
@@ -3554,6 +3798,17 @@ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.61.2",
] ]
[[package]]
name = "socks"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b"
dependencies = [
"byteorder",
"libc",
"winapi",
]
[[package]] [[package]]
name = "softbuffer" name = "softbuffer"
version = "0.4.8" version = "0.4.8"
@@ -4507,6 +4762,36 @@ version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "ureq"
version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0"
dependencies = [
"base64 0.22.1",
"der",
"log",
"native-tls",
"percent-encoding",
"rustls-pki-types",
"socks",
"ureq-proto",
"utf8-zero",
"webpki-root-certs",
]
[[package]]
name = "ureq-proto"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c"
dependencies = [
"base64 0.22.1",
"http",
"httparse",
"log",
]
[[package]] [[package]]
name = "url" name = "url"
version = "2.5.8" version = "2.5.8"
@@ -4538,6 +4823,12 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "utf8-zero"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e"
[[package]] [[package]]
name = "utf8_iter" name = "utf8_iter"
version = "1.0.4" version = "1.0.4"
@@ -4556,6 +4847,12 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]] [[package]]
name = "version-compare" name = "version-compare"
version = "0.2.1" version = "0.2.1"
@@ -4809,6 +5106,15 @@ dependencies = [
"system-deps", "system-deps",
] ]
[[package]]
name = "webpki-root-certs"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca"
dependencies = [
"rustls-pki-types",
]
[[package]] [[package]]
name = "webview2-com" name = "webview2-com"
version = "0.38.2" version = "0.38.2"
@@ -5600,6 +5906,12 @@ dependencies = [
"synstructure", "synstructure",
] ]
[[package]]
name = "zeroize"
version = "1.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0"
[[package]] [[package]]
name = "zerotrie" name = "zerotrie"
version = "0.2.3" version = "0.2.3"

View File

@@ -33,3 +33,5 @@ humantime = "2.3.0"
csv = "1.4.0" csv = "1.4.0"
chrono = "0.4.44" chrono = "0.4.44"
crc = "3.4.0" crc = "3.4.0"
ort = "2.0.0-rc.12"
ndarray = "0.17.2"

Binary file not shown.

View File

@@ -0,0 +1,103 @@
{
"model_scope": "1d time-series tactile force compensation",
"channel_count": 84,
"feature_names": [
"sensor_sum",
"sensor_mean",
"sensor_max",
"sensor_min",
"sensor_range",
"sensor_std",
"top5_sum",
"baseline_sum",
"delta_from_baseline",
"naive_force_n",
"slope_1",
"slope_5",
"slope_20",
"slope_100",
"mean_5",
"mean_20",
"mean_100",
"mean_300",
"short_minus_mid",
"mid_minus_long",
"age_above_touch"
],
"feature_engineering": {
"current_row_stats": [
"sensor_sum",
"sensor_mean",
"sensor_max",
"sensor_min",
"sensor_range",
"sensor_std",
"top5_sum"
],
"rolling_mean_windows": [
5,
20,
100,
300
],
"slope_windows": [
1,
5,
20,
100
],
"topk_channels": 5,
"baseline_feature": "baseline_sum",
"age_above_touch_feature": "age_above_touch",
"naive_force_feature": "naive_force_n"
},
"inference_pipeline": [
"build 21 numeric features in the order listed by feature_names",
"run classifier.onnx or classifier.cbm",
"if predicted_state == 0 then output 0.0 N",
"otherwise run regressor.onnx or regressor.cbm and clamp to >= 0.0 N"
],
"calibration_points": [
{
"sensor_sum": 74602.73399170733,
"force_n": 0.9800000000000001
},
{
"sensor_sum": 105503.9038227644,
"force_n": 1.9600000000000002
},
{
"sensor_sum": 131459.57643184246,
"force_n": 2.94
},
{
"sensor_sum": 153512.34776297462,
"force_n": 3.9200000000000004
},
{
"sensor_sum": 172041.11212077862,
"force_n": 4.9
},
{
"sensor_sum": 193794.83789260528,
"force_n": 5.88
},
{
"sensor_sum": 218947.72467683573,
"force_n": 7.840000000000001
},
{
"sensor_sum": 240580.4449421614,
"force_n": 9.8
}
],
"training_args": {
"positive_step": 80,
"zero_step": 40,
"classifier_iterations": 500,
"regressor_iterations": 700,
"depth": 6,
"learning_rate": 0.05,
"random_seed": 42
}
}

Binary file not shown.

260
src-tauri/src/estimator.rs Normal file
View File

@@ -0,0 +1,260 @@
use anyhow::{Context, Result};
use ndarray::Array2;
use ort::session::Session;
use ort::value::TensorRef;
use serde::Deserialize;
use std::collections::VecDeque;
use std::fs;
use std::path::{Path, PathBuf};
const CHANNEL_COUNT: usize = 84;
const FEATURE_COUNT: usize = 21;
const HISTORY_WINDOW: usize = 300;
#[derive(Debug, Deserialize)]
struct CalibrationPoint {
sensor_sum: f32,
force_n: f32,
}
#[derive(Debug, Deserialize)]
struct ModelMetadata {
feature_names: Vec<String>,
calibration_points: Vec<CalibrationPoint>,
}
pub struct ForceEstimator {
classifier: Session,
regressor: Session,
metadata: ModelMetadata,
sum_history: VecDeque<f32>,
baseline_buffer: Vec<f32>,
baseline_sum: f32,
baseline_ready: bool,
touched: bool,
age_above_touch: u32,
touch_delta_threshold: f32,
}
impl ForceEstimator {
pub fn new(bundle_dir: impl AsRef<Path>) -> Result<Self> {
let bundle_dir = bundle_dir.as_ref();
let metadata_path = bundle_dir.join("model_metadata.json");
let classifier_path = bundle_dir.join("classifier.onnx");
let regressor_path = bundle_dir.join("regressor.onnx");
let metadata: ModelMetadata = serde_json::from_slice(
&fs::read(&metadata_path)
.with_context(|| format!("failed to read {}", metadata_path.display()))?,
)
.with_context(|| format!("failed to parse {}", metadata_path.display()))?;
if metadata.feature_names.len() != FEATURE_COUNT {
anyhow::bail!(
"expected {} features, got {}",
FEATURE_COUNT,
metadata.feature_names.len()
);
}
let classifier = Session::builder()?
.commit_from_file(&classifier_path)
.with_context(|| format!("failed to load {}", classifier_path.display()))?;
let regressor = Session::builder()?
.commit_from_file(&regressor_path)
.with_context(|| format!("failed to load {}", regressor_path.display()))?;
Ok(Self {
classifier,
regressor,
metadata,
sum_history: VecDeque::with_capacity(HISTORY_WINDOW),
baseline_buffer: Vec::with_capacity(100),
baseline_sum: 0.0,
baseline_ready: false,
touched: false,
age_above_touch: 0,
touch_delta_threshold: 1_000.0,
})
}
pub fn process_frame(&mut self, channels: [f32; CHANNEL_COUNT]) -> Result<f32> {
let sensor_sum: f32 = channels.iter().sum();
// Simple online baseline rule:
// collect the first 100 low-activity frames, then freeze the baseline.
if !self.baseline_ready && !self.touched {
self.baseline_buffer.push(sensor_sum);
if self.baseline_buffer.len() >= 100 {
let total: f32 = self.baseline_buffer.iter().sum();
self.baseline_sum = total / self.baseline_buffer.len() as f32;
self.baseline_ready = true;
}
}
if self.sum_history.len() == HISTORY_WINDOW {
self.sum_history.pop_front();
}
self.sum_history.push_back(sensor_sum);
let touch_threshold = self.baseline_sum + self.touch_delta_threshold;
if self.baseline_ready && sensor_sum >= touch_threshold {
self.touched = true;
self.age_above_touch += 1;
} else {
self.age_above_touch = 0;
}
let features = self.build_features(&channels)?;
let predicted_state = self.predict_state(&features)?;
if predicted_state == 0 {
return Ok(0.0);
}
let predicted_force = self.predict_force_value(&features)?;
Ok(predicted_force.max(0.0))
}
fn build_features(&self, channels: &[f32; CHANNEL_COUNT]) -> Result<[f32; FEATURE_COUNT]> {
let sensor_sum = *self
.sum_history
.back()
.context("cannot build features without at least one frame")?;
let sensor_mean = sensor_sum / CHANNEL_COUNT as f32;
let mut sorted = channels.to_vec();
sorted.sort_by(|a, b| b.partial_cmp(a).unwrap());
let top5_sum: f32 = sorted.iter().take(5).sum();
let sensor_max = *sorted.first().unwrap_or(&0.0);
let sensor_min = channels
.iter()
.fold(f32::INFINITY, |acc, &value| acc.min(value));
let sensor_range = sensor_max - sensor_min;
let sensor_std = population_std(channels, sensor_mean);
let naive_force_n = interpolate_force(sensor_sum, &self.metadata.calibration_points);
let mean_5 = rolling_mean(&self.sum_history, 5);
let mean_20 = rolling_mean(&self.sum_history, 20);
let mean_100 = rolling_mean(&self.sum_history, 100);
let mean_300 = rolling_mean(&self.sum_history, 300);
Ok([
sensor_sum,
sensor_mean,
sensor_max,
sensor_min,
sensor_range,
sensor_std,
top5_sum,
self.baseline_sum,
sensor_sum - self.baseline_sum,
naive_force_n,
slope_from_history(&self.sum_history, 1),
slope_from_history(&self.sum_history, 5),
slope_from_history(&self.sum_history, 20),
slope_from_history(&self.sum_history, 100),
mean_5,
mean_20,
mean_100,
mean_300,
mean_5 - mean_20,
mean_20 - mean_100,
self.age_above_touch as f32,
])
}
fn predict_state(&mut self, features: &[f32; FEATURE_COUNT]) -> Result<i64> {
let input = Array2::from_shape_vec((1, FEATURE_COUNT), features.to_vec())?;
let outputs = self
.classifier
.run(ort::inputs![TensorRef::from_array_view(&input)?]?)?;
let first = outputs
.iter()
.next()
.context("classifier returned no outputs")?;
let tensor = first.try_extract_tensor::<i64>()?;
let value = tensor
.view()
.iter()
.next()
.copied()
.context("classifier output tensor was empty")?;
Ok(value)
}
fn predict_force_value(&mut self, features: &[f32; FEATURE_COUNT]) -> Result<f32> {
let input = Array2::from_shape_vec((1, FEATURE_COUNT), features.to_vec())?;
let outputs = self
.regressor
.run(ort::inputs![TensorRef::from_array_view(&input)?]?)?;
let first = outputs
.iter()
.next()
.context("regressor returned no outputs")?;
let tensor = first.try_extract_tensor::<f32>()?;
let value = tensor
.view()
.iter()
.next()
.copied()
.context("regressor output tensor was empty")?;
Ok(value)
}
}
fn population_std(values: &[f32], mean: f32) -> f32 {
if values.len() <= 1 {
return 0.0;
}
let variance: f32 = values
.iter()
.map(|value| {
let diff = *value - mean;
diff * diff
})
.sum::<f32>()
/ values.len() as f32;
variance.sqrt()
}
fn rolling_mean(history: &VecDeque<f32>, window: usize) -> f32 {
if history.is_empty() {
return 0.0;
}
let count = window.min(history.len());
let sum: f32 = history.iter().rev().take(count).copied().sum();
sum / count as f32
}
fn slope_from_history(history: &VecDeque<f32>, lookback: usize) -> f32 {
if history.is_empty() {
return 0.0;
}
let current = *history.back().unwrap();
let previous_index = history.len().saturating_sub(lookback + 1);
let previous = history.get(previous_index).copied().unwrap_or(current);
let steps = (history.len() - 1).saturating_sub(previous_index).max(1);
(current - previous) / steps as f32
}
fn interpolate_force(sensor_sum: f32, points: &[CalibrationPoint]) -> f32 {
if points.is_empty() {
return 0.0;
}
if sensor_sum <= points[0].sensor_sum {
return points[0].force_n;
}
if sensor_sum >= points[points.len() - 1].sensor_sum {
return points[points.len() - 1].force_n;
}
for pair in points.windows(2) {
let left = &pair[0];
let right = &pair[1];
if sensor_sum >= left.sensor_sum && sensor_sum <= right.sensor_sum {
let ratio = (sensor_sum - left.sensor_sum) / (right.sensor_sum - left.sensor_sum);
return left.force_n + ratio * (right.force_n - left.force_n);
}
}
points[points.len() - 1].force_n
}

View File

@@ -1,4 +1,5 @@
mod commands; mod commands;
pub mod estimator;
pub mod log; pub mod log;
pub mod serial_core; pub mod serial_core;
use commands::serial::SerialConnectionState; use commands::serial::SerialConnectionState;