Files
AIclinicalresearch/r-statistics-service/plumber.R
HaHafeng 52989cd03f feat(ssa): Agent channel UX optimization (Solution B) + Plan-and-Execute architecture design
SSA Agent channel improvements (12 code files, +931/-203 lines):
- Solution B: left/right separation of concerns (gaze guiding + state mutex + time-travel)
- JWT token refresh mechanism (ensureFreshToken) to fix HTTP 401 during pipeline
- Code truncation fix: LLM maxTokens 4000->8000 + CSS max-height 60vh
- Retry streaming code generation with generateCodeStream()
- R Docker structured errors: 20+ pattern matching + format_agent_error + line extraction
- Prompt iron rules: strict output format in CoderAgent System Prompt
- parseCode robustness: XML/Markdown/inference 3-tier matching + length validation
- consoleOutput type defense: handle both array and scalar from R Docker unboxedJSON
- Agent progress bar sync: derive phase from agentExecution.status
- Export report / view code buttons restored for Agent mode
- ExecutingProgress component: real-time timer + dynamic tips + step pulse animation

Architecture design (3 review reports):
- Plan-and-Execute step-by-step execution architecture approved
- Code accumulation strategy (R Docker stays stateless)
- 5 engineering guardrails: XML tags, AST pre-check, defensive prompts, high-fidelity schema, error classification circuit breaker

Docs: update SSA module status v4.1, system status v6.7, deployment changelist
Made-with: Cursor
2026-03-07 22:32:32 +08:00

412 lines
12 KiB
R
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# plumber.R
# SSA-Pro R Statistics Service 入口文件
#
# 安全与性能优化:
# - 生产环境预加载所有工具脚本
# - tool_code 白名单正则校验(防止路径遍历攻击)
library(plumber)
library(jsonlite)
# 环境配置
DEV_MODE <- Sys.getenv("DEV_MODE", "false") == "true"
# 加载公共函数
source("utils/error_codes.R")
source("utils/data_loader.R")
source("utils/guardrails.R")
source("utils/result_formatter.R")
source("utils/block_helpers.R")
# 工具目录
tools_dir <- "tools"
tool_files <- list.files(tools_dir, pattern = "\\.R$", full.names = TRUE)
# ========== 生产环境预加载优化 ==========
# 在服务启动时预加载所有工具脚本到独立环境
# 避免每次请求都从磁盘读取和解析
# 工具缓存环境
TOOL_CACHE <- new.env(parent = emptyenv())
# 预加载函数
preload_tools <- function() {
message("[Init] 预加载工具脚本...")
for (f in tool_files) {
tool_name <- tools::file_path_sans_ext(basename(f))
# 创建独立环境加载工具
tool_env <- new.env(parent = globalenv())
source(f, local = tool_env)
# 检查是否实现了 run_analysis
if (exists("run_analysis", envir = tool_env, mode = "function")) {
TOOL_CACHE[[tool_name]] <- tool_env$run_analysis
message(paste("[Init] 已加载:", tool_name))
} else {
warning(paste("[Init] 工具缺少 run_analysis 函数:", tool_name))
}
}
message(paste("[Init] 预加载完成,共", length(ls(TOOL_CACHE)), "个工具"))
}
# 生产环境:启动时预加载
# 开发环境:跳过(支持热重载)
if (!DEV_MODE) {
preload_tools()
} else {
message("[Init] DEV_MODE 启用,跳过预加载(支持热重载)")
# 开发模式仍需首次加载
for (f in tool_files) source(f)
}
# ========== 安全校验函数 ==========
#' 校验 tool_code 格式(防止路径遍历攻击)
#' @param tool_code 工具代码
#' @return TRUE 如果格式合法,否则 FALSE
validate_tool_code <- function(tool_code) {
# 只允许:大写字母、数字、下划线
# 有效示例ST_T_TEST_IND, ST_ANOVA, T_TEST_IND
# 无效示例:../etc/passwd, ST_TEST;rm -rf
pattern <- "^[A-Z][A-Z0-9_]*$"
return(grepl(pattern, tool_code))
}
#' 将 tool_code 转换为工具名(小写,去除 ST_ 前缀)
#' @param tool_code 例如 "ST_T_TEST_IND"
#' @return 例如 "t_test_ind"
normalize_tool_name <- function(tool_code) {
name <- tolower(gsub("^ST_", "", tool_code))
return(name)
}
# ========== API 定义 ==========
#* @apiTitle SSA-Pro R Statistics Service
#* @apiDescription 严谨型统计分析 R 引擎
#* 健康检查
#* @get /health
function() {
list(
status = "ok",
timestamp = Sys.time(),
version = "1.0.1",
dev_mode = DEV_MODE,
tools_loaded = if (DEV_MODE) length(tool_files) else length(ls(TOOL_CACHE))
)
}
#* 列出已加载的工具
#* @get /api/v1/tools
function() {
if (DEV_MODE) {
tools <- gsub("\\.R$", "", basename(tool_files))
} else {
tools <- ls(TOOL_CACHE)
}
list(
status = "ok",
tools = tools,
count = length(tools)
)
}
#* JIT Guardrails Check
#* @post /api/v1/guardrails/jit
#* @serializer unboxedJSON
function(req) {
tryCatch({
input <- jsonlite::fromJSON(req$postBody, simplifyVector = FALSE)
# 必需参数
tool_code <- input$tool_code
params <- input$params
if (is.null(tool_code)) {
return(list(
status = "error",
error_code = "E400",
message = "Missing tool_code parameter"
))
}
# 加载数据
df <- tryCatch(
load_input_data(input),
error = function(e) {
return(NULL)
}
)
if (is.null(df)) {
return(list(
status = "error",
error_code = "E100",
message = "Failed to load data for guardrail checks"
))
}
# 执行 JIT 护栏检查
result <- run_jit_guardrails(df, tool_code, params)
return(list(
status = "success",
checks = result$checks,
suggested_tool = result$suggested_tool,
can_proceed = result$can_proceed,
all_checks_passed = result$all_checks_passed
))
}, error = function(e) {
return(map_r_error(e$message))
})
}
#* Agent 通道:执行任意 R 代码(沙箱模式)
#* @post /api/v1/execute-code
#* @serializer unboxedJSON
function(req) {
tryCatch({
input <- jsonlite::fromJSON(req$postBody, simplifyVector = FALSE)
code <- input$code
session_id <- input$session_id
timeout_sec <- as.numeric(input$timeout %||% 120)
if (is.null(code) || nchar(trimws(code)) == 0) {
return(list(
status = "error",
error_code = "E400",
message = "Missing 'code' parameter",
user_hint = "R 代码不能为空"
))
}
# 安全限制:最长 120 秒
if (timeout_sec > 120) timeout_sec <- 120
message(glue::glue("[ExecuteCode] session={session_id}, code_length={nchar(code)}, timeout={timeout_sec}s"))
sandbox_env <- new.env(parent = globalenv())
if (!is.null(session_id) && nchar(session_id) > 0) {
sandbox_env$SESSION_ID <- session_id
}
start_time <- proc.time()
collected_warnings <- list()
collected_messages <- character(0)
output_capture <- tryCatch(
withTimeout(
{
captured_output <- utils::capture.output({
result <- withCallingHandlers(
eval(parse(text = code), envir = sandbox_env),
warning = function(w) {
collected_warnings[[length(collected_warnings) + 1]] <<- w$message
invokeRestart("muffleWarning")
},
message = function(m) {
collected_messages <<- c(collected_messages, conditionMessage(m))
invokeRestart("muffleMessage")
}
)
})
list(
result = result,
output = captured_output,
warnings = collected_warnings,
messages = collected_messages,
error = NULL
)
},
timeout = timeout_sec,
onTimeout = "error"
),
error = function(e) {
error_info <- format_agent_error(e, code, collected_warnings, collected_messages)
list(
result = NULL,
output = NULL,
warnings = collected_warnings,
messages = collected_messages,
error = error_info$message,
error_detail = error_info
)
}
)
elapsed_ms <- round((proc.time() - start_time)["elapsed"] * 1000)
if (!is.null(output_capture$error)) {
detail <- output_capture$error_detail
message(glue::glue("[ExecuteCode] ERROR after {elapsed_ms}ms: {output_capture$error}"))
console_lines <- c(
output_capture$output,
if (length(output_capture$warnings) > 0) paste0("Warning: ", output_capture$warnings),
if (length(output_capture$messages) > 0) output_capture$messages
)
return(list(
status = "error",
error_code = if (!is.null(detail)) detail$error_code else "E_EXEC",
error_type = if (!is.null(detail)) detail$error_type else "runtime",
message = output_capture$error,
user_hint = if (!is.null(detail)) detail$user_hint else output_capture$error,
error_line = if (!is.null(detail)) detail$error_line else NULL,
error_context = if (!is.null(detail)) detail$error_context else NULL,
console_output = console_lines,
duration_ms = elapsed_ms
))
}
message(glue::glue("[ExecuteCode] SUCCESS in {elapsed_ms}ms"))
final_result <- output_capture$result
console_lines <- c(
output_capture$output,
if (length(output_capture$warnings) > 0) paste0("Warning: ", output_capture$warnings),
if (length(output_capture$messages) > 0) output_capture$messages
)
if (is.list(final_result) && !is.null(final_result$report_blocks)) {
return(list(
status = "success",
result = final_result,
console_output = console_lines,
duration_ms = elapsed_ms
))
}
return(list(
status = "success",
result = list(
data = final_result,
report_blocks = list()
),
console_output = console_lines,
duration_ms = elapsed_ms
))
}, error = function(e) {
message(glue::glue("[ExecuteCode] FATAL ERROR: {e$message}"))
return(map_r_error(e$message))
})
}
#' 超时执行包装器
#' @param expr 表达式
#' @param timeout 超时秒数
#' @param onTimeout 超时行为
withTimeout <- function(expr, timeout = 120, onTimeout = "error") {
setTimeLimit(cpu = timeout, elapsed = timeout, transient = TRUE)
on.exit(setTimeLimit(cpu = Inf, elapsed = Inf, transient = FALSE))
eval(expr, envir = parent.frame())
}
#* 执行统计工具
#* @post /api/v1/skills/<tool_code>
#* @param tool_code:str 工具代码(如 ST_T_TEST_IND
#* @serializer unboxedJSON
function(req, tool_code) {
tryCatch({
# ===== 安全校验tool_code 白名单 =====
if (!validate_tool_code(tool_code)) {
return(list(
status = "error",
error_code = "E400",
message = "Invalid tool code format",
user_hint = "工具代码格式错误,只允许大写字母、数字和下划线"
))
}
# 解析请求体
input <- jsonlite::fromJSON(req$postBody, simplifyVector = FALSE)
# 记录传入参数(便于调试)
param_names <- if (!is.null(input$params)) paste(names(input$params), collapse=", ") else "NULL"
message(glue::glue("[Skill:{tool_code}] params keys: [{param_names}]"))
if (!is.null(input$params$variables)) {
message(glue::glue("[Skill:{tool_code}] variables ({length(input$params$variables)}): [{paste(input$params$variables, collapse=', ')}]"))
}
if (!is.null(input$params$group_var)) {
message(glue::glue("[Skill:{tool_code}] group_var: {input$params$group_var}"))
}
# Debug 模式:保留临时文件用于排查
debug_mode <- isTRUE(input$debug)
# 统一入口函数名
func_name <- "run_analysis"
# 标准化工具名
tool_name <- normalize_tool_name(tool_code)
tool_file <- file.path("tools", paste0(tool_name, ".R"))
# ===== 根据环境选择加载策略 =====
if (DEV_MODE) {
# 开发环境:每次请求重新加载(支持热重载)
if (!file.exists(tool_file)) {
return(list(
status = "error",
error_code = "E100",
message = paste("Unknown tool:", tool_code),
user_hint = "请检查工具代码是否正确"
))
}
source(tool_file)
if (!exists(func_name, mode = "function")) {
return(list(
status = "error",
error_code = "E100",
message = paste("Tool", tool_code, "does not implement run_analysis()"),
user_hint = "工具脚本格式错误,请联系管理员"
))
}
# 执行分析
result <- do.call(func_name, list(input))
} else {
# 生产环境:从缓存加载
if (!exists(tool_name, envir = TOOL_CACHE)) {
return(list(
status = "error",
error_code = "E100",
message = paste("Unknown tool:", tool_code),
user_hint = "请检查工具代码是否正确,或联系管理员确认工具已部署"
))
}
# 从缓存获取函数并执行
cached_func <- TOOL_CACHE[[tool_name]]
result <- cached_func(input)
}
# Debug 模式:附加临时文件路径
if (debug_mode && !is.null(result$tmp_files)) {
result$debug_files <- result$tmp_files
message("[DEBUG] 临时文件已保留: ", paste(result$tmp_files, collapse = ", "))
}
return(result)
}, error = function(e) {
message(glue::glue("[Skill:{tool_code}] ERROR: {e$message}"))
return(map_r_error(e$message))
})
}