Files
prompts-core/service/prompt/prompt_compose_service.go
2026-06-10 16:32:42 +08:00

312 lines
9.0 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package prompt
import (
"context"
"errors"
"fmt"
"prompts-core/service/session"
"prompts-core/common/util"
"prompts-core/consts/public"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"prompts-core/service/gateway"
"gitea.redpowerfuture.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// ComposeMessages 核心拼接提示词主流程
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
// 1) 获取模型信息
chatModel, aiModel, err := GetModelMessage(ctx, req)
if err != nil {
return nil, err
}
// 2) 校验用户表单
if err = validateUserForm(req, aiModel); err != nil {
return nil, err
}
return handleBuild(ctx, req, chatModel, aiModel)
}
// GetModelMessage 获取模型信息
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway.AsynchModel, *gateway.AsynchModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
}
chatModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: 1,
})
if err != nil || chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
aiModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{TenantId: userInfo.TenantId, Creator: userInfo.UserName},
ModelName: req.ModelName,
})
if err != nil || aiModel == nil {
return nil, nil, errors.New("需要构建的模型不存在")
}
return chatModel, aiModel, nil
}
// validateUserForm 校验用户表单
func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) error {
if len(req.UserForm) == 0 {
return nil
}
isValid, exceedTokens, err := util.CheckUserFormWithinWindow(req.UserForm, model.TokenConfig)
if err != nil {
return fmt.Errorf("校验用户表单失败: %w", err)
}
if !isValid {
availableWindow := util.GetAvailableWindow(model.TokenConfig)
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens可用窗口 %d tokens请精简后重试",
exceedTokens, availableWindow)
}
return nil
}
// handleBuild 通用构建处理
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 1) 处理表单分批
processedReq, _, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
// 2) 构建推理请求
ir := NewPromptIR()
var taskReq map[string]any
switch req.BuildType {
case public.BuildTypePrompt:
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir)
case public.BuildTypeNode:
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
case public.BuildTypeStruct:
taskReq, err = buildStructTypeRequest(ctx, req, chatModel, ir)
default:
return nil, errors.New("不支持的构建类型")
}
if err != nil {
return nil, fmt.Errorf("构建推理请求失败: %w", err)
}
// 3) 调用网关创建任务
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {
return nil, fmt.Errorf("创建网关任务失败: %w", err)
}
if taskID == "" {
return nil, errors.New("网关未返回taskId")
}
// 4) 保存任务记录
if _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: gconv.Map(req),
Status: public.ComposeStatusPending,
}); err != nil {
return nil, err
}
return &dto.ComposeMessagesRes{TaskId: taskID}, nil
}
// Callback 回调处理
func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State)
// 1) 查询任务
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId})
if err != nil {
return fmt.Errorf("查询任务失败: %w", err)
}
// 2) 处理失败
if req.State == 3 {
return handleCallbackFailed(ctx, req, composeTask)
}
// 3) 处理成功
if req.State == 2 {
return handleCallbackSuccess(ctx, req, composeTask)
}
return nil
}
// handleCallbackFailed 处理回调失败
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultJson: req.Messages,
})
if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusFailed
composeTask.ErrorMessage = req.ErrorMsg
_ = gateway.SendCallback(ctx, composeTask, 0)
}
return err
}
// handleCallbackSuccess 处理回调成功
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
// 1) 获取模型配置
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
ModelName: composeTask.ModelName,
})
if err != nil {
return fmt.Errorf("查询模型失败: %w", err)
}
// 2) 获取协议配置
protocol, _ := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: model.OperatorName,
Status: 1,
})
// 3) 获取历史消息 + 保存当前轮
payload := composeTask.RequestPayload
sessionId := gconv.String(payload["sessionId"])
nodeId := gconv.String(payload["nodeId"])
var history []dto.FlatMessage
var epicycleId int64
if sessionId != "" && nodeId != "" && model.ModelType == public.ModelTypeInference {
// 3.1 获取历史
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
SessionId: sessionId,
NodeId: nodeId,
})
if h != nil {
history = h.Messages
}
// 3.2 保存当前轮(先存,下次查询就能拿到)
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
NodeId: nodeId,
SessionId: sessionId,
RequestContent: userMsg,
})
}
}
// 4) 合并附加结构
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
// 5) 注入历史
if len(history) > 0 {
messages = InjectHistory(messages, history, protocol)
}
// 6) 更新数据库
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultJson: messages,
})
if err != nil {
return err
}
// 8) 回调业务方
if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusSuccess
composeTask.ResultJson = messages
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
}
return nil
}
// InjectHistory 插入历史会话
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
if protocol == nil || len(history) == 0 {
return roundsData
}
// 1) 提取第一轮的 messages
rounds := roundsData["rounds"].([]any)
firstRound := rounds[0].(map[string]any)
original := firstRound["messages"].([]any)
// 2) 按 merge_order 拼接
result := make([]any, 0, len(original)+len(history))
for _, part := range protocol.MergeOrder {
switch part {
case "system":
for _, m := range original {
msg := m.(map[string]any)
if gconv.String(msg["role"]) == "system" {
result = append(result, msg)
}
}
case "history":
if gconv.Bool(protocol.Capabilities["support_history"]) {
for _, msg := range history {
result = append(result, map[string]any{
"role": msg.Role,
"content": msg.Content, // 纯字符串,不转换
})
}
}
case "user":
for _, m := range original {
msg := m.(map[string]any)
if gconv.String(msg["role"]) == "user" {
result = append(result, msg)
}
}
}
}
// 3) 角色映射
if len(protocol.RoleMapping) > 0 {
for _, m := range result {
msg := m.(map[string]any)
role := gconv.String(msg["role"])
if mapped, ok := protocol.RoleMapping[role]; ok {
msg["role"] = mapped
}
}
}
// 4) 直接修改原对象
firstRound["messages"] = result
return roundsData
}
// GetComposeTask 查询任务结果
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
return nil, fmt.Errorf("查询任务失败: %w", err)
}
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: record.ResultJson,
}, nil
}