585 lines
20 KiB
Python
585 lines
20 KiB
Python
"""
|
||
JE-Skin DevKit — Python gRPC Sensor Server
|
||
|
||
提供两个服务:
|
||
1. SensorPush (streaming) — 接收实时传感器帧
|
||
2. ExportProcessor (unary) — 处理导出的 CSV 文件:梯度过滤、xlsx 转换
|
||
|
||
安装依赖:
|
||
pip install grpcio grpcio-tools openpyxl
|
||
|
||
生成 gRPC 代码:
|
||
python -m grpc_tools.protoc -I../src-tauri/proto --python_out=. --grpc_python_out=. ../src-tauri/proto/sensor_stream.proto
|
||
|
||
启动:
|
||
python sensor_server.py [--port 50051]
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import csv
|
||
import os
|
||
import signal
|
||
import statistics
|
||
import sys
|
||
import time
|
||
from concurrent import futures
|
||
from pathlib import Path
|
||
import grpc
|
||
import sensor_stream_pb2
|
||
import sensor_stream_pb2_grpc
|
||
|
||
# ── 梯度过滤逻辑(来自用户的 main.py) ─────────────────────────
|
||
|
||
|
||
def load_rows(path: Path) -> list[list[str]]:
|
||
with path.open("r", encoding="utf-8-sig", newline="") as f:
|
||
return [row for row in csv.reader(f) if row]
|
||
|
||
|
||
def row_sum(row: list[str]) -> float:
|
||
return sum(float(v) for v in row[1:] if v.strip())
|
||
|
||
|
||
def find_threshold(sum_values: list[float]) -> float:
|
||
if len(sum_values) < 2:
|
||
raise ValueError("At least two rows are required.")
|
||
sorted_v = sorted(sum_values)
|
||
idx = max(
|
||
range(len(sorted_v) - 1),
|
||
key=lambda i: sorted_v[i + 1] - sorted_v[i],
|
||
)
|
||
return (sorted_v[idx] + sorted_v[idx + 1]) / 2.0
|
||
|
||
|
||
def extract_press_groups(
|
||
rows: list[list[str]], sum_values: list[float], threshold: float
|
||
) -> tuple[list[list[str]], list[float]]:
|
||
filtered: list[list[str]] = []
|
||
group_means: list[float] = []
|
||
current_group: list[float] = []
|
||
|
||
for row, total in zip(rows, sum_values):
|
||
if total >= threshold:
|
||
filtered.append(row)
|
||
current_group.append(total)
|
||
continue
|
||
if current_group:
|
||
group_means.append(statistics.fmean(current_group))
|
||
current_group = []
|
||
|
||
if current_group:
|
||
group_means.append(statistics.fmean(current_group))
|
||
|
||
return filtered, group_means
|
||
|
||
|
||
def write_csv(path: Path, rows: list[list[str]]) -> Path:
|
||
out = path.with_name(f"{path.stem}_filtered.csv")
|
||
with out.open("w", encoding="utf-8-sig", newline="") as f:
|
||
csv.writer(f).writerows(rows)
|
||
return out
|
||
|
||
|
||
def write_xlsx(path: Path, rows: list[list[str]], stats: dict) -> Path:
|
||
"""将过滤后的数据和统计信息写入 xlsx"""
|
||
try:
|
||
import openpyxl
|
||
except ImportError:
|
||
raise RuntimeError("openpyxl is required for xlsx output. Install it with: pip install openpyxl")
|
||
|
||
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
|
||
|
||
wb = openpyxl.Workbook()
|
||
|
||
# Sheet 1: 过滤后的数据
|
||
ws_data = wb.active
|
||
ws_data.title = "Filtered Data"
|
||
for row in rows:
|
||
ws_data.append([float(c) if c.strip().replace(".", "").replace("-", "").isdigit() else c for c in row])
|
||
|
||
# Sheet 2: 统计信息
|
||
ws_stats = wb.create_sheet("Statistics")
|
||
header_font = Font(bold=True, size=11)
|
||
header_fill = PatternFill(start_color="E0E0E0", end_color="E0E0E0", fill_type="solid")
|
||
|
||
ws_stats.append(["Parameter", "Value"])
|
||
ws_stats["A1"].font = header_font
|
||
ws_stats["A1"].fill = header_fill
|
||
ws_stats["B1"].font = header_font
|
||
ws_stats["B1"].fill = header_fill
|
||
|
||
stats_rows = [
|
||
("Source File", stats.get("source_file", "")),
|
||
("Total Rows", stats.get("rows_total", 0)),
|
||
("Filtered Rows", stats.get("rows_kept", 0)),
|
||
("Groups Used", stats.get("groups_used", 0)),
|
||
("Mean Value", f"{stats.get('mean_value', 0):.3f}"),
|
||
("Threshold", f"{stats.get('threshold', 0):.3f}"),
|
||
("Process Time", stats.get("process_time", "")),
|
||
]
|
||
for label, value in stats_rows:
|
||
ws_stats.append([label, value])
|
||
|
||
ws_stats.column_dimensions["A"].width = 18
|
||
ws_stats.column_dimensions["B"].width = 30
|
||
|
||
out = path.with_name(f"{path.stem}_filtered.xlsx")
|
||
wb.save(str(out))
|
||
return out
|
||
|
||
|
||
def process_csv(csv_path: str, save_as_xlsx: bool) -> dict:
|
||
"""执行梯度过滤,返回结果统计"""
|
||
path = Path(csv_path)
|
||
if not path.is_file():
|
||
raise FileNotFoundError(f"CSV file not found: {csv_path}")
|
||
|
||
rows = load_rows(path)
|
||
if not rows:
|
||
raise ValueError("CSV file is empty.")
|
||
|
||
sum_values = [row_sum(r) for r in rows]
|
||
threshold = find_threshold(sum_values)
|
||
filtered_rows, group_means = extract_press_groups(rows, sum_values, threshold)
|
||
|
||
if not filtered_rows:
|
||
raise ValueError("No large press-down data was found.")
|
||
|
||
overall_mean = statistics.fmean(group_means)
|
||
|
||
stats = {
|
||
"source_file": path.name,
|
||
"rows_total": len(rows),
|
||
"rows_kept": len(filtered_rows),
|
||
"groups_used": len(group_means),
|
||
"mean_value": overall_mean,
|
||
"threshold": threshold,
|
||
"process_time": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||
}
|
||
|
||
if save_as_xlsx:
|
||
output_path = write_xlsx(path, filtered_rows, stats)
|
||
# 删除源 CSV
|
||
try:
|
||
path.unlink()
|
||
except OSError:
|
||
pass
|
||
else:
|
||
output_path = write_csv(path, filtered_rows)
|
||
# 用过滤后的文件替换源文件
|
||
try:
|
||
path.unlink()
|
||
output_path.rename(path)
|
||
output_path = path
|
||
except OSError:
|
||
pass
|
||
|
||
# 追加一行到汇总 xlsx
|
||
_append_analysis_log(csv_path, stats)
|
||
|
||
return {
|
||
"ok": True,
|
||
"output_path": str(output_path),
|
||
"groups_used": len(group_means),
|
||
"mean_value": overall_mean,
|
||
"threshold": threshold,
|
||
"rows_total": len(rows),
|
||
"rows_kept": len(filtered_rows),
|
||
"message": "OK",
|
||
}
|
||
|
||
|
||
def _append_analysis_log(source_csv: str, stats: dict):
|
||
"""将处理结果追加到 devkit_analysis_results.xlsx"""
|
||
try:
|
||
import openpyxl
|
||
except ImportError:
|
||
return # openpyxl 不可用时跳过
|
||
|
||
log_path = Path(source_csv).parent / "devkit_analysis_results.xlsx"
|
||
|
||
if log_path.exists():
|
||
wb = openpyxl.load_workbook(str(log_path))
|
||
ws = wb.active
|
||
else:
|
||
wb = openpyxl.Workbook()
|
||
ws = wb.active
|
||
ws.title = "Analysis Log"
|
||
ws.append(["Time", "Source File", "Total Rows", "Kept Rows",
|
||
"Groups", "Mean Value", "Threshold", "Output File"])
|
||
|
||
ws.append([
|
||
stats.get("process_time", ""),
|
||
stats.get("source_file", ""),
|
||
stats.get("rows_total", 0),
|
||
stats.get("rows_kept", 0),
|
||
stats.get("groups_used", 0),
|
||
round(stats.get("mean_value", 0), 3),
|
||
round(stats.get("threshold", 0), 3),
|
||
f"{Path(stats.get('source_file', '')).stem}_filtered",
|
||
])
|
||
|
||
wb.save(str(log_path))
|
||
|
||
|
||
# ── gRPC 服务实现 ────────────────────────────────────────────────
|
||
|
||
|
||
class SensorPushServicer(sensor_stream_pb2_grpc.SensorPushServicer):
|
||
"""接收实时传感器帧(streaming)"""
|
||
|
||
_csv_path = None # 类变量,记录当前 CSV 路径
|
||
|
||
def __init__(self):
|
||
self.frame_count = 0
|
||
self.last_report_time = time.time()
|
||
self.last_angle = None
|
||
self._csv_file = None
|
||
self._csv_writer = None
|
||
|
||
def _open_csv(self):
|
||
"""打开一个新的 CSV 文件用于持续写入"""
|
||
ts = time.strftime("%Y%m%d_%H%M%S")
|
||
SensorPushServicer._csv_path = os.path.join(os.getcwd(), f"sensor_log_{ts}.csv")
|
||
self._csv_file = open(SensorPushServicer._csv_path, "w", newline="", encoding="utf-8-sig")
|
||
self._csv_writer = csv.writer(self._csv_file)
|
||
header = ["seq", "timestamp_ms", "dts_ms", "angle", "magnitude", "state", "cop_x", "cop_y", "base_x", "base_y", "resultant_force"] + [f"ch{i}" for i in range(SENSOR_ROWS * SENSOR_COLS)]
|
||
self._csv_writer.writerow(header)
|
||
self._csv_file.flush()
|
||
print(f"[SensorPush] CSV logging to: {SensorPushServicer._csv_path}")
|
||
|
||
def _close_csv(self):
|
||
"""关闭 CSV 文件"""
|
||
if self._csv_file:
|
||
self._csv_file.close()
|
||
print(f"[SensorPush] CSV saved: {SensorPushServicer._csv_path}")
|
||
self._csv_file = None
|
||
self._csv_writer = None
|
||
|
||
def Upload(self, request_iterator, context):
|
||
print("[SensorPush] Client connected, waiting for frames...")
|
||
reset_baseline()
|
||
self.last_angle = None
|
||
self.frame_count = 0
|
||
self._open_csv()
|
||
|
||
for frame in request_iterator:
|
||
self.frame_count += 1
|
||
angle = 0.0
|
||
magnitude = 0.0
|
||
state = 0
|
||
ok = True
|
||
message = "OK"
|
||
cop_x = cop_y = base_x = base_y = 0.0
|
||
total_press = 0.0
|
||
threshold = 0.0
|
||
if len(frame.matrix) == SENSOR_ROWS * SENSOR_COLS:
|
||
try:
|
||
angle, magnitude, state, cop_x, cop_y, base_x, base_y, total_press, threshold = get_pzt_angle(frame.matrix)
|
||
self.last_angle = angle
|
||
print(f"devkit: angle={angle:.2f}, magnitude={magnitude:.4f}, state={state}, cop_x={cop_x:.4f}, cop_y={cop_y:.4f}, base_x={base_x:.4f}, base_y={base_y:.4f}, total_press={total_press:.2f}, thresh={threshold:.2f}")
|
||
except Exception as e:
|
||
ok = False
|
||
message = str(e)
|
||
print(f"[SensorPush] PZT compute error on frame #{frame.seq}: {e}")
|
||
else:
|
||
ok = False
|
||
message = f"Invalid matrix length: {len(frame.matrix)}"
|
||
print(f"[Recv #{frame.seq}] INVALID len={len(frame.matrix)}")
|
||
|
||
# 持续写入 CSV
|
||
if self._csv_writer:
|
||
row = [frame.seq, frame.timestamp_ms, frame.dts_ms,
|
||
f"{angle:.4f}", f"{magnitude:.4f}", state,
|
||
f"{cop_x:.4f}", f"{cop_y:.4f}", f"{base_x:.4f}", f"{base_y:.4f}",
|
||
frame.resultant_force]
|
||
row += list(frame.matrix)
|
||
self._csv_writer.writerow(row)
|
||
if self.frame_count % 10 == 0:
|
||
self._csv_file.flush()
|
||
|
||
yield sensor_stream_pb2.PztAngleResponse(
|
||
seq=frame.seq,
|
||
timestamp_ms=frame.timestamp_ms,
|
||
angle=angle,
|
||
dts_ms=frame.dts_ms,
|
||
ok=ok,
|
||
message=message,
|
||
magnitude=magnitude,
|
||
state=state,
|
||
cop_x=cop_x,
|
||
cop_y=cop_y,
|
||
base_x=base_x,
|
||
base_y=base_y,
|
||
total_press=total_press,
|
||
threshold=threshold,
|
||
)
|
||
|
||
if self.frame_count % 100 == 0:
|
||
now = time.time()
|
||
elapsed = now - self.last_report_time
|
||
fps = 100 / elapsed if elapsed > 0 else 0
|
||
self.last_report_time = now
|
||
angle_text = (
|
||
f"{self.last_angle:.2f}"
|
||
if self.last_angle is not None
|
||
else "n/a"
|
||
)
|
||
print(
|
||
f"[SensorPush] Frame #{frame.seq} | "
|
||
f"{frame.rows}x{frame.cols} | "
|
||
f"angle={angle_text} | "
|
||
f"force={frame.resultant_force:.1f} | "
|
||
f"total={self.frame_count} | ~{fps:.1f} fps"
|
||
)
|
||
|
||
self._close_csv()
|
||
print(f"[SensorPush] Stream ended. Total: {self.frame_count}")
|
||
|
||
|
||
class ExportProcessorServicer(sensor_stream_pb2_grpc.ExportProcessorServicer):
|
||
"""处理导出的 CSV 文件(unary)"""
|
||
|
||
def ProcessFile(self, request, context):
|
||
csv_path = request.csv_path
|
||
save_as_xlsx = request.save_as_xlsx
|
||
|
||
print(f"[ExportProcessor] Processing: {csv_path} (xlsx={save_as_xlsx})")
|
||
|
||
try:
|
||
result = process_csv(csv_path, save_as_xlsx)
|
||
return sensor_stream_pb2.ProcessResponse(
|
||
ok=result["ok"],
|
||
output_path=result["output_path"],
|
||
groups_used=result["groups_used"],
|
||
mean_value=result["mean_value"],
|
||
threshold=result["threshold"],
|
||
rows_total=result["rows_total"],
|
||
rows_kept=result["rows_kept"],
|
||
message=result["message"],
|
||
)
|
||
except Exception as e:
|
||
print(f"[ExportProcessor] Error: {e}")
|
||
return sensor_stream_pb2.ProcessResponse(
|
||
ok=False,
|
||
output_path="",
|
||
message=str(e),
|
||
)
|
||
|
||
|
||
# ── 启动 ────────────────────────────────────────────────────────
|
||
|
||
|
||
def serve(port: int):
|
||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||
sensor_stream_pb2_grpc.add_SensorPushServicer_to_server(SensorPushServicer(), server)
|
||
sensor_stream_pb2_grpc.add_ExportProcessorServicer_to_server(ExportProcessorServicer(), server)
|
||
|
||
listen_addr = f"0.0.0.0:{port}"
|
||
server.add_insecure_port(listen_addr)
|
||
server.start()
|
||
|
||
print(f"[DevKit Server] gRPC listening on {listen_addr}")
|
||
print(f"[DevKit Server] Services: SensorPush (streaming), ExportProcessor (unary)")
|
||
|
||
def shutdown(signum, frame):
|
||
print("\n[DevKit Server] Shutting down...")
|
||
server.stop(grace=5)
|
||
sys.exit(0)
|
||
|
||
signal.signal(signal.SIGINT, shutdown)
|
||
signal.signal(signal.SIGTERM, shutdown)
|
||
|
||
server.wait_for_termination()
|
||
|
||
|
||
import numpy as np
|
||
from collections import deque
|
||
|
||
# ===================== 算法参数=====================
|
||
COP_INIT_MEDIAN_FRAMES = 1 # 初始COP取中位数的帧数
|
||
NOISE_COLLECT_FRAMES = 10 # 动态阈值基线采集帧数
|
||
THRESH_K = 5 # 阈值 = K * mean
|
||
SENSOR_ROWS = 12
|
||
SENSOR_COLS = 7
|
||
|
||
# ===================== 二次静置精修参数 =====================
|
||
POST_INIT_WINDOW_CNT = 60000
|
||
POST_INIT_STABLE_CNT = 100
|
||
POST_INIT_STABLE_THRESH = 0.1
|
||
|
||
# ===================== 线程安全全局状态 =====================
|
||
first_contact_CoP_x = None
|
||
first_contact_CoP_y = None
|
||
contact_initialized = False
|
||
|
||
# 候选初始CoP缓冲
|
||
cop_init_x_buf = deque(maxlen=COP_INIT_MEDIAN_FRAMES)
|
||
cop_init_y_buf = deque(maxlen=COP_INIT_MEDIAN_FRAMES)
|
||
|
||
# 动态阈值
|
||
noise_sum_buf = deque(maxlen=NOISE_COLLECT_FRAMES)
|
||
dynamic_thresh = None
|
||
|
||
# 二次静置精修状态
|
||
post_init_frame_cnt = 0
|
||
post_stable_cnt = 0
|
||
post_refined_flag = False
|
||
post_cand_x = None
|
||
post_cand_y = None
|
||
|
||
|
||
# ===================== 重置CoP状态 =====================
|
||
def reset_cop_state():
|
||
global first_contact_CoP_x, first_contact_CoP_y, contact_initialized
|
||
global post_init_frame_cnt, post_stable_cnt, post_refined_flag
|
||
global post_cand_x, post_cand_y
|
||
|
||
first_contact_CoP_x = None
|
||
first_contact_CoP_y = None
|
||
contact_initialized = False
|
||
cop_init_x_buf.clear()
|
||
cop_init_y_buf.clear()
|
||
post_init_frame_cnt = 0
|
||
post_stable_cnt = 0
|
||
post_refined_flag = False
|
||
post_cand_x = None
|
||
post_cand_y = None
|
||
|
||
|
||
# ===================== CoP压力中心计算 =====================
|
||
def compute_pressure_direction(raw_frame):
|
||
global first_contact_CoP_x, first_contact_CoP_y, contact_initialized
|
||
global post_init_frame_cnt, post_stable_cnt, post_refined_flag
|
||
global post_cand_x, post_cand_y
|
||
global noise_sum_buf, dynamic_thresh
|
||
|
||
rows, cols = SENSOR_ROWS, SENSOR_COLS
|
||
frame_flat = np.asarray(raw_frame, dtype=np.float32).flatten()
|
||
frame2d = frame_flat.reshape(rows, cols)
|
||
|
||
total_pressure = np.sum(frame2d)
|
||
|
||
# 动态阈值
|
||
if dynamic_thresh is None:
|
||
noise_sum_buf.append(total_pressure)
|
||
if len(noise_sum_buf) >= NOISE_COLLECT_FRAMES:
|
||
sums = np.array(noise_sum_buf)
|
||
dynamic_thresh = THRESH_K * float(np.mean(sums))
|
||
|
||
# 低压重置
|
||
if total_pressure == 0 or (dynamic_thresh is not None and total_pressure < dynamic_thresh):
|
||
if contact_initialized and dynamic_thresh is not None:
|
||
reset_cop_state()
|
||
return 0.0, 0.0, 0, rows-1, 0, cols-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0.0, dynamic_thresh
|
||
|
||
x_grid = np.tile(np.arange(cols), (rows, 1))
|
||
y_grid = np.repeat(np.arange(rows), cols).reshape(rows, cols)
|
||
cop_x = np.sum(frame2d * x_grid) / total_pressure
|
||
cop_y = np.sum(frame2d * y_grid) / total_pressure
|
||
|
||
delta_CoP_x = 0.0
|
||
delta_CoP_y = 0.0
|
||
base_x = cop_x
|
||
base_y = cop_y
|
||
|
||
# ============ 初始点稳定判断(中位数判定) ============
|
||
if not contact_initialized:
|
||
cop_init_x_buf.append(cop_x)
|
||
cop_init_y_buf.append(cop_y)
|
||
|
||
if len(cop_init_x_buf) >= COP_INIT_MEDIAN_FRAMES:
|
||
first_contact_CoP_x = float(np.median(cop_init_x_buf))
|
||
first_contact_CoP_y = float(np.median(cop_init_y_buf))
|
||
contact_initialized = True
|
||
cop_init_x_buf.clear()
|
||
cop_init_y_buf.clear()
|
||
|
||
# ========== 计算偏移量 ==========
|
||
else:
|
||
# 二次静置精修
|
||
post_init_frame_cnt += 1
|
||
if not post_refined_flag and post_init_frame_cnt <= POST_INIT_WINDOW_CNT:
|
||
if post_cand_x is not None:
|
||
dist_val = np.hypot(cop_x - post_cand_x, cop_y - post_cand_y)
|
||
if dist_val <= POST_INIT_STABLE_THRESH:
|
||
post_stable_cnt += 1
|
||
else:
|
||
post_cand_x = cop_x
|
||
post_cand_y = cop_y
|
||
post_stable_cnt = 1
|
||
else:
|
||
post_cand_x = cop_x
|
||
post_cand_y = cop_y
|
||
post_stable_cnt = 1
|
||
|
||
if post_stable_cnt >= POST_INIT_STABLE_CNT:
|
||
first_contact_CoP_x = post_cand_x
|
||
first_contact_CoP_y = post_cand_y
|
||
post_refined_flag = True
|
||
else:
|
||
post_refined_flag = True
|
||
|
||
delta_CoP_x = cop_x - first_contact_CoP_x
|
||
delta_CoP_y = first_contact_CoP_y - cop_y
|
||
base_x = first_contact_CoP_x
|
||
base_y = first_contact_CoP_y
|
||
|
||
magnitude = np.hypot(delta_CoP_x, delta_CoP_y)
|
||
if not contact_initialized:
|
||
state = 0
|
||
elif not post_refined_flag:
|
||
state = 1
|
||
else:
|
||
state = 2
|
||
|
||
return (cop_x, cop_y,
|
||
0, rows-1, 0, cols-1,
|
||
delta_CoP_x, delta_CoP_y,
|
||
base_x, base_y,
|
||
magnitude, state,
|
||
total_pressure, dynamic_thresh)
|
||
|
||
|
||
# ===================== 角度计算核心 =====================
|
||
def compute_vector_angle(x: float, y: float) -> tuple[float, float]:
|
||
epsilon = 1e-8
|
||
mag = np.hypot(x, y)
|
||
angle = np.degrees(np.arctan2(y, x + epsilon))
|
||
if angle < 0:
|
||
angle += 360
|
||
return angle, mag
|
||
|
||
def compute_PZT_angle(Px: float, Py: float) -> tuple[float, float]:
|
||
return compute_vector_angle(Px, Py)
|
||
|
||
|
||
# ===================== 核心入口函数 =====================
|
||
def get_pzt_angle(adc_data):
|
||
if len(adc_data) != 84:
|
||
raise ValueError("ADC数据长度必须为84")
|
||
result = compute_pressure_direction(adc_data)
|
||
cop_x, cop_y = result[0], result[1]
|
||
dx, dy = result[6], result[7]
|
||
base_x, base_y = result[8], result[9]
|
||
magnitude = result[10]
|
||
state = int(result[11])
|
||
total_press = result[12]
|
||
threshold = result[13]
|
||
pzt_angle, _ = compute_PZT_angle(dx, dy)
|
||
return pzt_angle, magnitude, state, cop_x, cop_y, base_x, base_y, total_press, threshold
|
||
|
||
|
||
# ===================== 重置基线(校准用) =====================
|
||
def reset_baseline():
|
||
reset_cop_state()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="JE-Skin DevKit gRPC Server")
|
||
parser.add_argument("--port", type=int, default=50051, help="gRPC listen port (default: 50051)")
|
||
args = parser.parse_args()
|
||
serve(args.port)
|