feat: 切向力算法更新 + AD反解x模块

This commit is contained in:
lenn
2026-05-25 14:44:31 +08:00
parent e52c86ea1a
commit 011bfe2450
35 changed files with 31833 additions and 79 deletions

162
src-tauri/src/ad_solver.rs Normal file
View File

@@ -0,0 +1,162 @@
/// AD值反解x计算器
/// AD = -5.732*x^3 - 131.5*x^2 + 31980*x + 13490 (x <= 6.57)
/// AD = -377.8*x^2 + 26040*x + 51120 (x > 6.57)
const X_BOUNDARY: f64 = 6.57;
/// 二次方程在边界处的AD值
/// 当 x = 6.57 时AD = -377.8*6.57^2 + 26040*6.57 + 51120
const AD_BOUNDARY: f64 = 205895.10;
/// 二次方程求解器
/// -377.8*x^2 + 26040*x + 51120 = ad
/// 返回 x > 6.57 的那个解
fn solve_quadratic(ad: f64) -> Option<f64> {
let a = -377.8;
let b = 26040.0;
let c = 51120.0 - ad;
let discriminant = b * b - 4.0 * a * c;
if discriminant < 0.0 {
return None;
}
let sqrt_d = discriminant.sqrt();
let x1 = (-b + sqrt_d) / (2.0 * a);
let x2 = (-b - sqrt_d) / (2.0 * a);
// 选择 x > 6.57 的解(只可能有一个解满足这个条件)
if x1 > X_BOUNDARY && x1 > 0.0 {
Some(x1)
} else if x2 > X_BOUNDARY && x2 > 0.0 {
Some(x2)
} else {
None
}
}
/// 计算三次多项式的值
/// f(x) = -5.732*x^3 - 131.5*x^2 + 31980*x + 13490
fn cubic_value(x: f64) -> f64 {
-5.732 * x.powi(3) - 131.5 * x.powi(2) + 31980.0 * x + 13490.0
}
/// 使用二分法求解三次方程 (x <= 6.57)
/// 三次方程在 [0, 6.57] 范围内是单调递增的
fn solve_cubic_bisection(ad: f64) -> Option<f64> {
let mut low = 0.0;
let mut high = X_BOUNDARY;
let target = ad;
// 检查目标是否在范围内
let low_ad = cubic_value(low);
let high_ad = cubic_value(high);
if target < low_ad.min(high_ad) || target > low_ad.max(high_ad) {
return None;
}
for _i in 0..100 {
let mid = (low + high) / 2.0;
let mid_ad = cubic_value(mid);
if (high - low).abs() < 1e-10 {
return Some((low + high) / 2.0);
}
if mid_ad > target {
high = mid;
} else {
low = mid;
}
}
Some((low + high) / 2.0)
}
/// 主求解函数根据AD值反解x
pub fn solve_for_x(ad: f64) -> Option<f64> {
// 如果 AD <= 边界值,使用三次方程 (x <= 6.57)
// 如果 AD > 边界值,使用二次方程 (x > 6.57)
if ad <= AD_BOUNDARY {
return solve_cubic_bisection(ad);
}
// AD > 边界值,使用二次方程
solve_quadratic(ad)
}
/// 批量求解,用于验证所有解
pub fn solve_for_x_all(ad: f64) -> Vec<f64> {
let mut results = Vec::new();
// 三次方程解
if let Some(x) = solve_cubic_bisection(ad) {
results.push(x);
}
// 二次方程解
if let Some(x) = solve_quadratic(ad) {
results.push(x);
}
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cubic_forward() {
// 测试 x <= 6.57 的正向计算
let x = 5.0;
let ad = cubic_value(x);
println!("x={}, ad={}", x, ad);
let solved = solve_for_x(ad).unwrap();
println!("solved={}", solved);
assert!((solved - x).abs() < 0.01, "x={}, solved={}", x, solved);
}
#[test]
fn test_quadratic_forward() {
// 测试 x > 6.57 的正向计算
let x = 10.0;
let ad = -377.8 * x * x + 26040.0 * x + 51120.0;
let solved = solve_for_x(ad).unwrap();
assert!((solved - x).abs() < 0.01, "x={}, solved={}", x, solved);
}
#[test]
fn test_boundary() {
// 测试边界值
let x = 6.57;
let ad_cubic = cubic_value(x);
let ad_quad = -377.8 * x * x + 26040.0 * x + 51120.0;
println!("x=6.57 时三次方程 AD = {:.2}", ad_cubic);
println!("x=6.57 时二次方程 AD = {:.2}", ad_quad);
println!("边界值 AD_BOUNDARY = {:.2}", AD_BOUNDARY);
// 边界处两个公式应该有显著差异
assert!((ad_cubic - ad_quad).abs() > 100.0);
}
#[test]
fn test_known_values() {
// 测试一些已知值
let test_cases = [
(0.0, cubic_value(0.0)),
(3.0, cubic_value(3.0)),
(6.0, cubic_value(6.0)),
(8.0, -377.8 * 8.0 * 8.0 + 26040.0 * 8.0 + 51120.0),
(15.0, -377.8 * 15.0 * 15.0 + 26040.0 * 15.0 + 51120.0),
];
for (x, ad) in test_cases {
let solved = solve_for_x(ad).unwrap();
assert!((solved - x).abs() < 0.01, "x={}, ad={}, solved={}", x, ad, solved);
}
}
}

View File

@@ -278,10 +278,16 @@ async fn run_grpc_upload(
angle: message.angle,
};
::log::debug!(
"python pzt angle: seq={} dts_ms={} angle={:.2}",
message.seq,
message.dts_ms,
message.angle
"devkit: angle={:.2}, magnitude={:.4}, state={}, cop_x={:.4}, cop_y={:.4}, base_x={:.4}, base_y={:.4}, total_press={:.2}, thresh={:.2}",
message.angle,
message.magnitude,
message.state,
message.cop_x,
message.cop_y,
message.base_x,
message.base_y,
message.total_press,
message.threshold
);
app.emit("devkit_pzt_angle", payload)?;
} else {

View File

@@ -1,3 +1,4 @@
pub mod ad_solver;
mod commands;
mod lan_game;
pub mod log;

View File

@@ -77,11 +77,12 @@ impl TactileACodec {
.chunks_exact(2)
.map(|chunk| {
let raw = u16::from_le_bytes([chunk[0], chunk[1]]) as i32;
if raw < 15 {
0
} else {
raw
}
raw
// if raw < 15 {
// 0
// } else {
// raw
// }
})
.collect::<Vec<i32>>();

View File

@@ -22,7 +22,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::time::{self, Duration, MissedTickBehavior};
use tokio_serial::SerialStream;
use tokio_util::sync::CancellationToken;
use crate::ad_solver::solve_for_x;
const AUTO_SUB_INTERVAL: Duration = Duration::from_nanos(16_666_667);
pub enum PollMode<F> {
@@ -316,10 +316,10 @@ where
{
let pzt_values = vals.iter().map(|value| *value as f32).collect::<Vec<f32>>();
if let Ok(analysis) = pzt_processor.get_pzt_analysis(&pzt_values) {
debug!(
"spatial force: angle={:.2}°, magnitude={:.2}, dx={:.2}, dy={:.2}",
analysis.angle_deg, analysis.magnitude, analysis.planar_x, analysis.planar_y
);
// debug!(
// "spatial force: angle={:.2}°, magnitude={:.2}, dx={:.2}, dy={:.2}",
// analysis.angle_deg, analysis.magnitude, analysis.planar_x, analysis.planar_y
// );
if PztProcessor::should_report(&analysis) {
spatial_force = Some(HudSpatialForce {
angle_deg: analysis.angle_deg,
@@ -333,6 +333,7 @@ where
{
let summary = vals.iter().copied().sum::<i32>();
let force = raw_to_g1(summary as u32);
push_devkit_frame(&app, vals.as_slice(), frame.dts_ms(), force);
}
@@ -358,6 +359,8 @@ fn build_display_values(
) -> Option<Vec<i32>> {
let summary = values.iter().copied().sum::<i32>();
let force = raw_to_g1(summary as u32);
// let force_solve = solve_for_x(summary as f64)?;
// println!("force_solve: {force_solve}");
chart_state.record_summary(force as f32);
chart_state.record_pressure_matrix(values);
chart_state.record_spatial_force(spatial_force);
@@ -417,12 +420,12 @@ fn infer_matrix_shape(len: usize) -> (u32, u32) {
}
fn raw_to_g1(raw: u32) -> f64 {
const X: [u32; 12] = [
0, 84402, 117218, 140176, 159126, 175812, 191484, 208758, 224703, 252448, 302361, 352703,
const X: [u32; 13] = [
0, 16811, 41350, 79241, 94615, 127446, 149559, 175900, 195056, 237852, 267810, 322472, 378511,
];
const Y: [f64; 12] = [
0.0, 160.0, 260.0, 360.0, 460.0, 560.0, 660.0, 760.0, 860.0, 1060.0, 1560.0, 2060.0,
const Y: [f64; 13] = [
0.0, 57.0, 97.0, 197.0, 257.0, 357.0, 457.0, 557.0, 657.0, 857.0, 1057.0, 1557.0, 2057.0,
];
let n = X.len();