diff --git a/common/util/mapping.go b/common/util/mapping.go index d229b4c..85547dd 100644 --- a/common/util/mapping.go +++ b/common/util/mapping.go @@ -19,17 +19,20 @@ import ( tgjson "github.com/tidwall/gjson" ) -// ParseAndValidate 解析并校验结果 +// ParseAndValidate 解析模型响应,并返回标准格式 func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) { - // 1) 解析 content 字符串为 rounds 数组 - contentVal, ok := raw[model.ResponseBody] - if !ok { - return raw, fmt.Errorf("字段 %s 不存在", model.ResponseBody) - } - contentStr, ok := contentVal.(string) - if !ok || strings.TrimSpace(contentStr) == "" { - return raw, fmt.Errorf("字段 %s 为空或不是字符串", model.ResponseBody) + contentStr := gconv.String(raw[entity.ResponseBody]) + if strings.TrimSpace(contentStr) == "" { + return raw, fmt.Errorf("字段 %s 为空", entity.ResponseBody) } + + contentStr = strings.Map(func(r rune) rune { + if r < 32 && r != ' ' { + return -1 + } + return r + }, contentStr) + var arr []any if err := json.Unmarshal([]byte(contentStr), &arr); err != nil { return raw, fmt.Errorf("JSON解析失败: %w", err) @@ -38,17 +41,11 @@ func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[ return raw, fmt.Errorf("解析后数组为空") } - // 2) 校验必填字段 - if len(model.RequiredFields) > 0 { + for _, field := range model.RequiredFields { for i, r := range arr { - round, ok := r.(map[string]any) - if !ok { - continue - } - for _, field := range model.RequiredFields { - if gjson.New(round).Get(field).IsNil() { - return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field) - } + round, _ := r.(map[string]any) + if round != nil && gjson.New(round).Get(field).IsNil() { + return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field) } } } diff --git a/dao/model_gateway_models_dao.go b/dao/model_gateway_models_dao.go index be0c423..7217b79 100644 --- a/dao/model_gateway_models_dao.go +++ b/dao/model_gateway_models_dao.go @@ -56,6 +56,7 @@ func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewa Where(entity.ModelGatewayModelCol.Id, req.Id). Where(entity.ModelGatewayModelCol.Creator, req.Creator). Where(entity.ModelGatewayModelCol.ModelName, req.ModelName). + Where(entity.ModelGatewayModelCol.IsChatModel, req.IsChatModel). Fields(fields).One() if err != nil { return nil, err @@ -122,7 +123,7 @@ func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *enti func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) { sql := ` SELECT DISTINCT ON (model_name) * -FROM asynch_models +FROM ` + public.TableNameModel + ` WHERE deleted_at IS NULL AND (? = '' OR model_name LIKE ?) ` diff --git a/dao/model_gateway_task_dao.go b/dao/model_gateway_task_dao.go index 7d9c209..0031c4e 100644 --- a/dao/model_gateway_task_dao.go +++ b/dao/model_gateway_task_dao.go @@ -7,7 +7,6 @@ import ( "model-gateway/model/entity" "gitea.redpowerfuture.com/red-future/common/db/gfdb" - "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/util/gconv" ) @@ -128,32 +127,32 @@ func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit in // ClaimByID 按主键抢占,返回抢占后的任务 func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) { + // 1) 先查任务 var task entity.ModelGatewayTask - err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { - r, err := tx.Model(public.TableNameTask). - Where(entity.ModelGatewayTaskCol.Id, id). - Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). - Limit(1). - LockUpdate(). - One() - if err != nil { - return err - } - if r.IsEmpty() { - return fmt.Errorf("任务已被抢占或不存在: id=%d", id) - } - if err := r.Struct(&task); err != nil { - return err - } - _, err = tx.Model(public.TableNameTask). - Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}). - Where(entity.ModelGatewayTaskCol.Id, id). - OmitEmpty(). - Update() - return err - }) + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). + Where(entity.ModelGatewayTaskCol.Id, id). + Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). + One() if err != nil { return nil, err } + if r.IsEmpty() { + return nil, fmt.Errorf("任务已被抢占或不存在: id=%d", id) + } + if err = r.Struct(&task); err != nil { + return nil, err + } + + // 2) 改为执行中 + _, err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). + Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}). + Where(entity.ModelGatewayTaskCol.Id, id). + Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). // 防并发 + OmitEmpty(). + Update() + if err != nil { + return nil, err + } + return &task, nil } diff --git a/model/entity/model_gateway_model.go b/model/entity/model_gateway_model.go index c7f6ec1..afdeb3e 100644 --- a/model/entity/model_gateway_model.go +++ b/model/entity/model_gateway_model.go @@ -4,99 +4,98 @@ import "gitea.redpowerfuture.com/red-future/common/beans" type modelGatewayModelCol struct { beans.SQLBaseCol - ModelName string - ModelType string - BaseURL string - HttpMethod string - HeadMsg string - FormJSON string - RequestMapping string - ResponseMapping string - ResponseBody string - ResponseTokenField string - RequiredFields string - IsPrivate string - IsChatModel string - CallMode string - ApiKey string - Enabled string - MaxConcurrency string - TimeoutSeconds string - RetryTimes string - AutoCleanSeconds string - IsOwner string - OperatorName string - TokenConfig string - ExtendMapping string - QueryConfig string - StreamConfig string - FirstFrame string - LastFrame string - MaxTokens string + ModelName string + ModelType string + BaseURL string + HttpMethod string + HeadMsg string + FormJSON string + RequestMapping string + ResponseMapping string + RequiredFields string + IsPrivate string + IsChatModel string + CallMode string + ApiKey string + Enabled string + MaxConcurrency string + TimeoutSeconds string + RetryTimes string + AutoCleanSeconds string + IsOwner string + OperatorName string + TokenConfig string + ExtendMapping string + QueryConfig string + StreamConfig string + FirstFrame string + LastFrame string + MaxTokens string } var ModelGatewayModelCol = modelGatewayModelCol{ - SQLBaseCol: beans.DefSQLBaseCol, - ModelName: "model_name", - ModelType: "model_type", - BaseURL: "base_url", - HttpMethod: "http_method", - HeadMsg: "head_msg", - FormJSON: "form_json", - RequestMapping: "request_mapping", - ResponseMapping: "response_mapping", - ResponseBody: "response_body", - ResponseTokenField: "response_token_field", - RequiredFields: "required_fields", - IsPrivate: "is_private", - IsChatModel: "is_chat_model", - CallMode: "call_mode", - ApiKey: "api_key", - Enabled: "enabled", - MaxConcurrency: "max_concurrency", - TimeoutSeconds: "timeout_seconds", - RetryTimes: "retry_times", - AutoCleanSeconds: "auto_clean_seconds", - IsOwner: "is_owner", - OperatorName: "operator_name", - TokenConfig: "token_config", - ExtendMapping: "extend_mapping", - QueryConfig: "query_config", - StreamConfig: "stream_config", - FirstFrame: "first_frame", - LastFrame: "last_frame", - MaxTokens: "max_tokens", + SQLBaseCol: beans.DefSQLBaseCol, + ModelName: "model_name", + ModelType: "model_type", + BaseURL: "base_url", + HttpMethod: "http_method", + HeadMsg: "head_msg", + FormJSON: "form_json", + RequestMapping: "request_mapping", + ResponseMapping: "response_mapping", + RequiredFields: "required_fields", + IsPrivate: "is_private", + IsChatModel: "is_chat_model", + CallMode: "call_mode", + ApiKey: "api_key", + Enabled: "enabled", + MaxConcurrency: "max_concurrency", + TimeoutSeconds: "timeout_seconds", + RetryTimes: "retry_times", + AutoCleanSeconds: "auto_clean_seconds", + IsOwner: "is_owner", + OperatorName: "operator_name", + TokenConfig: "token_config", + ExtendMapping: "extend_mapping", + QueryConfig: "query_config", + StreamConfig: "stream_config", + FirstFrame: "first_frame", + LastFrame: "last_frame", + MaxTokens: "max_tokens", } type ModelGatewayModel struct { - beans.SQLBaseDO `orm:",inline"` - ModelName string `orm:"model_name" json:"modelName"` - ModelType int `orm:"model_type" json:"modelType"` - BaseURL string `orm:"base_url" json:"baseUrl"` - HttpMethod string `orm:"http_method" json:"httpMethod"` - HeadMsg map[string]any `orm:"head_msg" json:"headMsg"` - Form []map[string]any `orm:"form_json" json:"form"` - RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"` - ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"` - ResponseBody string `orm:"response_body" json:"responseBody"` - ResponseTokenField string `orm:"response_token_field" json:"tokenField"` - RequiredFields []string `orm:"required_fields" json:"requiredFields"` - IsPrivate *int `orm:"is_private" json:"isPrivate"` - IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` - CallMode *int `orm:"call_mode" json:"callMode"` - ApiKey string `orm:"api_key" json:"apiKey"` - Enabled *int `orm:"enabled" json:"enabled"` - MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"` - TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"` - RetryTimes int `orm:"retry_times" json:"retryTimes"` - AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` - IsOwner *int `orm:"is_owner" json:"isOwner"` - OperatorName string `orm:"operator_name" json:"operatorName"` - TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"` - ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"` - QueryConfig map[string]any `orm:"query_config" json:"queryConfig"` - StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"` - FirstFrame string `orm:"first_frame" json:"firstFrame"` - LastFrame string `orm:"last_frame" json:"lastFrame"` - MaxTokens int `orm:"max_tokens" json:"maxTokens"` + beans.SQLBaseDO `orm:",inline"` + ModelName string `orm:"model_name" json:"modelName"` + ModelType int `orm:"model_type" json:"modelType"` + BaseURL string `orm:"base_url" json:"baseUrl"` + HttpMethod string `orm:"http_method" json:"httpMethod"` + HeadMsg map[string]any `orm:"head_msg" json:"headMsg"` + Form []map[string]any `orm:"form_json" json:"form"` + RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"` + ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"` + RequiredFields []string `orm:"required_fields" json:"requiredFields"` + IsPrivate *int `orm:"is_private" json:"isPrivate"` + IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` + CallMode *int `orm:"call_mode" json:"callMode"` + ApiKey string `orm:"api_key" json:"apiKey"` + Enabled *int `orm:"enabled" json:"enabled"` + MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"` + TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"` + RetryTimes int `orm:"retry_times" json:"retryTimes"` + AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` + IsOwner *int `orm:"is_owner" json:"isOwner"` + OperatorName string `orm:"operator_name" json:"operatorName"` + TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"` + ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"` + QueryConfig map[string]any `orm:"query_config" json:"queryConfig"` + StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"` + FirstFrame string `orm:"first_frame" json:"firstFrame"` + LastFrame string `orm:"last_frame" json:"lastFrame"` + MaxTokens int `orm:"max_tokens" json:"maxTokens"` } + +const ( + ResponseBody = "content" //返回主体(必填) + TotalTokens = "total_tokens" //总token数 +) diff --git a/service/model/model_service.go b/service/model/model_service.go index b28b9ff..080b406 100644 --- a/service/model/model_service.go +++ b/service/model/model_service.go @@ -99,7 +99,7 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM ModelName: req.ModelName, IsChatModel: req.IsChatModel, }) - if err != nil { + if err != nil || model == nil { return nil, err } return &dto.GetModelRes{ diff --git a/service/task/task_service.go b/service/task/task_service.go index fe1eba7..6e5f710 100644 --- a/service/task/task_service.go +++ b/service/task/task_service.go @@ -107,13 +107,27 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * }, }) - // 5) 获取任务信息 - task, err := dao.ModelGatewayTask.ClaimByID(ctx, id) + // 5) 抢占任务:改为执行中 + rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{ + SQLBaseDO: beans.SQLBaseDO{Id: id}, + State: public.TaskStatusRunning, + }) + if err != nil { + return nil, err + } + if rows == 0 { + return nil, fmt.Errorf("任务不存在: id=%d", id) + } + + // 6) 查询任务信息 + task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{ + SQLBaseDO: beans.SQLBaseDO{Id: id}, + }) if err != nil { return nil, err } - // 5) 创建成功后立即异步尝试执行当前任务 + // 7) 创建成功后立即异步尝试执行当前任务 go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req) return &dto.CreateTaskRes{TaskID: taskID}, nil diff --git a/service/task/worker.go b/service/task/worker.go index de31f83..2236c36 100644 --- a/service/task/worker.go +++ b/service/task/worker.go @@ -67,25 +67,39 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa // ============================================ // 2) 调用模型 // ============================================ - switch { - case model.CallMode != nil && *model.CallMode == public.CallModeStream: - rawBytes, streamErr := w.callModelStream(ctx, task, model, body) - if streamErr != nil { - w.failTask(ctx, task, startTime, streamErr.Error()) + for attempt := 0; ; attempt++ { + if attempt > 0 { + g.Log().Infof(ctx, "[执行任务][重试] 调用模型 第%d次 taskId=%s", attempt, task.TaskID) + time.Sleep(time.Duration(attempt) * time.Second) + } + + switch { + case model.CallMode != nil && *model.CallMode == public.CallModeStream: + rawBytes, streamErr := w.callModelStream(ctx, task, model, body) + if streamErr != nil { + err = streamErr + g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err) + continue + } + result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig) + case model.CallMode != nil && *model.CallMode == public.CallModeAsync: + result, err = w.callModel(ctx, task, model, body) + if err == nil { + result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg) + } + default: + result, err = w.callModel(ctx, task, model, body) + } + + if err == nil { + break + } + + if !strings.Contains(err.Error(), "Timeout") { + w.failTask(ctx, task, startTime, err.Error()) return } - result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig) - case model.CallMode != nil && *model.CallMode == public.CallModeAsync: - result, err = w.callModel(ctx, task, model, body) - if err == nil { - result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg) - } - default: - result, err = w.callModel(ctx, task, model, body) - } - if err != nil { - w.failTask(ctx, task, startTime, err.Error()) - return + g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err) } // ============================================ @@ -205,7 +219,7 @@ func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGate return nil, err } // 2. 拿到 task_id - taskID := gjson.New(body).Get(model.ResponseBody).String() + taskID := gjson.New(body).Get(entity.ResponseBody).String() // 3. 创建等待通道 ch := make(chan asyncResult, 1) @@ -294,6 +308,8 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTa // parseAndRetry 解析模型返回结果,并重试 func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) { + var lastErr error + for attempt := 0; attempt <= maxRetry; attempt++ { if attempt > 0 { g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) @@ -302,6 +318,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta // 1) 响应映射 mapped, err := util.MapResponsePayload(model.ResponseMapping, body) if err != nil { + lastErr = err g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) if attempt == maxRetry { return nil, fmt.Errorf("响应映射重试耗尽: %w", err) @@ -309,10 +326,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta continue } - // 2) 先存 token 到数据库,防止后续失败丢失 - if _, ok := mapped[model.ResponseTokenField]; ok { - task.ExpendTokens = gconv.Int64(mapped[model.ResponseTokenField]) - _, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{ + // 2) 存 token + if _, ok := mapped[entity.TotalTokens]; ok { + task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens]) + _, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{ SQLBaseDO: beans.SQLBaseDO{Id: task.Id}, ExpendTokens: task.ExpendTokens, }) @@ -326,9 +343,9 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta if err == nil { return parsed, nil } + lastErr = err case public.BuildTypeStruct: - parsed = util.ParseStructResult(mapped, model.ResponseBody) - return parsed, nil + return util.ParseStructResult(mapped, entity.ResponseBody), nil default: return mapped, nil } @@ -336,22 +353,22 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) if attempt == maxRetry { - return nil, fmt.Errorf("JSON解析重试耗尽: %w", err) + return nil, fmt.Errorf("JSON解析重试耗尽: %w", lastErr) } - // 4) 重新调模型(直接调,不走缓存) + // 4) 拼接错误信息到请求体,重调模型 task.RetryCount++ _, _ = dao.ModelGatewayTask.Update(ctx, task) - rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body) + body = injectErrorMessage(task.RequestPayload.Body, lastErr) + rawData, callErr := InvokeModel(ctx, model, body) if callErr != nil { g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr) continue } - // 5) 解析原始响应,覆盖 body 进入下一轮 var rawResp map[string]any - if err = json.Unmarshal(rawData, &rawResp); err != nil { + if err := json.Unmarshal(rawData, &rawResp); err != nil { g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err) continue } @@ -361,6 +378,44 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta return body, nil } +// injectErrorMessage 将错误信息拼接到 user 消息中 +func injectErrorMessage(payload map[string]any, err error) map[string]any { + if err == nil { + return payload + } + + messages, _ := payload["messages"].([]any) + if len(messages) == 0 { + return payload + } + + errMsg := fmt.Sprintf("\n\n【上一轮输出错误,请修正】%s", err.Error()) + + // 找到最后一个 role=user 的消息,追加错误提示 + for i := len(messages) - 1; i >= 0; i-- { + msg, ok := messages[i].(map[string]any) + if !ok { + continue + } + if gconv.String(msg["role"]) != "user" { + continue + } + + switch c := msg["content"].(type) { + case string: + msg["content"] = c + errMsg + case []any: + msg["content"] = append(c, map[string]any{ + "type": "text", + "text": errMsg, + }) + } + break + } + + return payload +} + // InvokeModel 调用模型服务,返回二进制结果 // modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key) func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {