Files
model-gateway/service/model_service.go

379 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"errors"
"model-gateway/common/util"
"model-gateway/consts/public"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"model-gateway/service/gateway"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Model = &modelService{}
type modelService struct{}
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
// 获取当前会话模型
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
// 获取当前用户会话模型
var model *entity.AsynchModel
model, err = dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
IsChatModel: new(1),
})
if err != nil {
return nil, err
}
// 如果有会话模型,那就改变为 0
if model != nil {
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
IsChatModel: gconv.PtrInt(0),
})
if err != nil {
return nil, err
}
}
}
req.IsOwner = gconv.PtrInt(1)
admin, err := gateway.IsSuperAdmin(ctx)
if err != nil {
return
}
if admin {
req.IsOwner = gconv.PtrInt(0)
}
id, err := dao.Model.Insert(ctx, &entity.AsynchModel{
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
})
if err != nil {
return nil, err
}
return &dto.CreateModelRes{ID: id}, nil
}
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
//根据当前 isChatModel 来判断是否更新模型
if req.IsChatModel == gconv.PtrInt(1) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
// 获取当前用户会话模型
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
IsChatModel: new(1),
})
if err != nil {
return err
}
if model != nil {
return errors.New("用户已存在会话模型,不能创建")
}
}
req.IsOwner = gconv.PtrInt(1)
admin, err := gateway.IsSuperAdmin(ctx)
if err != nil {
return err
}
if admin {
req.IsOwner = gconv.PtrInt(0)
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
})
if err != nil {
return err
}
return nil
}
// 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建否则更新
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
if err != nil {
return err
}
if model.TenantId == 1 {
insertDto := new(dto.CreateModelReq)
err = gconv.Struct(req, insertDto)
if err != nil {
return err
}
_, err = dao.Model.Insert(ctx, &entity.AsynchModel{
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
})
return err
}
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
})
return err
}
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
return err
}
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
if err != nil {
return nil, err
}
model.Form = util.ParseJSONField(model.Form)
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
model.TokenConfig = util.ParseJSONField(model.TokenConfig)
return &dto.GetModelRes{
Model: model,
}, nil
}
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
var models []*entity.AsynchModel
req.IsOwner = gconv.PtrInt(1)
admin, err := gateway.IsSuperAdmin(ctx)
if err != nil {
return
}
if admin {
req.IsOwner = gconv.PtrInt(0)
}
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
req.Creator = user.UserName
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
if err != nil {
return
}
// 处理列表中每条记录的 JSONB 字段
for _, m := range models {
m.Form = util.ParseJSONField(m.Form)
m.RequestMapping = util.ParseJSONField(m.RequestMapping)
m.ResponseMapping = util.ParseJSONField(m.ResponseMapping)
m.ResponseBody = util.ParseJSONField(m.ResponseBody)
m.TokenConfig = util.ParseJSONField(m.TokenConfig)
}
return &dto.ListModelRes{
List: models,
Total: total,
}, nil
}
// GetModelTypesFromConfig 从配置文件读取模型类型
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
// 返回副本,避免外部修改
types := make(map[int]string, len(public.ModelTypeName))
for k, v := range public.ModelTypeName {
types[k] = v
}
return &dto.TypeItem{
Type: types,
}, nil
}
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
// 校验新会话模型是否存在
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
})
if err != nil {
return err
}
if newModel == nil {
return errors.New("新会话模型不存在")
}
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return err
}
// 获取当前用户会话模型
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
IsChatModel: new(1),
})
if err != nil {
return err
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
if !g.IsEmpty(currentModel) {
if currentModel.ModelType != public.ModelTypeInference {
return errors.New("当前模型为非推理模型,不能设置为会话模型")
}
// 如果点击的就是当前会话模型已经是1取消它设为0
if currentModel.Id != req.Id {
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
IsChatModel: gconv.PtrInt(0),
})
if err != nil {
return err
}
}
}
// 设置当前为会话模型设为1
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
IsChatModel: gconv.PtrInt(1),
})
return err
})
return err
}
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
IsChatModel: new(1),
})
if err != nil {
return nil, err
}
if model == nil {
return nil, nil
}
model.Form = util.ParseJSONField(model.Form)
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
model.TokenConfig = util.ParseJSONField(model.TokenConfig)
return &dto.GetIsChatModelRes{
Model: model,
}, nil
}