196 lines
5.3 KiB
Go
196 lines
5.3 KiB
Go
package main
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"os"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
var text = "欢迎使用红动未来数字人服务平台,我们将为您提供最优质的AI数字人解决方案。"
|
||
|
||
type TTSCommonResponse struct {
|
||
Code int `json:"code"`
|
||
Msg string `json:"msg"`
|
||
Text string `json:"text"`
|
||
Audio string `json:"audio"`
|
||
}
|
||
|
||
func main() {
|
||
// 获取当前工作目录
|
||
outputDir, err := os.Getwd()
|
||
if err != nil {
|
||
fmt.Printf("获取当前目录失败: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
|
||
// 查找项目根目录(向上查找包含 go.mod 的目录)
|
||
outputDir = findProjectRoot(outputDir)
|
||
|
||
// 验证根目录是否正确(检查是否有 go.mod)
|
||
if _, err := os.Stat(outputDir + "/go.mod"); err != nil {
|
||
fmt.Printf("未找到项目根目录,当前目录: %s\n", outputDir)
|
||
os.Exit(1)
|
||
}
|
||
|
||
fmt.Println("=================== TTS测试开始 ===================")
|
||
fmt.Printf("输出目录: %s\n", outputDir)
|
||
fmt.Printf("随机文本: %s\n", text)
|
||
fmt.Printf("请求URL: http://127.0.0.1:8000/tts\n")
|
||
|
||
// 创建带超时的 HTTP 客户端(120秒超时)
|
||
client := &http.Client{
|
||
Timeout: 120 * time.Second,
|
||
}
|
||
|
||
resp, err := client.Post("http://127.0.0.1:8000/tts", "application/json", bytes.NewBufferString(fmt.Sprintf(`"%s"`, text)))
|
||
if err != nil {
|
||
fmt.Printf("请求失败: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 打印响应头
|
||
fmt.Printf("Content-Type: %s\n", resp.Header.Get("Content-Type"))
|
||
fmt.Printf("Content-Length: %s\n", resp.Header.Get("Content-Length"))
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
fmt.Printf("读取响应失败: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
|
||
fmt.Printf("状态码: %d, 响应大小: %d字节\n", resp.StatusCode, len(body))
|
||
|
||
// 打印响应内容的前200字节(用于调试)
|
||
if len(body) > 0 {
|
||
previewLen := minInt(200, len(body))
|
||
fmt.Printf("响应内容预览(前%d字节): ", previewLen)
|
||
if len(body) >= 4 && string(body[:4]) == "RIFF" {
|
||
// WAV文件头
|
||
fmt.Printf("WAV文件格式 (RIFF...)\n")
|
||
} else if len(body) >= 3 && string(body[:3]) == "ID3" {
|
||
// MP3 ID3标签
|
||
fmt.Printf("MP3 ID3格式\n")
|
||
} else if len(body) >= 2 && body[0] == 0xFF && (body[1]&0xE0) == 0xE0 {
|
||
// MP3帧同步
|
||
fmt.Printf("MP3帧格式\n")
|
||
} else {
|
||
// 可能是JSON或其他格式
|
||
fmt.Printf("%s\n", string(body[:previewLen]))
|
||
}
|
||
} else {
|
||
fmt.Printf("响应内容为空!\n")
|
||
os.Exit(1)
|
||
}
|
||
|
||
// 尝试解析JSON响应(包含base64音频)
|
||
var commonResp TTSCommonResponse
|
||
var audioData []byte
|
||
var ext string
|
||
|
||
if json.Unmarshal(body, &commonResp) == nil && commonResp.Audio != "" && commonResp.Audio != "base64_placeholder" {
|
||
fmt.Printf("检测到JSON响应,code=%d, msg=%s\n", commonResp.Code, commonResp.Msg)
|
||
fmt.Printf("Audio字段长度: %d 字符\n", len(commonResp.Audio))
|
||
|
||
// 检查是否成功
|
||
if commonResp.Code != 0 {
|
||
fmt.Printf("TTS服务返回错误: %s\n", commonResp.Msg)
|
||
os.Exit(1)
|
||
}
|
||
|
||
// 解码base64音频数据
|
||
decoded, err := base64.StdEncoding.DecodeString(commonResp.Audio)
|
||
if err != nil {
|
||
fmt.Printf("base64解码失败: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
if len(decoded) == 0 {
|
||
fmt.Printf("解码后数据为空!\n")
|
||
os.Exit(1)
|
||
}
|
||
|
||
audioData = decoded
|
||
fmt.Printf("解码后音频数据大小: %d 字节\n", len(audioData))
|
||
|
||
// 根据解码后的音频数据格式决定扩展名
|
||
if len(audioData) >= 4 && string(audioData[:4]) == "RIFF" {
|
||
ext = ".wav"
|
||
fmt.Printf("检测到WAV格式\n")
|
||
} else if len(audioData) >= 3 && string(audioData[:3]) == "ID3" || (len(audioData) >= 2 && audioData[0] == 0xFF && (audioData[1]&0xE0) == 0xE0) {
|
||
ext = ".mp3"
|
||
fmt.Printf("检测到MP3格式\n")
|
||
} else {
|
||
ext = ".wav" // 默认wav
|
||
fmt.Printf("未知格式,默认保存为 .wav\n")
|
||
}
|
||
} else {
|
||
// 直接是二进制音频数据
|
||
audioData = body
|
||
|
||
// 根据音频数据格式决定扩展名
|
||
if len(audioData) >= 4 && string(audioData[:4]) == "RIFF" {
|
||
ext = ".wav"
|
||
} else if len(audioData) >= 3 && string(audioData[:3]) == "ID3" || (len(audioData) >= 2 && audioData[0] == 0xFF && (audioData[1]&0xE0) == 0xE0) {
|
||
ext = ".mp3"
|
||
} else {
|
||
ext = ".wav" // 默认wav
|
||
}
|
||
}
|
||
|
||
// 保存音频文件
|
||
filename := fmt.Sprintf("%s/tts_output_%d%s", outputDir, time.Now().Unix(), ext)
|
||
if err = os.WriteFile(filename, audioData, 0644); err != nil {
|
||
fmt.Printf("写文件失败: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
fmt.Printf("音频已保存: %s (%d字节)\n", filename, len(audioData))
|
||
fmt.Println("=================== TTS测试成功 ===================")
|
||
}
|
||
|
||
func maxInt(a, b int) int {
|
||
if a > b {
|
||
return a
|
||
}
|
||
return b
|
||
}
|
||
|
||
func minInt(a, b int) int {
|
||
if a < b {
|
||
return a
|
||
}
|
||
return b
|
||
}
|
||
|
||
// findProjectRoot 查找项目根目录(包含 go.mod 的目录)
|
||
func findProjectRoot(startDir string) string {
|
||
dir := startDir
|
||
for {
|
||
// 检查当前目录是否有 go.mod
|
||
if _, err := os.Stat(dir + "/go.mod"); err == nil {
|
||
return dir
|
||
}
|
||
|
||
// 如果已经是根目录或无法继续向上查找,返回当前目录
|
||
parentDir := dir[:maxInt(0, len(dir)-len("/"+getLastPathSegment(dir)))]
|
||
if parentDir == dir || parentDir == "" {
|
||
return startDir
|
||
}
|
||
|
||
dir = parentDir
|
||
}
|
||
}
|
||
|
||
// getLastPathSegment 获取路径的最后一部分
|
||
func getLastPathSegment(path string) string {
|
||
if idx := strings.LastIndex(path, "/"); idx != -1 {
|
||
return path[idx+1:]
|
||
}
|
||
return path
|
||
}
|