256 lines
6.3 KiB
Go
256 lines
6.3 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"model-gateway/dao"
|
||
"model-gateway/model/dto"
|
||
"model-gateway/model/entity"
|
||
|
||
"gitea.com/red-future/common/beans"
|
||
"gitea.com/red-future/common/db/gfdb"
|
||
"gitea.com/red-future/common/http"
|
||
"gitea.com/red-future/common/utils"
|
||
"github.com/gogf/gf/v2/database/gdb"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
)
|
||
|
||
var Model = &modelService{}
|
||
|
||
type modelService struct{}
|
||
|
||
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
|
||
func (s *modelService) IsSuperAdmin(ctx context.Context) (res bool, err error) {
|
||
headers := forwardHeaders(ctx)
|
||
var r = make(map[string]bool)
|
||
if err = http.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
||
return false, err
|
||
}
|
||
return r["isSuperAdmin"], err
|
||
}
|
||
|
||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||
// 获取当前会话模型
|
||
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
|
||
var model *entity.AsynchModel
|
||
model, err = dao.Model.GetByIsChatModel(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
// 如果有会话模型,那就改变为 0
|
||
if model != nil {
|
||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
||
ID: model.Id,
|
||
IsChatModel: gconv.PtrInt(0),
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
}
|
||
|
||
req.IsOwner = gconv.PtrInt(1)
|
||
admin, err := s.IsSuperAdmin(ctx)
|
||
if err != nil {
|
||
return
|
||
}
|
||
if admin {
|
||
req.IsOwner = gconv.PtrInt(0)
|
||
}
|
||
id, err := dao.Model.Insert(ctx, req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &dto.CreateModelRes{ID: id}, nil
|
||
}
|
||
|
||
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
|
||
//根据当前 isChatModel 来判断是否更新模型
|
||
if req.IsChatModel == gconv.PtrInt(1) {
|
||
//判断当前用户是否有会话模型
|
||
model, err := dao.Model.GetByIsChatModel(ctx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if model != nil {
|
||
return errors.New("用户已存在会话模型,不能创建")
|
||
}
|
||
}
|
||
|
||
req.IsOwner = gconv.PtrInt(1)
|
||
admin, err := s.IsSuperAdmin(ctx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if admin {
|
||
req.IsOwner = gconv.PtrInt(0)
|
||
_, err = dao.Model.Update(ctx, req)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
var user *beans.User
|
||
user, err = utils.GetUserInfo(ctx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建,否则更新
|
||
var count int
|
||
count, err = dao.Model.Count(ctx, &dto.GetModelReq{
|
||
ID: req.ID,
|
||
Creator: user.UserName,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if count == 0 {
|
||
insertDto := new(dto.CreateModelReq)
|
||
err = gconv.Struct(req, insertDto)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
_, err = dao.Model.Insert(ctx, insertDto)
|
||
return err
|
||
}
|
||
_, err = dao.Model.Update(ctx, req)
|
||
return err
|
||
}
|
||
|
||
func (s *modelService) Delete(ctx context.Context, id string) error {
|
||
_, err := dao.Model.DeleteByID(ctx, id)
|
||
return err
|
||
}
|
||
|
||
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
|
||
model, err := dao.Model.Get(ctx, id)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
model.Form = ParseJSONField(model.Form)
|
||
model.RequestMapping = ParseJSONField(model.RequestMapping)
|
||
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
|
||
model.ResponseBody = ParseJSONField(model.ResponseBody)
|
||
return model, nil
|
||
}
|
||
|
||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
||
var models []*entity.AsynchModel
|
||
|
||
req.IsOwner = gconv.PtrInt(1)
|
||
admin, err := s.IsSuperAdmin(ctx)
|
||
if err != nil {
|
||
return
|
||
}
|
||
if admin {
|
||
req.IsOwner = gconv.PtrInt(0)
|
||
}
|
||
|
||
var user *beans.User
|
||
user, err = utils.GetUserInfo(ctx)
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
req.Creator = user.UserName
|
||
|
||
models, total, err = dao.Model.GetByCreatorAndPlatform(ctx, req)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 处理列表中每条记录的 JSONB 字段
|
||
for _, m := range models {
|
||
m.Form = ParseJSONField(m.Form)
|
||
m.RequestMapping = ParseJSONField(m.RequestMapping)
|
||
m.ResponseMapping = ParseJSONField(m.ResponseMapping)
|
||
m.ResponseBody = ParseJSONField(m.ResponseBody)
|
||
}
|
||
return models, total, nil
|
||
}
|
||
|
||
// GetModelTypesFromConfig 从配置文件读取模型类型
|
||
func GetModelTypesFromConfig(ctx context.Context) map[int]string {
|
||
typeMap := make(map[int]string)
|
||
|
||
// 读取配置
|
||
configMap := g.Cfg().MustGet(ctx, "modelType.types").Map()
|
||
for k, v := range configMap {
|
||
typeID := gconv.Int(k)
|
||
typeName := gconv.String(v)
|
||
if typeID > 0 && typeName != "" {
|
||
typeMap[typeID] = typeName
|
||
}
|
||
}
|
||
// 如果配置为空,使用默认值
|
||
if len(typeMap) == 0 {
|
||
typeMap = map[int]string{
|
||
1: "推理模型",
|
||
2: "图片模型",
|
||
3: "音频模型",
|
||
4: "向量化模型",
|
||
5: "全模态模型",
|
||
}
|
||
}
|
||
return typeMap
|
||
}
|
||
|
||
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
|
||
// 校验新会话模型是否存在
|
||
newModel, err := dao.Model.Get(ctx, req.Id)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if newModel == nil {
|
||
return errors.New("新会话模型不存在")
|
||
}
|
||
|
||
// 获取当前用户会话模型
|
||
currentModel, err := dao.Model.GetByIsChatModel(ctx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||
if !g.IsEmpty(currentModel) {
|
||
if currentModel.ModelType != 1 {
|
||
return errors.New("当前模型为非推理模型,不能设置为会话模型")
|
||
}
|
||
|
||
// 如果点击的就是当前会话模型(已经是1),取消它(设为0)
|
||
if currentModel.Id != req.Id {
|
||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
||
ID: currentModel.Id,
|
||
IsChatModel: gconv.PtrInt(0),
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
}
|
||
|
||
// 设置当前为会话模型(设为1)
|
||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
||
ID: req.Id,
|
||
IsChatModel: gconv.PtrInt(1),
|
||
})
|
||
return err
|
||
})
|
||
return err
|
||
}
|
||
|
||
func (s *modelService) GetIsChatModel(ctx context.Context) (*entity.AsynchModel, error) {
|
||
model, err := dao.Model.GetByIsChatModel(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if model == nil {
|
||
return nil, nil
|
||
}
|
||
model.Form = ParseJSONField(model.Form)
|
||
model.RequestMapping = ParseJSONField(model.RequestMapping)
|
||
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
|
||
model.ResponseBody = ParseJSONField(model.ResponseBody)
|
||
return model, nil
|
||
}
|