255 lines
6.7 KiB
Go
255 lines
6.7 KiB
Go
package model
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"model-gateway/common/util"
|
||
"model-gateway/consts/public"
|
||
"model-gateway/dao"
|
||
"model-gateway/model/dto"
|
||
"model-gateway/model/entity"
|
||
"model-gateway/service/gateway"
|
||
|
||
"gitea.com/red-future/common/beans"
|
||
"gitea.com/red-future/common/db/gfdb"
|
||
"gitea.com/red-future/common/utils"
|
||
"github.com/gogf/gf/v2/database/gdb"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
)
|
||
|
||
var Model = &modelService{}
|
||
|
||
type modelService struct{}
|
||
|
||
// Create 创建模型
|
||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (*dto.CreateModelRes, error) {
|
||
// 1)如果设为会话模型,先把该用户旧会话模型取消
|
||
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
|
||
if err := s.clearUserChatModel(ctx); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
// 2)判断是否超管,决定 isOwner
|
||
req.IsOwner = gconv.PtrInt(1)
|
||
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
|
||
req.IsOwner = gconv.PtrInt(0)
|
||
}
|
||
|
||
// 3)入库
|
||
id, err := dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &dto.CreateModelRes{ID: id}, nil
|
||
}
|
||
|
||
// Update 更新模型配置
|
||
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
|
||
// 1)会话模型唯一性校验
|
||
if req.IsChatModel != nil && *req.IsChatModel == 1 {
|
||
if err := s.checkChatModelUnique(ctx); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
// 2)超管创建/普通用户更新
|
||
req.IsOwner = gconv.PtrInt(1)
|
||
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
|
||
req.IsOwner = gconv.PtrInt(0)
|
||
_, err := dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||
return err
|
||
}
|
||
// 3)跨租户判断:超管的模型不允许直接修改,走插入新记录
|
||
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if model.TenantId == 1 {
|
||
_, err = dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||
return err
|
||
}
|
||
_, err = dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||
return err
|
||
}
|
||
|
||
// Delete 删除模型
|
||
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
|
||
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
|
||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||
})
|
||
return err
|
||
}
|
||
|
||
// Get 获取模型详情
|
||
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
|
||
user, err := utils.GetUserInfo(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if g.IsEmpty(req.ID) {
|
||
req.Creator = user.UserName
|
||
}
|
||
modelReq := new(entity.AsynchModel)
|
||
err = gconv.Struct(req, modelReq)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
model, err := dao.Model.Get(ctx, modelReq)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &dto.GetModelRes{
|
||
Model: model,
|
||
}, nil
|
||
}
|
||
|
||
// List 获取模型列表
|
||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.ListModelRes, error) {
|
||
// 1)判断超管
|
||
req.IsOwner = gconv.PtrInt(1)
|
||
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
|
||
req.IsOwner = gconv.PtrInt(0)
|
||
}
|
||
|
||
// 2)获取当前用户
|
||
user, err := utils.GetUserInfo(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
req.Creator = user.UserName
|
||
|
||
// 3)查询
|
||
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &dto.ListModelRes{List: models, Total: total}, nil
|
||
}
|
||
|
||
// UpdateChatModel 设置会话模型
|
||
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
|
||
// 1)校验新模型存在
|
||
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
||
})
|
||
if err != nil || newModel == nil {
|
||
return errors.New("新会话模型不存在")
|
||
}
|
||
|
||
// 2)获取当前用户的会话模型
|
||
user, err := utils.GetUserInfo(ctx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||
IsChatModel: new(1),
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 3)事务:取消旧的 + 设置新的
|
||
return gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||
if !g.IsEmpty(currentModel) {
|
||
if currentModel.ModelType != public.ModelTypeInference {
|
||
return errors.New("当前模型为非推理模型,不能设置为会话模型")
|
||
}
|
||
if currentModel.Id != req.Id {
|
||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
|
||
IsChatModel: gconv.PtrInt(0),
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
}
|
||
|
||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
||
IsChatModel: gconv.PtrInt(1),
|
||
})
|
||
return err
|
||
})
|
||
}
|
||
|
||
// GetIsChatModel 获取当前用户会话模型
|
||
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
|
||
user, err := utils.GetUserInfo(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||
IsChatModel: new(1),
|
||
})
|
||
if err != nil || model == nil {
|
||
return nil, err
|
||
}
|
||
return &dto.GetIsChatModelRes{Model: model}, nil
|
||
}
|
||
|
||
// ==================== 辅助方法 ====================
|
||
|
||
// clearUserChatModel 清除当前用户旧会话模型
|
||
func (s *modelService) clearUserChatModel(ctx context.Context) error {
|
||
user, err := utils.GetUserInfo(ctx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||
IsChatModel: new(1),
|
||
})
|
||
if err != nil || model == nil {
|
||
return nil
|
||
}
|
||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
|
||
IsChatModel: gconv.PtrInt(0),
|
||
})
|
||
return err
|
||
}
|
||
|
||
// checkChatModelUnique 校验用户是否已有会话模型
|
||
func (s *modelService) checkChatModelUnique(ctx context.Context) error {
|
||
user, err := utils.GetUserInfo(ctx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||
IsChatModel: new(1),
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if model != nil {
|
||
return errors.New("用户已存在会话模型")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetModelTypesFromConfig 从配置文件读取模型类型
|
||
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
|
||
// 返回副本,避免外部修改
|
||
types := make(map[int]string, len(public.ModelTypeName))
|
||
for k, v := range public.ModelTypeName {
|
||
types[k] = v
|
||
}
|
||
return &dto.TypeItem{
|
||
Type: types,
|
||
}, nil
|
||
}
|
||
|
||
// GetOperatorList 获取运营商列表
|
||
func GetOperatorList() (res *dto.ListOperatorRes, err error) {
|
||
return &dto.ListOperatorRes{
|
||
List: public.OperatorList,
|
||
}, nil
|
||
}
|