Files
model-gateway/service/model/model_service.go
2026-06-10 16:24:29 +08:00

257 lines
6.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package model
import (
"context"
"errors"
"model-gateway/common/util"
"model-gateway/consts/public"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"model-gateway/service/gateway"
"gitea.redpowerfuture.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Model = &modelService{}
type modelService struct{}
// Create 创建模型
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (*dto.CreateModelRes, error) {
// 1如果设为会话模型先把该用户旧会话模型取消
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
if err := s.clearUserChatModel(ctx); err != nil {
return nil, err
}
}
// 2判断是否超管决定 isOwner
req.IsOwner = gconv.PtrInt(1)
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
req.IsOwner = gconv.PtrInt(0)
}
// 3入库
id, err := dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
if err != nil {
return nil, err
}
return &dto.CreateModelRes{ID: id}, nil
}
// Update 更新模型配置
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
// 1会话模型唯一性校验
if req.IsChatModel != nil && *req.IsChatModel == 1 {
if err := s.checkChatModelUnique(ctx); err != nil {
return err
}
}
// 2超管创建/普通用户更新
req.IsOwner = gconv.PtrInt(1)
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
req.IsOwner = gconv.PtrInt(0)
_, err := dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
return err
}
// 3跨租户判断超管的模型不允许直接修改走插入新记录
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
if err != nil {
return err
}
if model.TenantId == 1 {
_, err = dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
return err
}
_, err = dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
return err
}
// Delete 删除模型
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
return err
}
// Get 获取模型详情
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
if g.IsEmpty(req.ID) {
req.Creator = user.UserName
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
Id: req.ID,
Creator: user.UserName,
},
ModelName: req.ModelName,
IsChatModel: req.IsChatModel,
})
if err != nil {
return nil, err
}
return &dto.GetModelRes{
Model: model,
}, nil
}
// List 获取模型列表
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.ListModelRes, error) {
// 1判断超管
req.IsOwner = gconv.PtrInt(1)
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
req.IsOwner = gconv.PtrInt(0)
}
// 2获取当前用户
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
req.Creator = user.UserName
// 3查询
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
if err != nil {
return nil, err
}
return &dto.ListModelRes{List: models, Total: total}, nil
}
// UpdateChatModel 设置会话模型
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
// 1校验新模型存在
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
})
if err != nil || newModel == nil {
return errors.New("新会话模型不存在")
}
// 2获取当前用户的会话模型
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
if err != nil {
return err
}
// 3事务取消旧的 + 设置新的
return gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
if !g.IsEmpty(currentModel) {
if currentModel.ModelType != public.ModelTypeInference {
return errors.New("当前模型为非推理模型,不能设置为会话模型")
}
if currentModel.Id != req.Id {
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
IsChatModel: gconv.PtrInt(0),
})
if err != nil {
return err
}
}
}
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
IsChatModel: gconv.PtrInt(1),
})
return err
})
}
// GetIsChatModel 获取当前用户会话模型
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
if err != nil || model == nil {
return nil, err
}
return &dto.GetIsChatModelRes{Model: model}, nil
}
// ==================== 辅助方法 ====================
// clearUserChatModel 清除当前用户旧会话模型
func (s *modelService) clearUserChatModel(ctx context.Context) error {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
if err != nil || model == nil {
return nil
}
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
IsChatModel: gconv.PtrInt(0),
})
return err
}
// checkChatModelUnique 校验用户是否已有会话模型
func (s *modelService) checkChatModelUnique(ctx context.Context) error {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
if err != nil {
return err
}
if model != nil {
return errors.New("用户已存在会话模型")
}
return nil
}
// GetModelTypesFromConfig 从配置文件读取模型类型
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
// 返回副本,避免外部修改
types := make(map[int]string, len(public.ModelTypeName))
for k, v := range public.ModelTypeName {
types[k] = v
}
return &dto.TypeItem{
Type: types,
}, nil
}
// GetOperatorList 获取运营商列表
func GetOperatorList() (res *dto.ListOperatorRes, err error) {
return &dto.ListOperatorRes{
List: public.OperatorList,
}, nil
}