add multi voice support

This commit is contained in:
kimpure 2026-05-19 17:47:17 +00:00
parent 6380daa4a4
commit caa6db37ce
No known key found for this signature in database
7 changed files with 132 additions and 97 deletions

View file

@ -9,11 +9,14 @@ use serde::{Deserialize, Serialize};
pub struct Body {
text: String,
lang: String,
style_id: String,
}
pub async fn handler(
State(state): State<Arc<TtsPool>>,
Json(payload): Json<Body>,
) -> Result<Vec<u8>, String> {
Ok(state.synthesize(payload.text, payload.lang).await?)
Ok(state
.synthesize(payload.text, payload.lang, payload.style_id)
.await?)
}

View file

@ -1,12 +1,14 @@
use axum::{Router, routing::post};
use std::path::PathBuf;
use std::sync::Arc;
use std::{collections::HashMap, path::PathBuf};
pub mod api;
pub mod tts;
use tts::{TtsOpts, TtsPool, load_text_to_speech, load_voice_style};
use crate::tts::engine::load_voice_style_map;
#[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
dotenvy::dotenv().ok();
@ -15,7 +17,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let model_dir =
std::env::var("SUPERTONIC_MODEL_DIR").unwrap_or_else(|_| "./assets".to_string());
let voice_style_path = std::env::var("SUPERTONIC_VOICE_STYLE")
.unwrap_or_else(|_| format!("{model_dir}/voice_styles/M1.json"));
.unwrap_or_else(|_| format!("F1={model_dir}/voice_styles/F1.json"));
let lang = std::env::var("SUPERTONIC_LANG").unwrap_or_else(|_| "en".to_string());
let total_step: usize = std::env::var("SUPERTONIC_TOTAL_STEP")
.ok()
@ -40,15 +42,26 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
tts::assets::ensure_assets(&model_path, &hf_repo)?;
let onnx_dir_for_init = model_path.join("onnx").to_string_lossy().into_owned();
let voice_style_for_init = voice_style_path.clone();
let voice_style_for_init = voice_style_path
.split(",")
.filter_map(|i| {
i.split_once("=").or_else(|| {
tracing::error!("Voice style '{i}' is not valid.");
None
})
})
.map(|(k, v)| (k.to_owned(), PathBuf::from(v)))
.collect::<HashMap<_, _>>();
let voice_style_map = Arc::new(load_voice_style_map(&voice_style_for_init)?);
let pool = Arc::new(TtsPool::spawn(
workers,
move |id| {
let span = tracing::info_span!("worker", worker_id = id);
let _enter = span.enter();
let tts = load_text_to_speech(&onnx_dir_for_init)?;
let style = load_voice_style(std::slice::from_ref(&voice_style_for_init), false)?;
Ok((tts, style))
Ok((tts, voice_style_map.clone()))
},
TtsOpts {
total_step,

View file

@ -8,7 +8,7 @@ use serde::Deserialize;
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
@ -282,61 +282,58 @@ fn sample_noisy_latent(
(noisy_latent, latent_mask)
}
pub fn load_voice_style(voice_style_paths: &[String], verbose: bool) -> Result<Style> {
let bsz = voice_style_paths.len();
pub fn load_voice_style_map(voice_style_path_map: &HashMap<String, PathBuf>) -> Result<HashMap<String, Style>> {
let mut map = HashMap::new();
let first_file =
File::open(&voice_style_paths[0]).context("Failed to open voice style file")?;
let first_reader = BufReader::new(first_file);
let first_data: VoiceStyleData = serde_json::from_reader(first_reader)?;
for (name, path) in voice_style_path_map.iter() {
map.insert(name.to_owned(), load_voice_style(path)?);
tracing::info!("Voice style {name} loaded")
}
let ttl_dims = &first_data.style_ttl.dims;
let dp_dims = &first_data.style_dp.dims;
Ok(map)
}
pub fn load_voice_style(voice_style_path: &Path) -> Result<Style> {
let file =
File::open(&voice_style_path).context("Failed to open voice style file")?;
let file_reader = BufReader::new(file);
let file_data: VoiceStyleData = serde_json::from_reader(file_reader)?;
let ttl_dims = &file_data.style_ttl.dims;
let dp_dims = &file_data.style_dp.dims;
let ttl_dim1 = ttl_dims[1];
let ttl_dim2 = ttl_dims[2];
let dp_dim1 = dp_dims[1];
let dp_dim2 = dp_dims[2];
let ttl_size = bsz * ttl_dim1 * ttl_dim2;
let dp_size = bsz * dp_dim1 * dp_dim2;
let ttl_size = ttl_dim1 * ttl_dim2;
let dp_size = dp_dim1 * dp_dim2;
let mut ttl_flat = vec![0.0f32; ttl_size];
let mut dp_flat = vec![0.0f32; dp_size];
for (i, path) in voice_style_paths.iter().enumerate() {
let file = File::open(path).context("Failed to open voice style file")?;
let reader = BufReader::new(file);
let data: VoiceStyleData = serde_json::from_reader(reader)?;
let ttl_offset = i * ttl_dim1 * ttl_dim2;
let mut idx = 0;
for batch in &data.style_ttl.data {
for batch in &file_data.style_ttl.data {
for row in batch {
for &val in row {
ttl_flat[ttl_offset + idx] = val;
ttl_flat[idx] = val;
idx += 1;
}
}
}
let dp_offset = i * dp_dim1 * dp_dim2;
idx = 0;
for batch in &data.style_dp.data {
for batch in &file_data.style_dp.data {
for row in batch {
for &val in row {
dp_flat[dp_offset + idx] = val;
dp_flat[idx] = val;
idx += 1;
}
}
}
}
let ttl_style = Array3::from_shape_vec((bsz, ttl_dim1, ttl_dim2), ttl_flat)?;
let dp_style = Array3::from_shape_vec((bsz, dp_dim1, dp_dim2), dp_flat)?;
if verbose {
tracing::info!("Loaded {} voice styles", bsz);
}
let ttl_style = Array3::from_shape_vec((1, ttl_dim1, ttl_dim2), ttl_flat)?;
let dp_style = Array3::from_shape_vec((1, dp_dim1, dp_dim2), dp_flat)?;
Ok(Style {
ttl: ttl_style,

View file

@ -1,5 +1,6 @@
use anyhow::{Context, Result};
use crossbeam_channel::{Sender, bounded};
use std::collections::HashMap;
use std::sync::Arc;
use std::thread;
use std::time::Instant;
@ -18,6 +19,7 @@ pub struct TtsOpts {
pub struct TtsJob {
pub text: String,
pub lang: String,
pub style_id: String,
pub reply: oneshot::Sender<Result<Vec<u8>, String>>,
}
@ -28,7 +30,7 @@ pub struct TtsPool {
impl TtsPool {
pub fn spawn<I>(workers: usize, init: I, opts: TtsOpts) -> Result<Self>
where
I: Fn(u32) -> Result<(TextToSpeech, Style)> + Send + Sync + 'static,
I: Fn(u32) -> Result<(TextToSpeech, Arc<HashMap<String, Style>>)> + Send + Sync + 'static,
{
let workers = workers.max(1);
let (tx, rx) = bounded::<TtsJob>(workers * 4);
@ -40,7 +42,7 @@ impl TtsPool {
thread::Builder::new()
.name(format!("supertonic-tts-{worker_id}"))
.spawn(move || {
let (mut tts, style) = match init(worker_id as u32) {
let (mut tts, style_map) = match init(worker_id as u32) {
Ok(pair) => pair,
Err(e) => {
tracing::error!("worker {worker_id} init failed: {e:?}");
@ -52,6 +54,12 @@ impl TtsPool {
while let Ok(job) = rx.recv() {
let start_at = Instant::now();
let Some(style) = style_map.get(&job.style_id) else {
job.reply
.send(Err(format!("Voice style {} not found", job.style_id)));
continue;
};
let result = (|| -> Result<Vec<u8>, String> {
let (wav, _dur) = tts
.synthesize(
@ -77,11 +85,17 @@ impl TtsPool {
Ok(TtsPool { tx })
}
pub async fn synthesize(&self, text: String, lang: String) -> Result<Vec<u8>, String> {
pub async fn synthesize(
&self,
text: String,
lang: String,
style_id: String,
) -> Result<Vec<u8>, String> {
let (reply_tx, reply_rx) = oneshot::channel();
let job = TtsJob {
text,
lang,
style_id,
reply: reply_tx,
};

View file

@ -1,7 +1,7 @@
import { getOrCreateVoiceConnection } from "../util";
import { getUserProfile, hasGuildReadChannel } from "../db";
import { defineEvent } from "../event";
import { playVoice } from "../tts";
import { playVoice, PlayVoiceOptions } from "../tts";
import { Voice } from "../../db/generated/prisma/enums";
export default defineEvent("messageCreate", async (message) => {
@ -20,18 +20,25 @@ export default defineEvent("messageCreate", async (message) => {
let content = message.cleanContent;
let voice: Voice | null = null
let options: PlayVoiceOptions | undefined = undefined;
let matched: RegExpMatchArray | null = null;
if (content.startsWith("$t ")) {
voice = "TypeCast";
} else if (content.startsWith("$p ")) {
voice = "Papago";
} else if (content.startsWith("$s ")) {
} else if (matched = content.match(/^\$s(\S*) /)) {
voice = "Supertonic";
let style: string | undefined = undefined
if (matched[1].length) {
style = matched[1]
}
options = { supertonicStyleId: style };
} else if (content.match(/^\$\s/)) {
return;
}
if (voice) {
content = content.replace(/^\$[^ ]+ +/, "")
content = content.replace(/^\$\S+\s+/, "")
} else {
voice = profile.voice;
}
@ -61,5 +68,6 @@ export default defineEvent("messageCreate", async (message) => {
} catch(err) {
message.reply("말이 꼬이네요 ㅜ.ㅜ");
console.log("playVoice failed. ", err);
}
})

View file

@ -10,38 +10,6 @@ import { nyaize } from "../utils/nyaize";
import { OutputHandler } from "../utils/outputHandler";
import TTSSupertonicModel from "../tts/supertonic";
export async function createVoiceBuffer(voice: Voice, text: string): Promise<Buffer> {
if (voice == "TypeCast") {
const content = TTSTypecastModel.instance.ttsify(text);
if (!content.length)
throw new Error("Empty content");
return await TTSTypecastModel.instance.getMemcachedVoice(
TTSTypecastModel.instance.createRequestId(content)
);
} else if (voice == "Supertonic") {
const content = TTSSupertonicModel.instance.ttsify(text);
if (!content.length)
throw new Error("Empty content");
return await TTSSupertonicModel.instance.getMemcachedVoice(
TTSSupertonicModel.instance.createRequestId(content)
);
} else if (voice == "Papago") {
const content = TTSPapagoModel.instance.ttsify(text);
if (!content.length)
throw new Error("Empty content");
return await TTSPapagoModel.instance.getMemcachedVoice(
TTSPapagoModel.instance.createRequestId(content)
);
} else {
throw new Error(`Unknown voice type: ${voice}`);
}
}
class VoiceQueue {
private connection: VoiceConnection;
private list: AudioResource[];
@ -77,11 +45,16 @@ class VoiceQueue {
}
}
export type PlayVoiceOptions = {
supertonicStyleId?: string,
};
export async function playVoice(
guild: Guild,
profile: DiscordUserProfile,
voice: Voice,
text: string
text: string,
options?: PlayVoiceOptions,
) {
if (profile.nya)
text = nyaize(text);
@ -91,17 +64,42 @@ export async function playVoice(
if (!connection)
throw new Error("Yaeju is not joined VoiceChat");
let voiceBuffer: Buffer;
if (voice == "TypeCast") {
if (profile.canTypecast) {
voiceBuffer = await createVoiceBuffer(voice, text);
} else {
if (voice == "TypeCast" && profile.canTypecast) {
throw new Error(`the user ${profile.userId} is can't use typecast voice`);
}
let voiceBuffer: Buffer
if (voice == "TypeCast") {
const content = TTSTypecastModel.instance.ttsify(text);
if (!content.length)
throw new Error("Empty content");
voiceBuffer = await TTSTypecastModel.instance.getMemcachedVoice(
TTSTypecastModel.instance.createRequestId(content)
);
} else if (voice == "Supertonic") {
const content = TTSSupertonicModel.instance.ttsify(text);
if (!content.length)
throw new Error("Empty content");
voiceBuffer = await TTSSupertonicModel.instance.getMemcachedVoice(
TTSSupertonicModel.instance.createRequestId(content, options?.supertonicStyleId)
);
} else if (voice == "Papago") {
const content = TTSPapagoModel.instance.ttsify(text);
if (!content.length)
throw new Error("Empty content");
voiceBuffer = await TTSPapagoModel.instance.getMemcachedVoice(
TTSPapagoModel.instance.createRequestId(content)
);
} else {
voiceBuffer = await createVoiceBuffer(voice, text);
throw new Error(`Unknown voice type: ${voice}`);
}
VoiceQueue.fromConnection(connection).enqueue(
TTSModelBase.bufferToAudioResource(voiceBuffer)
);

View file

@ -18,6 +18,7 @@ export class TTSSupertonicModel extends TTSModelBase<TTSSupertonicModel.RequestI
const payload = {
text: voiceId.text,
lang: "ko",
style_id: voiceId.styleId
};
if (!process.env.SUPERTONIC_API_URL) {
@ -42,22 +43,23 @@ export class TTSSupertonicModel extends TTSModelBase<TTSSupertonicModel.RequestI
throw new Error(`invalid supertonic response ${await response.text()}`);
}
public getVoicePath(id: TTSSupertonicModel.RequestId): string {
const audioFileName = TTSModelBase.hashAudioFile(id.text);
const audioFileName = TTSModelBase.hashAudioFile(id.text + id.styleId);
const audioPath = join(
TTSSupertonicModel.SupertonicAudioCachePath,
audioFileName
);
return audioPath;
}
public createRequestId(text: string): TTSSupertonicModel.RequestId {
public createRequestId(text: string, styleId?: string): TTSSupertonicModel.RequestId {
return {
text,
styleId: styleId ?? "F1"
};
}
}
export namespace TTSSupertonicModel {
export const instance = new TTSSupertonicModel();
export type RequestId = { text: string };
export type RequestId = { text: string, styleId: string };
export const SupertonicAudioCachePath = join(TTSModelBase.AudioCachePath, "supertonic");
}
export default TTSSupertonicModel;