Files
prompts-core/service/prompt/prompt_compose_service.go

381 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package prompt
import (
"context"
"encoding/json"
"errors"
"fmt"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/consts/public"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"prompts-core/service/gateway"
)
// ComposeMessages 核心拼接提示词主流程
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
chatModel, aiModel, err := GetModelMessage(ctx, req)
if err != nil {
return nil, err
}
if err = validateUserForm(req, aiModel); err != nil {
return nil, err
}
switch req.BuildType {
case public.BuildTypePrompt:
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
case public.BuildTypeNode:
return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建
default:
return nil, errors.New("BuildType 不支持")
}
}
// GetModelMessage 获取模型信息
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
}
chatModel, err := getChatModel(ctx, userInfo.UserName)
if err != nil {
return nil, nil, err
}
aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName)
if err != nil {
return nil, nil, err
}
return chatModel, aiModel, nil
}
// validateUserForm 校验用户表单
func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) error {
if len(req.UserForm) == 0 {
return nil
}
isValid, exceedTokens, err := util.CheckUserFormWithinWindow(req.UserForm, model.TokenConfig)
if err != nil {
return fmt.Errorf("校验用户表单失败: %w", err)
}
if !isValid {
availableWindow := util.GetAvailableWindow(model.TokenConfig)
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens可用窗口 %d tokens请精简后重试",
exceedTokens, availableWindow)
}
return nil
}
// handlePromptBuild 处理提示词构建BuildType=1
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 获取历史会话
history, err := GetHistoryMessages(ctx, req.SessionId)
if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil
}
// 调用推理模型
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
// 保存任务记录
if err = saveComposeTask(ctx, taskID, req); err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
}, nil
}
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
if err := saveComposeTask(ctx, taskID, req); err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
}, nil
}
// saveComposeTask 保存组合任务记录
func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error {
_, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshal(req),
Status: public.ComposeStatusPending,
})
return err
}
// getChatModel 获取聊天模型
func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) {
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
IsChatModel: new(1),
})
if err != nil {
return nil, fmt.Errorf("查询聊天模型失败: %w", err)
}
if chatModel == nil {
return nil, errors.New("当前没有对话模型,请添加")
}
return chatModel, nil
}
// getAIModel 获取AI模型
func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) {
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
ModelName: modelName,
})
if err != nil {
return nil, fmt.Errorf("查询AI模型失败: %w", err)
}
if aiModel == nil {
return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName)
}
return aiModel, nil
}
// callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, idModel *entity.AsynchModel, history []map[string]any) (string, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, idModel, history)
if err != nil {
return "", fmt.Errorf("构建推理请求失败: %w", err)
}
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {
return "", fmt.Errorf("创建网关任务失败: %w", err)
}
if taskID == "" {
return "", errors.New("网关未返回taskId")
}
return taskID, nil
}
// createDefaultResult 创建默认结果
func createDefaultResult(data map[string]any) *dto.MultiRoundResult {
if data == nil {
data = make(map[string]any)
}
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []map[string]any{data},
}
}
// Callback 回调处理
func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
})
if err != nil {
return fmt.Errorf("查询任务失败: %w", err)
}
if composeTask == nil {
return fmt.Errorf("任务不存在: %s", req.TaskId)
}
//处理失败
if req.State == 3 {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Text,
})
// 用更新后的值发送回调
if composeTask.CallbackUrl != "" {
failedTask := &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
CallbackUrl: composeTask.CallbackUrl,
Messages: composeTask.Messages,
}
gateway.SendCallback(ctx, failedTask)
}
return err
}
//处理成功
if req.State == 2 {
// 1. 根据 BuildType 解析结果
var messages any
switch composeTask.BuildType {
case public.BuildTypePrompt: // 提示词构建解析
messages = parsePromptResult(req.Text)
case public.BuildTypeNode: // 节点构建解析
messages = parseNodeResult(req.Text)
default:
messages = req.Text
}
// 2. 更新数据库
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Text,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err)
return err
}
// 4. 发送回调给业务方
if composeTask.CallbackUrl != "" {
successTask := &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
CallbackUrl: composeTask.CallbackUrl,
}
gateway.SendCallback(ctx, successTask)
}
}
return err
}
// parsePromptResult 解析提示词构建结果
func parsePromptResult(raw string) *dto.MultiRoundResult {
var wrapper map[string]any
if err := json.Unmarshal([]byte(raw), &wrapper); err != nil {
return createDefaultResult(map[string]any{"raw": raw})
}
contentStr, ok := wrapper["content"].(string)
if !ok || contentStr == "" {
return createDefaultResult(wrapper)
}
// 先尝试解析为数组
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
return &dto.MultiRoundResult{
TotalRounds: len(roundsArray),
Rounds: roundsArray,
}
}
// 再尝试解析为单个对象
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []map[string]any{singleRound},
}
}
return createDefaultResult(map[string]any{"content": contentStr})
}
func tryParseAsMapArray(jsonStr string) []map[string]any {
var arr []map[string]any
if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil {
return nil
}
if len(arr) == 0 {
return nil
}
return arr
}
func tryParseAsMap(jsonStr string) map[string]any {
var obj map[string]any
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
return nil
}
if len(obj) == 0 {
return nil
}
return obj
}
// parseNodeResult 解析节点构建结果
func parseNodeResult(raw string) *dto.MultiRoundResult {
var result map[string]any
if err := json.Unmarshal([]byte(raw), &result); err != nil {
return createDefaultResult(map[string]any{"raw": raw})
}
if contentStr, ok := result["content"].(string); ok && contentStr != "" {
var inner map[string]any
if err := json.Unmarshal([]byte(contentStr), &inner); err == nil {
result = inner
}
}
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []map[string]any{result},
}
}
// GetComposeTask 查询任务结果
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
return nil, fmt.Errorf("查询任务失败: %w", err)
}
if record == nil {
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
messages := parseMessagesForResponse(record.Messages)
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: messages,
}, nil
}
// parseMessagesForResponse 解析用于响应的消息
func parseMessagesForResponse(messages any) any {
str, ok := messages.(string)
if !ok || str == "" {
return messages
}
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
return parsed
}
return messages
}