Align Agent mode to strict stepwise generation and execution, add deterministic and safety hardening, and sync deployment/module documentation for Phase 5A.5/5B/5C rollout. - implement strict stepwise execution path and dependency short-circuiting - persist step-level errors/results and stream step_* progress events - add agent plan params patch route and schema/migration support - improve R sanitizer/security checks and step result rendering in workspace - update SSA module guide and deployment change checklist Made-with: Cursor
498 lines
16 KiB
R
498 lines
16 KiB
R
# plumber.R
|
||
# SSA-Pro R Statistics Service 入口文件
|
||
#
|
||
# 安全与性能优化:
|
||
# - 生产环境预加载所有工具脚本
|
||
# - tool_code 白名单正则校验(防止路径遍历攻击)
|
||
|
||
library(plumber)
|
||
library(jsonlite)
|
||
|
||
# 环境配置
|
||
DEV_MODE <- Sys.getenv("DEV_MODE", "false") == "true"
|
||
|
||
# 空值合并操作符(避免 `%||%` 未定义导致 execute-code 入口报错)
|
||
`%||%` <- function(x, y) if (is.null(x)) y else x
|
||
|
||
# 加载公共函数
|
||
source("utils/error_codes.R")
|
||
source("utils/data_loader.R")
|
||
source("utils/guardrails.R")
|
||
source("utils/result_formatter.R")
|
||
source("utils/block_helpers.R")
|
||
|
||
# 工具目录
|
||
tools_dir <- "tools"
|
||
tool_files <- list.files(tools_dir, pattern = "\\.R$", full.names = TRUE)
|
||
|
||
# ========== 生产环境预加载优化 ==========
|
||
# 在服务启动时预加载所有工具脚本到独立环境
|
||
# 避免每次请求都从磁盘读取和解析
|
||
|
||
# 工具缓存环境
|
||
TOOL_CACHE <- new.env(parent = emptyenv())
|
||
|
||
# 预加载函数
|
||
preload_tools <- function() {
|
||
message("[Init] 预加载工具脚本...")
|
||
|
||
for (f in tool_files) {
|
||
tool_name <- tools::file_path_sans_ext(basename(f))
|
||
|
||
# 创建独立环境加载工具
|
||
tool_env <- new.env(parent = globalenv())
|
||
source(f, local = tool_env)
|
||
|
||
# 检查是否实现了 run_analysis
|
||
if (exists("run_analysis", envir = tool_env, mode = "function")) {
|
||
TOOL_CACHE[[tool_name]] <- tool_env$run_analysis
|
||
message(paste("[Init] 已加载:", tool_name))
|
||
} else {
|
||
warning(paste("[Init] 工具缺少 run_analysis 函数:", tool_name))
|
||
}
|
||
}
|
||
|
||
message(paste("[Init] 预加载完成,共", length(ls(TOOL_CACHE)), "个工具"))
|
||
}
|
||
|
||
# 生产环境:启动时预加载
|
||
# 开发环境:跳过(支持热重载)
|
||
if (!DEV_MODE) {
|
||
preload_tools()
|
||
} else {
|
||
message("[Init] DEV_MODE 启用,跳过预加载(支持热重载)")
|
||
# 开发模式仍需首次加载
|
||
for (f in tool_files) source(f)
|
||
}
|
||
|
||
# ========== 安全校验函数 ==========
|
||
|
||
#' 校验 tool_code 格式(防止路径遍历攻击)
|
||
#' @param tool_code 工具代码
|
||
#' @return TRUE 如果格式合法,否则 FALSE
|
||
validate_tool_code <- function(tool_code) {
|
||
# 只允许:大写字母、数字、下划线
|
||
# 有效示例:ST_T_TEST_IND, ST_ANOVA, T_TEST_IND
|
||
# 无效示例:../etc/passwd, ST_TEST;rm -rf
|
||
pattern <- "^[A-Z][A-Z0-9_]*$"
|
||
return(grepl(pattern, tool_code))
|
||
}
|
||
|
||
#' 将 tool_code 转换为工具名(小写,去除 ST_ 前缀)
|
||
#' @param tool_code 例如 "ST_T_TEST_IND"
|
||
#' @return 例如 "t_test_ind"
|
||
normalize_tool_name <- function(tool_code) {
|
||
name <- tolower(gsub("^ST_", "", tool_code))
|
||
return(name)
|
||
}
|
||
|
||
# ========== API 定义 ==========
|
||
|
||
#* @apiTitle SSA-Pro R Statistics Service
|
||
#* @apiDescription 严谨型统计分析 R 引擎
|
||
|
||
#* 健康检查
|
||
#* @get /health
|
||
function() {
|
||
list(
|
||
status = "ok",
|
||
timestamp = Sys.time(),
|
||
version = "1.0.1",
|
||
dev_mode = DEV_MODE,
|
||
tools_loaded = if (DEV_MODE) length(tool_files) else length(ls(TOOL_CACHE))
|
||
)
|
||
}
|
||
|
||
#* 列出已加载的工具
|
||
#* @get /api/v1/tools
|
||
function() {
|
||
if (DEV_MODE) {
|
||
tools <- gsub("\\.R$", "", basename(tool_files))
|
||
} else {
|
||
tools <- ls(TOOL_CACHE)
|
||
}
|
||
|
||
list(
|
||
status = "ok",
|
||
tools = tools,
|
||
count = length(tools)
|
||
)
|
||
}
|
||
|
||
#* 诊断:返回 R 运行时包清单(只读)
|
||
#* @get /api/v1/debug/packages
|
||
#* @serializer unboxedJSON
|
||
function() {
|
||
required_packages <- c(
|
||
"plumber", "jsonlite", "ggplot2", "glue", "dplyr", "tidyr",
|
||
"base64enc", "yaml", "car", "httr", "scales", "gridExtra",
|
||
"gtsummary", "gt", "broom", "meta"
|
||
)
|
||
|
||
installed <- rownames(installed.packages())
|
||
missing <- setdiff(required_packages, installed)
|
||
|
||
list(
|
||
status = "ok",
|
||
r_version = R.version.string,
|
||
dev_mode = DEV_MODE,
|
||
lib_paths = .libPaths(),
|
||
required_count = length(required_packages),
|
||
installed_count = length(installed),
|
||
missing_required = missing,
|
||
required_status = if (length(missing) == 0) "complete" else "incomplete",
|
||
sample_installed = head(sort(installed), 120)
|
||
)
|
||
}
|
||
|
||
#* JIT Guardrails Check
|
||
#* @post /api/v1/guardrails/jit
|
||
#* @serializer unboxedJSON
|
||
function(req) {
|
||
tryCatch({
|
||
input <- jsonlite::fromJSON(req$postBody, simplifyVector = FALSE)
|
||
|
||
# 必需参数
|
||
tool_code <- input$tool_code
|
||
params <- input$params
|
||
|
||
if (is.null(tool_code)) {
|
||
return(list(
|
||
status = "error",
|
||
error_code = "E400",
|
||
message = "Missing tool_code parameter"
|
||
))
|
||
}
|
||
|
||
# 加载数据
|
||
df <- tryCatch(
|
||
load_input_data(input),
|
||
error = function(e) {
|
||
return(NULL)
|
||
}
|
||
)
|
||
|
||
if (is.null(df)) {
|
||
return(list(
|
||
status = "error",
|
||
error_code = "E100",
|
||
message = "Failed to load data for guardrail checks"
|
||
))
|
||
}
|
||
|
||
# 执行 JIT 护栏检查
|
||
result <- run_jit_guardrails(df, tool_code, params)
|
||
|
||
return(list(
|
||
status = "success",
|
||
checks = result$checks,
|
||
suggested_tool = result$suggested_tool,
|
||
can_proceed = result$can_proceed,
|
||
all_checks_passed = result$all_checks_passed
|
||
))
|
||
|
||
}, error = function(e) {
|
||
return(map_r_error(e$message))
|
||
})
|
||
}
|
||
|
||
#* Agent 通道:执行任意 R 代码(沙箱模式)
|
||
#* @post /api/v1/execute-code
|
||
#* @serializer unboxedJSON
|
||
function(req) {
|
||
tryCatch({
|
||
input <- jsonlite::fromJSON(req$postBody, simplifyVector = FALSE)
|
||
|
||
code <- input$code
|
||
session_id <- input$session_id
|
||
timeout_sec <- as.numeric(input$timeout %||% 120)
|
||
|
||
if (is.null(code) || nchar(trimws(code)) == 0) {
|
||
return(list(
|
||
status = "error",
|
||
error_code = "E400",
|
||
message = "Missing 'code' parameter",
|
||
user_hint = "R 代码不能为空"
|
||
))
|
||
}
|
||
|
||
# 安全限制:最长 120 秒
|
||
if (timeout_sec > 120) timeout_sec <- 120
|
||
|
||
message(glue::glue("[ExecuteCode] session={session_id}, code_length={nchar(code)}, timeout={timeout_sec}s"))
|
||
|
||
# ── AST + 安全双层预检:语法检查 + 危险调用拦截 ──
|
||
parsed_code <- NULL
|
||
ast_check <- tryCatch({
|
||
parsed_code <<- parse(text = code)
|
||
NULL
|
||
}, error = function(e) {
|
||
e$message
|
||
})
|
||
|
||
if (!is.null(ast_check)) {
|
||
line_match <- regmatches(ast_check, regexpr("\\d+:\\d+", ast_check))
|
||
error_line <- if (length(line_match) > 0) as.integer(sub(":.*", "", line_match)) else NULL
|
||
|
||
code_lines <- strsplit(code, "\n")[[1]]
|
||
context_lines <- if (!is.null(error_line) && error_line > 0 && error_line <= length(code_lines)) {
|
||
start_l <- max(1, error_line - 2)
|
||
end_l <- min(length(code_lines), error_line + 2)
|
||
paste(sprintf("%3d| %s", start_l:end_l, code_lines[start_l:end_l]), collapse = "\n")
|
||
} else NULL
|
||
|
||
return(list(
|
||
status = "error",
|
||
error_code = "E_SYNTAX",
|
||
error_type = "syntax",
|
||
message = paste0("R 语法错误(代码无法解析): ", ast_check),
|
||
user_hint = "代码存在语法错误,请检查括号/引号是否匹配、运算符是否正确",
|
||
error_line = error_line,
|
||
error_context = context_lines,
|
||
console_output = list(),
|
||
duration_ms = 0
|
||
))
|
||
}
|
||
|
||
# 安全预检(静态扫描,MVP)
|
||
# 注:为减少误报,先粗略移除注释行再扫描
|
||
code_for_scan <- gsub("(?m)^\\s*#.*$", "", code, perl = TRUE)
|
||
forbidden_pattern <- "(^|[^[:alnum:]_\\.])((base::)?system|(base::)?eval|(base::)?parse|(base::)?source|file\\.remove|setwd|download\\.file|readLines|writeLines)\\s*\\("
|
||
security_hit <- regexpr(forbidden_pattern, code_for_scan, perl = TRUE, ignore.case = TRUE)
|
||
if (security_hit[1] != -1) {
|
||
hit_text <- regmatches(code_for_scan, security_hit)[1]
|
||
return(list(
|
||
status = "error",
|
||
error_code = "E_SECURITY",
|
||
error_type = "security",
|
||
message = paste0("Security Violation: Detected forbidden function call: ", hit_text),
|
||
user_hint = "代码包含高风险函数调用(如 system/eval/source/file.remove/setwd),已被系统拦截",
|
||
console_output = list(),
|
||
duration_ms = 0
|
||
))
|
||
}
|
||
|
||
sandbox_env <- new.env(parent = globalenv())
|
||
# 运行时保护:即使静态扫描漏检,也在沙箱层阻断关键高风险调用
|
||
sandbox_env$system <- function(...) stop("Security Violation: function 'system' is forbidden.")
|
||
sandbox_env$eval <- function(...) stop("Security Violation: function 'eval' is forbidden.")
|
||
sandbox_env$source <- function(...) stop("Security Violation: function 'source' is forbidden.")
|
||
sandbox_env$setwd <- function(...) stop("Security Violation: function 'setwd' is forbidden.")
|
||
sandbox_env$file.remove <- function(...) stop("Security Violation: function 'file.remove' is forbidden.")
|
||
|
||
if (!is.null(session_id) && nchar(session_id) > 0) {
|
||
sandbox_env$SESSION_ID <- session_id
|
||
}
|
||
|
||
start_time <- proc.time()
|
||
|
||
collected_warnings <- list()
|
||
collected_messages <- character(0)
|
||
|
||
output_capture <- tryCatch(
|
||
withTimeout(
|
||
{
|
||
captured_output <- utils::capture.output({
|
||
result <- withCallingHandlers(
|
||
eval(parsed_code, envir = sandbox_env),
|
||
warning = function(w) {
|
||
collected_warnings[[length(collected_warnings) + 1]] <<- w$message
|
||
invokeRestart("muffleWarning")
|
||
},
|
||
message = function(m) {
|
||
collected_messages <<- c(collected_messages, conditionMessage(m))
|
||
invokeRestart("muffleMessage")
|
||
}
|
||
)
|
||
})
|
||
|
||
list(
|
||
result = result,
|
||
output = captured_output,
|
||
warnings = collected_warnings,
|
||
messages = collected_messages,
|
||
error = NULL
|
||
)
|
||
},
|
||
timeout = timeout_sec,
|
||
onTimeout = "error"
|
||
),
|
||
error = function(e) {
|
||
error_info <- format_agent_error(e, code, collected_warnings, collected_messages)
|
||
list(
|
||
result = NULL,
|
||
output = NULL,
|
||
warnings = collected_warnings,
|
||
messages = collected_messages,
|
||
error = error_info$message,
|
||
error_detail = error_info
|
||
)
|
||
}
|
||
)
|
||
|
||
elapsed_ms <- round((proc.time() - start_time)["elapsed"] * 1000)
|
||
|
||
if (!is.null(output_capture$error)) {
|
||
detail <- output_capture$error_detail
|
||
message(glue::glue("[ExecuteCode] ERROR after {elapsed_ms}ms: {output_capture$error}"))
|
||
|
||
console_lines <- c(
|
||
output_capture$output,
|
||
if (length(output_capture$warnings) > 0) paste0("Warning: ", output_capture$warnings),
|
||
if (length(output_capture$messages) > 0) output_capture$messages
|
||
)
|
||
|
||
return(list(
|
||
status = "error",
|
||
error_code = if (!is.null(detail)) detail$error_code else "E_EXEC",
|
||
error_type = if (!is.null(detail)) detail$error_type else "runtime",
|
||
message = output_capture$error,
|
||
user_hint = if (!is.null(detail)) detail$user_hint else output_capture$error,
|
||
error_line = if (!is.null(detail)) detail$error_line else NULL,
|
||
error_context = if (!is.null(detail)) detail$error_context else NULL,
|
||
console_output = console_lines,
|
||
duration_ms = elapsed_ms
|
||
))
|
||
}
|
||
|
||
message(glue::glue("[ExecuteCode] SUCCESS in {elapsed_ms}ms"))
|
||
|
||
final_result <- output_capture$result
|
||
|
||
console_lines <- c(
|
||
output_capture$output,
|
||
if (length(output_capture$warnings) > 0) paste0("Warning: ", output_capture$warnings),
|
||
if (length(output_capture$messages) > 0) output_capture$messages
|
||
)
|
||
|
||
if (is.list(final_result) && !is.null(final_result$report_blocks)) {
|
||
return(list(
|
||
status = "success",
|
||
result = final_result,
|
||
console_output = console_lines,
|
||
duration_ms = elapsed_ms
|
||
))
|
||
}
|
||
|
||
return(list(
|
||
status = "success",
|
||
result = list(
|
||
data = final_result,
|
||
report_blocks = list()
|
||
),
|
||
console_output = console_lines,
|
||
duration_ms = elapsed_ms
|
||
))
|
||
|
||
}, error = function(e) {
|
||
message(glue::glue("[ExecuteCode] FATAL ERROR: {e$message}"))
|
||
return(map_r_error(e$message))
|
||
})
|
||
}
|
||
|
||
#' 超时执行包装器
|
||
#' @param expr 表达式
|
||
#' @param timeout 超时秒数
|
||
#' @param onTimeout 超时行为
|
||
withTimeout <- function(expr, timeout = 120, onTimeout = "error") {
|
||
setTimeLimit(cpu = timeout, elapsed = timeout, transient = TRUE)
|
||
on.exit(setTimeLimit(cpu = Inf, elapsed = Inf, transient = FALSE))
|
||
eval(expr, envir = parent.frame())
|
||
}
|
||
|
||
#* 执行统计工具
|
||
#* @post /api/v1/skills/<tool_code>
|
||
#* @param tool_code:str 工具代码(如 ST_T_TEST_IND)
|
||
#* @serializer unboxedJSON
|
||
function(req, tool_code) {
|
||
tryCatch({
|
||
|
||
# ===== 安全校验:tool_code 白名单 =====
|
||
if (!validate_tool_code(tool_code)) {
|
||
return(list(
|
||
status = "error",
|
||
error_code = "E400",
|
||
message = "Invalid tool code format",
|
||
user_hint = "工具代码格式错误,只允许大写字母、数字和下划线"
|
||
))
|
||
}
|
||
|
||
# 解析请求体
|
||
input <- jsonlite::fromJSON(req$postBody, simplifyVector = FALSE)
|
||
|
||
# 记录传入参数(便于调试)
|
||
param_names <- if (!is.null(input$params)) paste(names(input$params), collapse=", ") else "NULL"
|
||
message(glue::glue("[Skill:{tool_code}] params keys: [{param_names}]"))
|
||
if (!is.null(input$params$variables)) {
|
||
message(glue::glue("[Skill:{tool_code}] variables ({length(input$params$variables)}): [{paste(input$params$variables, collapse=', ')}]"))
|
||
}
|
||
if (!is.null(input$params$group_var)) {
|
||
message(glue::glue("[Skill:{tool_code}] group_var: {input$params$group_var}"))
|
||
}
|
||
|
||
# Debug 模式:保留临时文件用于排查
|
||
debug_mode <- isTRUE(input$debug)
|
||
|
||
# 统一入口函数名
|
||
func_name <- "run_analysis"
|
||
|
||
# 标准化工具名
|
||
tool_name <- normalize_tool_name(tool_code)
|
||
tool_file <- file.path("tools", paste0(tool_name, ".R"))
|
||
|
||
# ===== 根据环境选择加载策略 =====
|
||
if (DEV_MODE) {
|
||
# 开发环境:每次请求重新加载(支持热重载)
|
||
if (!file.exists(tool_file)) {
|
||
return(list(
|
||
status = "error",
|
||
error_code = "E100",
|
||
message = paste("Unknown tool:", tool_code),
|
||
user_hint = "请检查工具代码是否正确"
|
||
))
|
||
}
|
||
|
||
source(tool_file)
|
||
|
||
if (!exists(func_name, mode = "function")) {
|
||
return(list(
|
||
status = "error",
|
||
error_code = "E100",
|
||
message = paste("Tool", tool_code, "does not implement run_analysis()"),
|
||
user_hint = "工具脚本格式错误,请联系管理员"
|
||
))
|
||
}
|
||
|
||
# 执行分析
|
||
result <- do.call(func_name, list(input))
|
||
|
||
} else {
|
||
# 生产环境:从缓存加载
|
||
if (!exists(tool_name, envir = TOOL_CACHE)) {
|
||
return(list(
|
||
status = "error",
|
||
error_code = "E100",
|
||
message = paste("Unknown tool:", tool_code),
|
||
user_hint = "请检查工具代码是否正确,或联系管理员确认工具已部署"
|
||
))
|
||
}
|
||
|
||
# 从缓存获取函数并执行
|
||
cached_func <- TOOL_CACHE[[tool_name]]
|
||
result <- cached_func(input)
|
||
}
|
||
|
||
# Debug 模式:附加临时文件路径
|
||
if (debug_mode && !is.null(result$tmp_files)) {
|
||
result$debug_files <- result$tmp_files
|
||
message("[DEBUG] 临时文件已保留: ", paste(result$tmp_files, collapse = ", "))
|
||
}
|
||
|
||
return(result)
|
||
|
||
}, error = function(e) {
|
||
message(glue::glue("[Skill:{tool_code}] ERROR: {e$message}"))
|
||
return(map_r_error(e$message))
|
||
})
|
||
}
|