Files
JE-Skin/devkit/sensor_server.py

468 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
JE-Skin DevKit — Python gRPC Sensor Server
提供两个服务:
1. SensorPush (streaming) — 接收实时传感器帧
2. ExportProcessor (unary) — 处理导出的 CSV 文件梯度过滤、xlsx 转换
安装依赖:
pip install grpcio grpcio-tools openpyxl
生成 gRPC 代码:
python -m grpc_tools.protoc -I../src-tauri/proto --python_out=. --grpc_python_out=. ../src-tauri/proto/sensor_stream.proto
启动:
python sensor_server.py [--port 50051]
"""
from __future__ import annotations
import argparse
import csv
import os
import signal
import statistics
import sys
import time
from concurrent import futures
from pathlib import Path
import grpc
import sensor_stream_pb2
import sensor_stream_pb2_grpc
# ── 梯度过滤逻辑(来自用户的 main.py ─────────────────────────
def load_rows(path: Path) -> list[list[str]]:
with path.open("r", encoding="utf-8-sig", newline="") as f:
return [row for row in csv.reader(f) if row]
def row_sum(row: list[str]) -> float:
return sum(float(v) for v in row[1:] if v.strip())
def find_threshold(sum_values: list[float]) -> float:
if len(sum_values) < 2:
raise ValueError("At least two rows are required.")
sorted_v = sorted(sum_values)
idx = max(
range(len(sorted_v) - 1),
key=lambda i: sorted_v[i + 1] - sorted_v[i],
)
return (sorted_v[idx] + sorted_v[idx + 1]) / 2.0
def extract_press_groups(
rows: list[list[str]], sum_values: list[float], threshold: float
) -> tuple[list[list[str]], list[float]]:
filtered: list[list[str]] = []
group_means: list[float] = []
current_group: list[float] = []
for row, total in zip(rows, sum_values):
if total >= threshold:
filtered.append(row)
current_group.append(total)
continue
if current_group:
group_means.append(statistics.fmean(current_group))
current_group = []
if current_group:
group_means.append(statistics.fmean(current_group))
return filtered, group_means
def write_csv(path: Path, rows: list[list[str]]) -> Path:
out = path.with_name(f"{path.stem}_filtered.csv")
with out.open("w", encoding="utf-8-sig", newline="") as f:
csv.writer(f).writerows(rows)
return out
def write_xlsx(path: Path, rows: list[list[str]], stats: dict) -> Path:
"""将过滤后的数据和统计信息写入 xlsx"""
try:
import openpyxl
except ImportError:
raise RuntimeError("openpyxl is required for xlsx output. Install it with: pip install openpyxl")
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
wb = openpyxl.Workbook()
# Sheet 1: 过滤后的数据
ws_data = wb.active
ws_data.title = "Filtered Data"
for row in rows:
ws_data.append([float(c) if c.strip().replace(".", "").replace("-", "").isdigit() else c for c in row])
# Sheet 2: 统计信息
ws_stats = wb.create_sheet("Statistics")
header_font = Font(bold=True, size=11)
header_fill = PatternFill(start_color="E0E0E0", end_color="E0E0E0", fill_type="solid")
ws_stats.append(["Parameter", "Value"])
ws_stats["A1"].font = header_font
ws_stats["A1"].fill = header_fill
ws_stats["B1"].font = header_font
ws_stats["B1"].fill = header_fill
stats_rows = [
("Source File", stats.get("source_file", "")),
("Total Rows", stats.get("rows_total", 0)),
("Filtered Rows", stats.get("rows_kept", 0)),
("Groups Used", stats.get("groups_used", 0)),
("Mean Value", f"{stats.get('mean_value', 0):.3f}"),
("Threshold", f"{stats.get('threshold', 0):.3f}"),
("Process Time", stats.get("process_time", "")),
]
for label, value in stats_rows:
ws_stats.append([label, value])
ws_stats.column_dimensions["A"].width = 18
ws_stats.column_dimensions["B"].width = 30
out = path.with_name(f"{path.stem}_filtered.xlsx")
wb.save(str(out))
return out
def process_csv(csv_path: str, save_as_xlsx: bool) -> dict:
"""执行梯度过滤,返回结果统计"""
path = Path(csv_path)
if not path.is_file():
raise FileNotFoundError(f"CSV file not found: {csv_path}")
rows = load_rows(path)
if not rows:
raise ValueError("CSV file is empty.")
sum_values = [row_sum(r) for r in rows]
threshold = find_threshold(sum_values)
filtered_rows, group_means = extract_press_groups(rows, sum_values, threshold)
if not filtered_rows:
raise ValueError("No large press-down data was found.")
overall_mean = statistics.fmean(group_means)
stats = {
"source_file": path.name,
"rows_total": len(rows),
"rows_kept": len(filtered_rows),
"groups_used": len(group_means),
"mean_value": overall_mean,
"threshold": threshold,
"process_time": time.strftime("%Y-%m-%d %H:%M:%S"),
}
if save_as_xlsx:
output_path = write_xlsx(path, filtered_rows, stats)
# 删除源 CSV
try:
path.unlink()
except OSError:
pass
else:
output_path = write_csv(path, filtered_rows)
# 用过滤后的文件替换源文件
try:
path.unlink()
output_path.rename(path)
output_path = path
except OSError:
pass
# 追加一行到汇总 xlsx
_append_analysis_log(csv_path, stats)
return {
"ok": True,
"output_path": str(output_path),
"groups_used": len(group_means),
"mean_value": overall_mean,
"threshold": threshold,
"rows_total": len(rows),
"rows_kept": len(filtered_rows),
"message": "OK",
}
def _append_analysis_log(source_csv: str, stats: dict):
"""将处理结果追加到 devkit_analysis_results.xlsx"""
try:
import openpyxl
except ImportError:
return # openpyxl 不可用时跳过
log_path = Path(source_csv).parent / "devkit_analysis_results.xlsx"
if log_path.exists():
wb = openpyxl.load_workbook(str(log_path))
ws = wb.active
else:
wb = openpyxl.Workbook()
ws = wb.active
ws.title = "Analysis Log"
ws.append(["Time", "Source File", "Total Rows", "Kept Rows",
"Groups", "Mean Value", "Threshold", "Output File"])
ws.append([
stats.get("process_time", ""),
stats.get("source_file", ""),
stats.get("rows_total", 0),
stats.get("rows_kept", 0),
stats.get("groups_used", 0),
round(stats.get("mean_value", 0), 3),
round(stats.get("threshold", 0), 3),
f"{Path(stats.get('source_file', '')).stem}_filtered",
])
wb.save(str(log_path))
# ── gRPC 服务实现 ────────────────────────────────────────────────
class SensorPushServicer(sensor_stream_pb2_grpc.SensorPushServicer):
"""接收实时传感器帧streaming"""
def __init__(self):
self.frame_count = 0
self.last_report_time = time.time()
self.last_angle = None
def Upload(self, request_iterator, context):
print("[SensorPush] Client connected, waiting for frames...")
reset_baseline()
self.last_angle = None
for frame in request_iterator:
self.frame_count += 1
angle = 0.0
ok = True
message = "OK"
if len(frame.matrix) == SENSOR_ROWS * SENSOR_COLS:
try:
angle = get_pzt_angle(frame.matrix)
self.last_angle = angle
if self.frame_count <= 10 or self.frame_count % 30 == 0:
print(
f"[SensorPush] PZT angle frame #{frame.seq} "
f"dts={frame.dts_ms} angle={angle:.2f}"
)
except Exception as e:
ok = False
message = str(e)
print(f"[SensorPush] PZT compute error on frame #{frame.seq}: {e}")
else:
ok = False
message = f"Invalid matrix length: {len(frame.matrix)}"
yield sensor_stream_pb2.PztAngleResponse(
seq=frame.seq,
timestamp_ms=frame.timestamp_ms,
angle=angle,
dts_ms=frame.dts_ms,
ok=ok,
message=message,
)
if self.frame_count % 100 == 0:
now = time.time()
elapsed = now - self.last_report_time
fps = 100 / elapsed if elapsed > 0 else 0
self.last_report_time = now
angle_text = (
f"{self.last_angle:.2f}"
if self.last_angle is not None
else "n/a"
)
print(
f"[SensorPush] Frame #{frame.seq} | "
f"{frame.rows}x{frame.cols} | "
f"angle={angle_text} | "
f"force={frame.resultant_force:.1f} | "
f"total={self.frame_count} | ~{fps:.1f} fps"
)
print(f"[SensorPush] Stream ended. Total: {self.frame_count}")
class ExportProcessorServicer(sensor_stream_pb2_grpc.ExportProcessorServicer):
"""处理导出的 CSV 文件unary"""
def ProcessFile(self, request, context):
csv_path = request.csv_path
save_as_xlsx = request.save_as_xlsx
print(f"[ExportProcessor] Processing: {csv_path} (xlsx={save_as_xlsx})")
try:
result = process_csv(csv_path, save_as_xlsx)
return sensor_stream_pb2.ProcessResponse(
ok=result["ok"],
output_path=result["output_path"],
groups_used=result["groups_used"],
mean_value=result["mean_value"],
threshold=result["threshold"],
rows_total=result["rows_total"],
rows_kept=result["rows_kept"],
message=result["message"],
)
except Exception as e:
print(f"[ExportProcessor] Error: {e}")
return sensor_stream_pb2.ProcessResponse(
ok=False,
output_path="",
message=str(e),
)
# ── 启动 ────────────────────────────────────────────────────────
def serve(port: int):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
sensor_stream_pb2_grpc.add_SensorPushServicer_to_server(SensorPushServicer(), server)
sensor_stream_pb2_grpc.add_ExportProcessorServicer_to_server(ExportProcessorServicer(), server)
listen_addr = f"0.0.0.0:{port}"
server.add_insecure_port(listen_addr)
server.start()
print(f"[DevKit Server] gRPC listening on {listen_addr}")
print(f"[DevKit Server] Services: SensorPush (streaming), ExportProcessor (unary)")
def shutdown(signum, frame):
print("\n[DevKit Server] Shutting down...")
server.stop(grace=5)
sys.exit(0)
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
server.wait_for_termination()
import numpy as np
import threading
# ===================== 算法参数=====================
TOTAL_PRESSURE_LOW_THRESHOLD = 500
COP_STABILITY_FRAMES_REQUIRED = 5
SENSOR_ROWS = 12
SENSOR_COLS = 7
# ===================== 线程安全全局状态 =====================
first_frame = None
first_frame_lock = threading.Lock()
first_contact_CoP_x = None
first_contact_CoP_y = None
contact_initialized = False
total_pressure_low_counter = 0
# ===================== 基线减除 =====================
def subtract_baseline(current_frame):
global first_frame
current_frame = np.array(current_frame, dtype=np.float32).flatten()
with first_frame_lock:
if first_frame is None:
first_frame = current_frame.copy()
diff = current_frame - first_frame
return np.clip(diff, 0, None)
# ===================== 重置CoP状态 =====================
def reset_cop_state():
global first_contact_CoP_x, first_contact_CoP_y, contact_initialized
global total_pressure_low_counter
first_contact_CoP_x = None
first_contact_CoP_y = None
contact_initialized = False
total_pressure_low_counter = 0
# ===================== CoP压力中心计算 =====================
def compute_pressure_direction(baseline_subtracted_frame):
global first_contact_CoP_x, first_contact_CoP_y, contact_initialized
global total_pressure_low_counter
rows, cols = SENSOR_ROWS, SENSOR_COLS
frame_flat = np.asarray(baseline_subtracted_frame, dtype=np.float32).flatten()
frame2d = frame_flat.reshape(rows, cols)
total_pressure = np.sum(frame2d)
if total_pressure < TOTAL_PRESSURE_LOW_THRESHOLD:
total_pressure_low_counter += 1
else:
total_pressure_low_counter = 0
if total_pressure_low_counter >= COP_STABILITY_FRAMES_REQUIRED:
reset_cop_state()
return 0.0, 0.0
if total_pressure == 0:
return 0.0, 0.0
x_grid = np.tile(np.arange(cols), (rows, 1))
y_grid = np.repeat(np.arange(rows), cols).reshape(rows, cols)
cop_x = np.sum(frame2d * x_grid) / total_pressure
cop_y = np.sum(frame2d * y_grid) / total_pressure
delta_CoP_x = 0.0
delta_CoP_y = 0.0
if not contact_initialized:
first_contact_CoP_x = cop_x
first_contact_CoP_y = cop_y
contact_initialized = True
else:
delta_CoP_x = cop_x - first_contact_CoP_x
delta_CoP_y = cop_y - first_contact_CoP_y
return delta_CoP_x, delta_CoP_y
# ===================== 角度计算核心 =====================
def compute_vector_angle(x: float, y: float) -> tuple[float, float]:
epsilon = 1e-8
mag = np.hypot(x, y)
angle = np.degrees(np.arctan2(y, x + epsilon))
if angle < 0:
angle += 360
return angle, mag
def compute_PZT_angle(Px: float, Py: float) -> tuple[float, float]:
return compute_vector_angle(Px, -Py)
# ===================== 核心入口函数 =====================
def get_pzt_angle(adc_data):
if len(adc_data) != 84:
raise ValueError("ADC数据长度必须为84")
baseline_subtracted = subtract_baseline(adc_data)
dx, dy = compute_pressure_direction(baseline_subtracted)
pzt_angle, _ = compute_PZT_angle(dx, dy)
return pzt_angle
# ===================== 重置基线(校准用) =====================
def reset_baseline():
global first_frame
with first_frame_lock:
first_frame = None
reset_cop_state()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="JE-Skin DevKit gRPC Server")
parser.add_argument("--port", type=int, default=50051, help="gRPC listen port (default: 50051)")
args = parser.parse_args()
serve(args.port)