404 lines
12 KiB
Go
404 lines
12 KiB
Go
package asr
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
"media/service/setup"
|
||
"net/http"
|
||
"os"
|
||
"os/exec"
|
||
"path/filepath"
|
||
"runtime"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
)
|
||
|
||
// WhisperBackend 后端类型
|
||
type WhisperBackend int
|
||
|
||
const (
|
||
backendPython WhisperBackend = iota // python -m whisper
|
||
backendCLI // openai-whisper CLI (whisper 命令)
|
||
backendCpp // whisper.cpp (whisper-cpp)
|
||
)
|
||
|
||
type whisperService struct{}
|
||
|
||
// Whisper 语音识别服务单例
|
||
var Whisper = new(whisperService)
|
||
|
||
// TranscribeReq 语音识别请求
|
||
type TranscribeReq struct {
|
||
AudioPath string // 音频文件路径
|
||
Model string // whisper 模型: tiny/base/small/medium/large
|
||
Language string // 语言代码,默认 zh(中文)
|
||
}
|
||
|
||
// TranscribeRes 语音识别响应
|
||
type TranscribeRes struct {
|
||
Text string // 完整识别文本
|
||
Segments []Segment
|
||
Model string // 使用的模型
|
||
Language string // 识别的语言
|
||
OutputPath string // 输出的 txt 文件路径
|
||
}
|
||
|
||
// Segment 识别片段(带时间戳)
|
||
type Segment struct {
|
||
Start float64 `json:"start"` // 开始时间(秒)
|
||
End float64 `json:"end"` // 结束时间(秒)
|
||
Text string `json:"text"` // 文本内容
|
||
}
|
||
|
||
// Transcribe 对音频文件进行语音识别(自动检测后端,自动降级)
|
||
func (s *whisperService) Transcribe(ctx context.Context, req *TranscribeReq) (res *TranscribeRes, err error) {
|
||
// 1. 校验音频文件
|
||
if _, err = os.Stat(req.AudioPath); os.IsNotExist(err) {
|
||
return nil, fmt.Errorf("音频文件不存在: %s", req.AudioPath)
|
||
}
|
||
|
||
// 2. 设置默认值
|
||
model := req.Model
|
||
if model == "" {
|
||
model = g.Cfg().MustGet(ctx, "whisper.model", "small").String()
|
||
}
|
||
language := req.Language
|
||
if language == "" {
|
||
language = g.Cfg().MustGet(ctx, "whisper.language", "zh").String()
|
||
}
|
||
|
||
// 3. 检测后端,C++ 版找不到模型文件时自动降级
|
||
backend, whisperPath := s.detectBackend()
|
||
if backend == backendCpp {
|
||
modelPath := s.resolveCppModelPath(model)
|
||
if modelPath == "" {
|
||
g.Log().Warningf(ctx, "whisper.cpp 模型文件(%s)未找到,降级到 Python whisper", model)
|
||
backend = backendPython
|
||
} else {
|
||
g.Log().Infof(ctx, "语音识别(whisper.cpp): audio=%s, model=%s", req.AudioPath, modelPath)
|
||
return s.transcribeWithCpp(ctx, req, whisperPath, modelPath, language)
|
||
}
|
||
}
|
||
|
||
switch backend {
|
||
case backendCLI:
|
||
g.Log().Infof(ctx, "语音识别(CLI): audio=%s, model=%s, language=%s", req.AudioPath, model, language)
|
||
return s.transcribeWithCLI(ctx, req, whisperPath, model, language)
|
||
default:
|
||
g.Log().Infof(ctx, "语音识别(python): audio=%s, model=%s, language=%s", req.AudioPath, model, language)
|
||
return s.transcribeWithPython(ctx, req, model, language)
|
||
}
|
||
}
|
||
|
||
// transcribeWithCLI 使用 whisper CLI 命令
|
||
func (s *whisperService) transcribeWithCLI(ctx context.Context, req *TranscribeReq, whisperPath, model, language string) (res *TranscribeRes, err error) {
|
||
outputDir := filepath.Dir(req.AudioPath)
|
||
modelDir := g.Cfg().MustGet(ctx, "whisper.model_dir", "").String()
|
||
threads := g.Cfg().MustGet(ctx, "whisper.threads", 2).Int()
|
||
|
||
args := []string{
|
||
req.AudioPath,
|
||
"--model", model,
|
||
"--language", language,
|
||
"--output_dir", outputDir,
|
||
"--output_format", "txt",
|
||
"--threads", fmt.Sprintf("%d", threads),
|
||
}
|
||
if modelDir != "" {
|
||
args = append(args, "--model_dir", modelDir)
|
||
}
|
||
|
||
cmd := exec.CommandContext(ctx, whisperPath, args...)
|
||
output, execErr := cmd.CombinedOutput()
|
||
if execErr != nil {
|
||
g.Log().Errorf(ctx, "whisper CLI 执行失败: %v\n%s", execErr, string(output))
|
||
return nil, fmt.Errorf("语音识别失败: %v", execErr)
|
||
}
|
||
|
||
return s.readTxtResult(outputDir, req.AudioPath, model)
|
||
}
|
||
|
||
// transcribeWithPython 使用 python -m whisper
|
||
func (s *whisperService) transcribeWithPython(ctx context.Context, req *TranscribeReq, model, language string) (res *TranscribeRes, err error) {
|
||
// 查找 python
|
||
pythonPath, err := exec.LookPath("python3")
|
||
if err != nil {
|
||
pythonPath, err = exec.LookPath("python")
|
||
if err != nil {
|
||
return nil, fmt.Errorf("未找到 python,请安装: pip3 install openai-whisper")
|
||
}
|
||
}
|
||
|
||
outputDir := filepath.Dir(req.AudioPath)
|
||
modelDir := g.Cfg().MustGet(ctx, "whisper.model_dir", "").String()
|
||
threads := g.Cfg().MustGet(ctx, "whisper.threads", 2).Int()
|
||
|
||
args := []string{
|
||
"-m", "whisper",
|
||
req.AudioPath,
|
||
"--model", model,
|
||
"--language", language,
|
||
"--output_dir", outputDir,
|
||
"--output_format", "txt",
|
||
"--threads", fmt.Sprintf("%d", threads),
|
||
}
|
||
if modelDir != "" {
|
||
args = append(args, "--model_dir", modelDir)
|
||
}
|
||
|
||
cmd := exec.CommandContext(ctx, pythonPath, args...)
|
||
output, execErr := cmd.CombinedOutput()
|
||
if execErr != nil {
|
||
g.Log().Errorf(ctx, "whisper(python) 执行失败: %v\n%s", execErr, string(output))
|
||
return nil, fmt.Errorf("语音识别失败: %v", execErr)
|
||
}
|
||
|
||
return s.readTxtResult(outputDir, req.AudioPath, model)
|
||
}
|
||
|
||
// readTxtResult 读取 whisper 输出的 txt 文件
|
||
func (s *whisperService) readTxtResult(outputDir, audioPath, model string) (res *TranscribeRes, err error) {
|
||
baseName := strings.TrimSuffix(filepath.Base(audioPath), filepath.Ext(audioPath))
|
||
txtPaths := []string{
|
||
filepath.Join(outputDir, baseName+".txt"),
|
||
filepath.Join(outputDir, baseName+"."+model+".txt"),
|
||
}
|
||
|
||
var textBytes []byte
|
||
var txtPath string
|
||
for _, p := range txtPaths {
|
||
if b, e := os.ReadFile(p); e == nil {
|
||
textBytes = b
|
||
txtPath = p
|
||
break
|
||
}
|
||
}
|
||
if textBytes == nil {
|
||
return nil, fmt.Errorf("读取识别结果文件失败")
|
||
}
|
||
|
||
res = &TranscribeRes{
|
||
Text: cleanTranscript(string(textBytes)),
|
||
Model: model,
|
||
OutputPath: txtPath,
|
||
}
|
||
return
|
||
}
|
||
|
||
// cleanTranscript 清理识别结果:去换行、合并空格
|
||
func cleanTranscript(text string) string {
|
||
text = strings.ReplaceAll(text, "\r\n", " ")
|
||
text = strings.ReplaceAll(text, "\n", " ")
|
||
text = strings.ReplaceAll(text, "\r", " ")
|
||
// 合并多个空格
|
||
for strings.Contains(text, " ") {
|
||
text = strings.ReplaceAll(text, " ", " ")
|
||
}
|
||
return strings.TrimSpace(text)
|
||
}
|
||
|
||
// detectBackend 检测可用的 whisper 后端,返回后端类型和可执行路径
|
||
func (s *whisperService) detectBackend() (WhisperBackend, string) {
|
||
// 1. 优先检测 C++ 版 whisper.cpp(最快,但参数格式不同)
|
||
for _, name := range []string{"whisper-cpp", "whisper-cli"} {
|
||
if path, err := exec.LookPath(name); err == nil {
|
||
return backendCpp, path
|
||
}
|
||
}
|
||
|
||
// 2. 检查 setup 检测到的 C++ 路径
|
||
if setup.DetectedWhisperPath != "" {
|
||
base := filepath.Base(setup.DetectedWhisperPath)
|
||
if base == "whisper-cpp" || base == "whisper-cli" {
|
||
if _, err := os.Stat(setup.DetectedWhisperPath); err == nil {
|
||
return backendCpp, setup.DetectedWhisperPath
|
||
}
|
||
}
|
||
}
|
||
|
||
// 3. 检测 Python CLI(whisper 命令)
|
||
if path, err := exec.LookPath("whisper"); err == nil {
|
||
return backendCLI, path
|
||
}
|
||
|
||
// 4. 检查 setup 检测到的 Python CLI 路径
|
||
if setup.DetectedWhisperPath != "" {
|
||
if _, err := os.Stat(setup.DetectedWhisperPath); err == nil {
|
||
return backendCLI, setup.DetectedWhisperPath
|
||
}
|
||
}
|
||
|
||
// 5. 检查配置中的路径
|
||
if p := g.Cfg().MustGet(context.Background(), "whisper.path", "").String(); p != "" {
|
||
if _, err := os.Stat(p); err == nil {
|
||
return backendCLI, p
|
||
}
|
||
}
|
||
|
||
return backendPython, ""
|
||
}
|
||
|
||
// resolveCppModelPath 查找或下载 whisper.cpp 模型文件
|
||
func (s *whisperService) resolveCppModelPath(model string) string {
|
||
modelName := strings.TrimPrefix(model, "ggml-")
|
||
modelName = strings.TrimSuffix(modelName, ".bin")
|
||
|
||
cppModelName := "ggml-" + modelName + ".bin"
|
||
home, _ := os.UserHomeDir()
|
||
|
||
// 目标路径:~/.cache/whisper/ggml-{model}.bin
|
||
targetDir := filepath.Join(home, ".cache", "whisper")
|
||
targetPath := filepath.Join(targetDir, cppModelName)
|
||
|
||
// 1. 如果已存在,直接返回
|
||
if _, err := os.Stat(targetPath); err == nil {
|
||
return targetPath
|
||
}
|
||
|
||
// 2. 检查其他常见位置
|
||
altPaths := []string{
|
||
cppModelName,
|
||
filepath.Join(home, ".cache", "whisper", "ggml-"+modelName+"-q5_0.bin"),
|
||
}
|
||
// macOS: Homebrew 安装的 whisper.cpp 模型路径
|
||
if runtime.GOOS == "darwin" {
|
||
altPaths = append(altPaths,
|
||
"/opt/homebrew/share/whisper-cpp/models/"+cppModelName,
|
||
"/usr/local/share/whisper-cpp/models/"+cppModelName,
|
||
)
|
||
}
|
||
// Linux: 常见系统安装路径
|
||
if runtime.GOOS == "linux" {
|
||
altPaths = append(altPaths,
|
||
"/usr/share/whisper-cpp/models/"+cppModelName,
|
||
"/usr/local/share/whisper-cpp/models/"+cppModelName,
|
||
)
|
||
}
|
||
for _, p := range altPaths {
|
||
if _, err := os.Stat(p); err == nil {
|
||
return p
|
||
}
|
||
}
|
||
|
||
// 3. 自动下载
|
||
modelSize := map[string]string{
|
||
"tiny": "75MB",
|
||
"base": "150MB",
|
||
"small": "500MB",
|
||
"medium": "1.5GB",
|
||
}
|
||
size, _ := modelSize[modelName]
|
||
|
||
// 下载源:先试 hf-mirror(国内可访问),失败再试官方
|
||
modelPath := fmt.Sprintf("ggerganov/whisper.cpp/resolve/main/%s", cppModelName)
|
||
urls := []string{
|
||
fmt.Sprintf("https://hf-mirror.com/%s", modelPath),
|
||
fmt.Sprintf("https://huggingface.co/%s", modelPath),
|
||
}
|
||
|
||
g.Log().Infof(context.TODO(), "[whisper.cpp] 正在下载模型 %s (%s)...", cppModelName, size)
|
||
|
||
// 创建目录
|
||
os.MkdirAll(targetDir, 0755)
|
||
|
||
// 下载文件(多个源,依次尝试)
|
||
var lastErr error
|
||
for _, url := range urls {
|
||
g.Log().Infof(context.TODO(), "[whisper.cpp] 下载地址: %s", url)
|
||
if err := s.downloadFile(url, targetPath, 5*time.Minute); err == nil {
|
||
g.Log().Infof(context.TODO(), "[whisper.cpp] 模型下载完成: %s", targetPath)
|
||
return targetPath
|
||
} else {
|
||
lastErr = err
|
||
g.Log().Warningf(context.TODO(), "[whisper.cpp] 从 %s 下载失败: %v,尝试下一个源...", url, err)
|
||
}
|
||
}
|
||
|
||
g.Log().Errorf(context.TODO(), "[whisper.cpp] 所有下载源均失败: %v", lastErr)
|
||
return ""
|
||
}
|
||
|
||
// downloadFile 下载文件到指定路径(支持超时)
|
||
func (s *whisperService) downloadFile(url, destPath string, timeout time.Duration) error {
|
||
tmpPath := destPath + ".tmp"
|
||
out, err := os.Create(tmpPath)
|
||
if err != nil {
|
||
return fmt.Errorf("创建临时文件失败: %v", err)
|
||
}
|
||
defer out.Close()
|
||
|
||
client := &http.Client{Timeout: timeout}
|
||
resp, err := client.Get(url)
|
||
if err != nil {
|
||
os.Remove(tmpPath)
|
||
return err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
os.Remove(tmpPath)
|
||
return fmt.Errorf("HTTP %d", resp.StatusCode)
|
||
}
|
||
|
||
written, err := io.Copy(out, resp.Body)
|
||
if err != nil {
|
||
os.Remove(tmpPath)
|
||
return err
|
||
}
|
||
|
||
if err := os.Rename(tmpPath, destPath); err != nil {
|
||
return fmt.Errorf("文件重命名失败: %v", err)
|
||
}
|
||
|
||
g.Log().Infof(context.TODO(), "[whisper.cpp] 下载完成: %d bytes", written)
|
||
return nil
|
||
}
|
||
|
||
// transcribeWithCpp 使用 whisper.cpp(C++ 版,参数格式不同)
|
||
func (s *whisperService) transcribeWithCpp(ctx context.Context, req *TranscribeReq, binaryPath, model, language string) (res *TranscribeRes, err error) {
|
||
outputDir := filepath.Dir(req.AudioPath)
|
||
baseName := strings.TrimSuffix(filepath.Base(req.AudioPath), filepath.Ext(req.AudioPath))
|
||
outputPrefix := filepath.Join(outputDir, baseName)
|
||
threads := g.Cfg().MustGet(ctx, "whisper.threads", 2).Int()
|
||
|
||
// whisper.cpp 参数:
|
||
// -f input.mp3 输入文件
|
||
// -l zh 语言
|
||
// -t 2 线程数
|
||
// -otxt 输出 txt
|
||
// -of /path/prefix 输出文件前缀(自动加 .txt)
|
||
args := []string{
|
||
"-f", req.AudioPath,
|
||
"-l", language,
|
||
"-t", fmt.Sprintf("%d", threads),
|
||
"-otxt",
|
||
"-of", outputPrefix,
|
||
"-m", model,
|
||
}
|
||
|
||
cmd := exec.CommandContext(ctx, binaryPath, args...)
|
||
output, execErr := cmd.CombinedOutput()
|
||
if execErr != nil {
|
||
g.Log().Errorf(ctx, "whisper.cpp 执行失败: %v\n%s", execErr, string(output))
|
||
return nil, fmt.Errorf("语音识别失败: %v", execErr)
|
||
}
|
||
|
||
// whisper.cpp 输出: {prefix}.txt
|
||
txtPath := outputPrefix + ".txt"
|
||
textBytes, readErr := os.ReadFile(txtPath)
|
||
if readErr != nil {
|
||
return nil, fmt.Errorf("读取识别结果文件失败: %v", readErr)
|
||
}
|
||
|
||
res = &TranscribeRes{
|
||
Text: cleanTranscript(string(textBytes)),
|
||
Model: model,
|
||
Language: language,
|
||
OutputPath: txtPath,
|
||
}
|
||
return
|
||
}
|