Files
JE-Skin/devkit/sensor_server.py

323 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
JE-Skin DevKit — Python gRPC Sensor Server
提供两个服务:
1. SensorPush (streaming) — 接收实时传感器帧
2. ExportProcessor (unary) — 处理导出的 CSV 文件梯度过滤、xlsx 转换
安装依赖:
pip install grpcio grpcio-tools openpyxl
生成 gRPC 代码:
python -m grpc_tools.protoc -I../src-tauri/proto --python_out=. --grpc_python_out=. ../src-tauri/proto/sensor_stream.proto
启动:
python sensor_server.py [--port 50051]
"""
from __future__ import annotations
import argparse
import csv
import os
import signal
import statistics
import sys
import time
from concurrent import futures
from pathlib import Path
import grpc
import sensor_stream_pb2
import sensor_stream_pb2_grpc
# ── 梯度过滤逻辑(来自用户的 main.py ─────────────────────────
def load_rows(path: Path) -> list[list[str]]:
with path.open("r", encoding="utf-8-sig", newline="") as f:
return [row for row in csv.reader(f) if row]
def row_sum(row: list[str]) -> float:
return sum(float(v) for v in row[1:] if v.strip())
def find_threshold(sum_values: list[float]) -> float:
if len(sum_values) < 2:
raise ValueError("At least two rows are required.")
sorted_v = sorted(sum_values)
idx = max(
range(len(sorted_v) - 1),
key=lambda i: sorted_v[i + 1] - sorted_v[i],
)
return (sorted_v[idx] + sorted_v[idx + 1]) / 2.0
def extract_press_groups(
rows: list[list[str]], sum_values: list[float], threshold: float
) -> tuple[list[list[str]], list[float]]:
filtered: list[list[str]] = []
group_means: list[float] = []
current_group: list[float] = []
for row, total in zip(rows, sum_values):
if total >= threshold:
filtered.append(row)
current_group.append(total)
continue
if current_group:
group_means.append(statistics.fmean(current_group))
current_group = []
if current_group:
group_means.append(statistics.fmean(current_group))
return filtered, group_means
def write_csv(path: Path, rows: list[list[str]]) -> Path:
out = path.with_name(f"{path.stem}_filtered.csv")
with out.open("w", encoding="utf-8-sig", newline="") as f:
csv.writer(f).writerows(rows)
return out
def write_xlsx(path: Path, rows: list[list[str]], stats: dict) -> Path:
"""将过滤后的数据和统计信息写入 xlsx"""
try:
import openpyxl
except ImportError:
raise RuntimeError("openpyxl is required for xlsx output. Install it with: pip install openpyxl")
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
wb = openpyxl.Workbook()
# Sheet 1: 过滤后的数据
ws_data = wb.active
ws_data.title = "Filtered Data"
for row in rows:
ws_data.append([float(c) if c.strip().replace(".", "").replace("-", "").isdigit() else c for c in row])
# Sheet 2: 统计信息
ws_stats = wb.create_sheet("Statistics")
header_font = Font(bold=True, size=11)
header_fill = PatternFill(start_color="E0E0E0", end_color="E0E0E0", fill_type="solid")
ws_stats.append(["Parameter", "Value"])
ws_stats["A1"].font = header_font
ws_stats["A1"].fill = header_fill
ws_stats["B1"].font = header_font
ws_stats["B1"].fill = header_fill
stats_rows = [
("Source File", stats.get("source_file", "")),
("Total Rows", stats.get("rows_total", 0)),
("Filtered Rows", stats.get("rows_kept", 0)),
("Groups Used", stats.get("groups_used", 0)),
("Mean Value", f"{stats.get('mean_value', 0):.3f}"),
("Threshold", f"{stats.get('threshold', 0):.3f}"),
("Process Time", stats.get("process_time", "")),
]
for label, value in stats_rows:
ws_stats.append([label, value])
ws_stats.column_dimensions["A"].width = 18
ws_stats.column_dimensions["B"].width = 30
out = path.with_name(f"{path.stem}_filtered.xlsx")
wb.save(str(out))
return out
def process_csv(csv_path: str, save_as_xlsx: bool) -> dict:
"""执行梯度过滤,返回结果统计"""
path = Path(csv_path)
if not path.is_file():
raise FileNotFoundError(f"CSV file not found: {csv_path}")
rows = load_rows(path)
if not rows:
raise ValueError("CSV file is empty.")
sum_values = [row_sum(r) for r in rows]
threshold = find_threshold(sum_values)
filtered_rows, group_means = extract_press_groups(rows, sum_values, threshold)
if not filtered_rows:
raise ValueError("No large press-down data was found.")
overall_mean = statistics.fmean(group_means)
stats = {
"source_file": path.name,
"rows_total": len(rows),
"rows_kept": len(filtered_rows),
"groups_used": len(group_means),
"mean_value": overall_mean,
"threshold": threshold,
"process_time": time.strftime("%Y-%m-%d %H:%M:%S"),
}
if save_as_xlsx:
output_path = write_xlsx(path, filtered_rows, stats)
# 删除源 CSV
try:
path.unlink()
except OSError:
pass
else:
output_path = write_csv(path, filtered_rows)
# 用过滤后的文件替换源文件
try:
path.unlink()
output_path.rename(path)
output_path = path
except OSError:
pass
# 追加一行到汇总 xlsx
_append_analysis_log(csv_path, stats)
return {
"ok": True,
"output_path": str(output_path),
"groups_used": len(group_means),
"mean_value": overall_mean,
"threshold": threshold,
"rows_total": len(rows),
"rows_kept": len(filtered_rows),
"message": "OK",
}
def _append_analysis_log(source_csv: str, stats: dict):
"""将处理结果追加到 devkit_analysis_results.xlsx"""
try:
import openpyxl
except ImportError:
return # openpyxl 不可用时跳过
log_path = Path(source_csv).parent / "devkit_analysis_results.xlsx"
if log_path.exists():
wb = openpyxl.load_workbook(str(log_path))
ws = wb.active
else:
wb = openpyxl.Workbook()
ws = wb.active
ws.title = "Analysis Log"
ws.append(["Time", "Source File", "Total Rows", "Kept Rows",
"Groups", "Mean Value", "Threshold", "Output File"])
ws.append([
stats.get("process_time", ""),
stats.get("source_file", ""),
stats.get("rows_total", 0),
stats.get("rows_kept", 0),
stats.get("groups_used", 0),
round(stats.get("mean_value", 0), 3),
round(stats.get("threshold", 0), 3),
f"{Path(stats.get('source_file', '')).stem}_filtered",
])
wb.save(str(log_path))
# ── gRPC 服务实现 ────────────────────────────────────────────────
class SensorPushServicer(sensor_stream_pb2_grpc.SensorPushServicer):
"""接收实时传感器帧streaming"""
def __init__(self):
self.frame_count = 0
self.last_report_time = time.time()
def Upload(self, request_iterator, context):
print("[SensorPush] Client connected, waiting for frames...")
for frame in request_iterator:
self.frame_count += 1
if self.frame_count % 100 == 0:
now = time.time()
elapsed = now - self.last_report_time
fps = 100 / elapsed if elapsed > 0 else 0
self.last_report_time = now
print(
f"[SensorPush] Frame #{frame.seq} | "
f"{frame.rows}x{frame.cols} | "
f"force={frame.resultant_force:.1f} | "
f"total={self.frame_count} | ~{fps:.1f} fps"
)
print(f"[SensorPush] Stream ended. Total: {self.frame_count}")
return sensor_stream_pb2.UploadResponse(
ok=True,
frames_received=self.frame_count,
message=f"Processed {self.frame_count} frames",
)
class ExportProcessorServicer(sensor_stream_pb2_grpc.ExportProcessorServicer):
"""处理导出的 CSV 文件unary"""
def ProcessFile(self, request, context):
csv_path = request.csv_path
save_as_xlsx = request.save_as_xlsx
print(f"[ExportProcessor] Processing: {csv_path} (xlsx={save_as_xlsx})")
try:
result = process_csv(csv_path, save_as_xlsx)
return sensor_stream_pb2.ProcessResponse(
ok=result["ok"],
output_path=result["output_path"],
groups_used=result["groups_used"],
mean_value=result["mean_value"],
threshold=result["threshold"],
rows_total=result["rows_total"],
rows_kept=result["rows_kept"],
message=result["message"],
)
except Exception as e:
print(f"[ExportProcessor] Error: {e}")
return sensor_stream_pb2.ProcessResponse(
ok=False,
output_path="",
message=str(e),
)
# ── 启动 ────────────────────────────────────────────────────────
def serve(port: int):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
sensor_stream_pb2_grpc.add_SensorPushServicer_to_server(SensorPushServicer(), server)
sensor_stream_pb2_grpc.add_ExportProcessorServicer_to_server(ExportProcessorServicer(), server)
listen_addr = f"0.0.0.0:{port}"
server.add_insecure_port(listen_addr)
server.start()
print(f"[DevKit Server] gRPC listening on {listen_addr}")
print(f"[DevKit Server] Services: SensorPush (streaming), ExportProcessor (unary)")
def shutdown(signum, frame):
print("\n[DevKit Server] Shutting down...")
server.stop(grace=5)
sys.exit(0)
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
server.wait_for_termination()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="JE-Skin DevKit gRPC Server")
parser.add_argument("--port", type=int, default=50051, help="gRPC listen port (default: 50051)")
args = parser.parse_args()
serve(args.port)