diff --git a/crates/yaejuyang-supertonic/src/api.rs b/crates/yaejuyang-supertonic/src/api.rs index c5ef4c7..e3dc7a7 100644 --- a/crates/yaejuyang-supertonic/src/api.rs +++ b/crates/yaejuyang-supertonic/src/api.rs @@ -9,11 +9,14 @@ use serde::{Deserialize, Serialize}; pub struct Body { text: String, lang: String, + style_id: String, } pub async fn handler( State(state): State>, Json(payload): Json, ) -> Result, String> { - Ok(state.synthesize(payload.text, payload.lang).await?) + Ok(state + .synthesize(payload.text, payload.lang, payload.style_id) + .await?) } diff --git a/crates/yaejuyang-supertonic/src/main.rs b/crates/yaejuyang-supertonic/src/main.rs index 42f067b..a9b7927 100644 --- a/crates/yaejuyang-supertonic/src/main.rs +++ b/crates/yaejuyang-supertonic/src/main.rs @@ -1,12 +1,14 @@ use axum::{Router, routing::post}; -use std::path::PathBuf; use std::sync::Arc; +use std::{collections::HashMap, path::PathBuf}; pub mod api; pub mod tts; use tts::{TtsOpts, TtsPool, load_text_to_speech, load_voice_style}; +use crate::tts::engine::load_voice_style_map; + #[tokio::main(flavor = "multi_thread")] async fn main() -> Result<(), Box> { dotenvy::dotenv().ok(); @@ -15,7 +17,7 @@ async fn main() -> Result<(), Box> { let model_dir = std::env::var("SUPERTONIC_MODEL_DIR").unwrap_or_else(|_| "./assets".to_string()); let voice_style_path = std::env::var("SUPERTONIC_VOICE_STYLE") - .unwrap_or_else(|_| format!("{model_dir}/voice_styles/M1.json")); + .unwrap_or_else(|_| format!("F1={model_dir}/voice_styles/F1.json")); let lang = std::env::var("SUPERTONIC_LANG").unwrap_or_else(|_| "en".to_string()); let total_step: usize = std::env::var("SUPERTONIC_TOTAL_STEP") .ok() @@ -40,15 +42,26 @@ async fn main() -> Result<(), Box> { tts::assets::ensure_assets(&model_path, &hf_repo)?; let onnx_dir_for_init = model_path.join("onnx").to_string_lossy().into_owned(); - let voice_style_for_init = voice_style_path.clone(); + let voice_style_for_init = voice_style_path + .split(",") + .filter_map(|i| { + i.split_once("=").or_else(|| { + tracing::error!("Voice style '{i}' is not valid."); + None + }) + }) + .map(|(k, v)| (k.to_owned(), PathBuf::from(v))) + .collect::>(); + let voice_style_map = Arc::new(load_voice_style_map(&voice_style_for_init)?); + let pool = Arc::new(TtsPool::spawn( workers, move |id| { let span = tracing::info_span!("worker", worker_id = id); let _enter = span.enter(); let tts = load_text_to_speech(&onnx_dir_for_init)?; - let style = load_voice_style(std::slice::from_ref(&voice_style_for_init), false)?; - Ok((tts, style)) + + Ok((tts, voice_style_map.clone())) }, TtsOpts { total_step, diff --git a/crates/yaejuyang-supertonic/src/tts/engine.rs b/crates/yaejuyang-supertonic/src/tts/engine.rs index bea3c85..ad00dc9 100644 --- a/crates/yaejuyang-supertonic/src/tts/engine.rs +++ b/crates/yaejuyang-supertonic/src/tts/engine.rs @@ -8,7 +8,7 @@ use serde::Deserialize; use std::collections::HashMap; use std::fs::File; use std::io::BufReader; -use std::path::Path; +use std::path::{Path, PathBuf}; #[derive(Debug, Clone, Deserialize)] pub struct Config { @@ -282,62 +282,59 @@ fn sample_noisy_latent( (noisy_latent, latent_mask) } -pub fn load_voice_style(voice_style_paths: &[String], verbose: bool) -> Result