Compare commits

...

2 Commits

View File

@@ -25,8 +25,17 @@ type modelService struct{}
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
// 获取当前会话模型
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
// 获取当前用户会话模型
var model *entity.AsynchModel
model, err = dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
IsChatModel: new(1),
})
if err != nil {
@@ -88,8 +97,15 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
//根据当前 isChatModel 来判断是否更新模型
if req.IsChatModel == gconv.PtrInt(1) {
//判断当前用户是否有会话模型
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
// 获取当前用户会话模型
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
IsChatModel: new(1),
})
if err != nil {
@@ -298,8 +314,16 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
if newModel == nil {
return errors.New("新会话模型不存在")
}
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return err
}
// 获取当前用户会话模型
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
IsChatModel: new(1),
})
if err != nil {
@@ -307,7 +331,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
if !g.IsEmpty(currentModel) {
if currentModel.ModelType != 1 {
if currentModel.ModelType != public.ModelTypeInference {
return errors.New("当前模型为非推理模型,不能设置为会话模型")
}
@@ -325,7 +349,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
// 设置当前为会话模型设为1
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
IsChatModel: gconv.PtrInt(1),
})
return err