add supertonic
This commit is contained in:
parent
a69c90c822
commit
69ec38d16b
13 changed files with 2850 additions and 0 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -9,3 +9,6 @@ packages/generated
|
||||||
cache
|
cache
|
||||||
db
|
db
|
||||||
docker-compose.yml
|
docker-compose.yml
|
||||||
|
|
||||||
|
target/
|
||||||
|
test.wav
|
||||||
1681
Cargo.lock
generated
Normal file
1681
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
3
Cargo.toml
Normal file
3
Cargo.toml
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
[workspace]
|
||||||
|
resolver = "2"
|
||||||
|
members = ["crates/yaejuyang-supertonic"]
|
||||||
2
crates/yaejuyang-supertonic/.gitignore
vendored
Normal file
2
crates/yaejuyang-supertonic/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
assets/
|
||||||
|
.env
|
||||||
30
crates/yaejuyang-supertonic/Cargo.toml
Normal file
30
crates/yaejuyang-supertonic/Cargo.toml
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
[package]
|
||||||
|
name = "yaejuyang-supertonic"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["webgpu"]
|
||||||
|
webgpu = [ "ort/webgpu" ]
|
||||||
|
cuda = [ "ort/cuda" ]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1.0.102"
|
||||||
|
axum = "0.8.9"
|
||||||
|
clap = "4.6.1"
|
||||||
|
crossbeam-channel = "0.5.15"
|
||||||
|
dotenvy = "0.15.7"
|
||||||
|
hound = "3.5.1"
|
||||||
|
ndarray = "0.17.2"
|
||||||
|
ort = "2.0.0-rc.12"
|
||||||
|
qwreey-utility-rs = "0.1.9"
|
||||||
|
rand = "0.10.1"
|
||||||
|
chacha20 = { version = "0.10.0-rc.5" }
|
||||||
|
rand_distr = "0.6.0"
|
||||||
|
regex = "1.12.3"
|
||||||
|
serde = { version = "1.0.228", features = ["derive"] }
|
||||||
|
serde_json = "1.0.149"
|
||||||
|
tokio = { version = "1.52.3", features = ["full"] }
|
||||||
|
tracing = "0.1.44"
|
||||||
|
tracing-subscriber = "0.3.23"
|
||||||
|
unicode-normalization = "0.1.25"
|
||||||
19
crates/yaejuyang-supertonic/src/api.rs
Normal file
19
crates/yaejuyang-supertonic/src/api.rs
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::{Json, extract::State};
|
||||||
|
|
||||||
|
use crate::tts::TtsPool;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct Body {
|
||||||
|
text: String,
|
||||||
|
lang: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handler(
|
||||||
|
State(state): State<Arc<TtsPool>>,
|
||||||
|
Json(payload): Json<Body>,
|
||||||
|
) -> Result<Vec<u8>, String> {
|
||||||
|
Ok(state.synthesize(payload.text, payload.lang).await?)
|
||||||
|
}
|
||||||
69
crates/yaejuyang-supertonic/src/main.rs
Normal file
69
crates/yaejuyang-supertonic/src/main.rs
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
use axum::{Router, routing::post};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub mod api;
|
||||||
|
pub mod tts;
|
||||||
|
|
||||||
|
use tts::{TtsOpts, TtsPool, load_text_to_speech, load_voice_style};
|
||||||
|
|
||||||
|
#[tokio::main(flavor = "multi_thread")]
|
||||||
|
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
dotenvy::dotenv().ok();
|
||||||
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
|
let model_dir = std::env::var("SUPERTONIC_MODEL_DIR")
|
||||||
|
.unwrap_or_else(|_| "./assets/supertonic-3".to_string());
|
||||||
|
let voice_style_path = std::env::var("SUPERTONIC_VOICE_STYLE")
|
||||||
|
.unwrap_or_else(|_| format!("{model_dir}/voice_styles/M1.json"));
|
||||||
|
let lang = std::env::var("SUPERTONIC_LANG").unwrap_or_else(|_| "en".to_string());
|
||||||
|
let total_step: usize = std::env::var("SUPERTONIC_TOTAL_STEP")
|
||||||
|
.ok()
|
||||||
|
.and_then(|v| v.parse().ok())
|
||||||
|
.unwrap_or(8);
|
||||||
|
let speed: f32 = std::env::var("SUPERTONIC_SPEED")
|
||||||
|
.ok()
|
||||||
|
.and_then(|v| v.parse().ok())
|
||||||
|
.unwrap_or(1.05);
|
||||||
|
let silence_duration: f32 = std::env::var("SUPERTONIC_SILENCE_DUR")
|
||||||
|
.ok()
|
||||||
|
.and_then(|v| v.parse().ok())
|
||||||
|
.unwrap_or(0.3);
|
||||||
|
let workers: usize = std::env::var("SUPERTONIC_WORKERS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|v| v.parse().ok())
|
||||||
|
.unwrap_or(2);
|
||||||
|
let hf_repo = std::env::var("SUPERTONIC_HF_REPO")
|
||||||
|
.unwrap_or_else(|_| "https://huggingface.co/Supertone/supertonic-3".to_string());
|
||||||
|
|
||||||
|
let model_path = PathBuf::from(&model_dir);
|
||||||
|
tts::assets::ensure_assets(&model_path, &hf_repo)?;
|
||||||
|
|
||||||
|
let onnx_dir_for_init = model_path.join("onnx").to_string_lossy().into_owned();
|
||||||
|
let voice_style_for_init = voice_style_path.clone();
|
||||||
|
let pool = Arc::new(TtsPool::spawn(
|
||||||
|
workers,
|
||||||
|
move |id| {
|
||||||
|
let span = tracing::info_span!("worker", worker_id = id);
|
||||||
|
let _enter = span.enter();
|
||||||
|
let tts = load_text_to_speech(&onnx_dir_for_init)?;
|
||||||
|
let style = load_voice_style(std::slice::from_ref(&voice_style_for_init), false)?;
|
||||||
|
Ok((tts, style))
|
||||||
|
},
|
||||||
|
TtsOpts {
|
||||||
|
total_step,
|
||||||
|
speed,
|
||||||
|
silence_duration,
|
||||||
|
},
|
||||||
|
)?);
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/", post(api::handler))
|
||||||
|
.with_state(pool);
|
||||||
|
|
||||||
|
let addr = std::env::var("ADDR").unwrap_or_else(|_| "0.0.0.0:80".to_string());
|
||||||
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||||
|
axum::serve(listener, app).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
77
crates/yaejuyang-supertonic/src/tts/assets.rs
Normal file
77
crates/yaejuyang-supertonic/src/tts/assets.rs
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
use anyhow::{Context, Result, bail};
|
||||||
|
use std::path::Path;
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
/// Ensure the Supertonic model directory is present.
|
||||||
|
///
|
||||||
|
/// `model_dir` is the unpacked HuggingFace repo root — it should contain
|
||||||
|
/// `onnx/tts.json` and `voice_styles/*.json` after the clone. If
|
||||||
|
/// `model_dir/onnx/tts.json` already exists, this is a no-op (typical inside
|
||||||
|
/// containers that bake assets in at build time). Otherwise the HF repo is
|
||||||
|
/// cloned into `model_dir`, which must not already exist or must be empty.
|
||||||
|
///
|
||||||
|
/// Requires `git` and `git-lfs` on PATH.
|
||||||
|
pub fn ensure_assets(model_dir: &Path, hf_repo: &str) -> Result<()> {
|
||||||
|
if model_dir.join("onnx").join("tts.json").exists() {
|
||||||
|
tracing::info!(
|
||||||
|
model_dir = %model_dir.display(),
|
||||||
|
"supertonic assets already present"
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if model_dir.exists() {
|
||||||
|
let is_empty = std::fs::read_dir(model_dir)
|
||||||
|
.with_context(|| format!("failed to read {}", model_dir.display()))?
|
||||||
|
.next()
|
||||||
|
.is_none();
|
||||||
|
if !is_empty {
|
||||||
|
bail!(
|
||||||
|
"SUPERTONIC_MODEL_DIR={} already exists and is not empty but does not \
|
||||||
|
contain onnx/tts.json — delete it or point SUPERTONIC_MODEL_DIR \
|
||||||
|
somewhere fresh",
|
||||||
|
model_dir.display()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else if let Some(parent) = model_dir.parent() {
|
||||||
|
if !parent.as_os_str().is_empty() {
|
||||||
|
std::fs::create_dir_all(parent)
|
||||||
|
.with_context(|| format!("failed to create {}", parent.display()))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
repo = hf_repo,
|
||||||
|
target = %model_dir.display(),
|
||||||
|
"cloning supertonic assets from HuggingFace"
|
||||||
|
);
|
||||||
|
|
||||||
|
let status = Command::new("git")
|
||||||
|
.args(["clone", "--depth=1", hf_repo])
|
||||||
|
.arg(model_dir)
|
||||||
|
.status()
|
||||||
|
.context("failed to invoke git — is git installed?")?;
|
||||||
|
if !status.success() {
|
||||||
|
bail!("git clone of {hf_repo} failed with status {status}");
|
||||||
|
}
|
||||||
|
|
||||||
|
let lfs_status = Command::new("git")
|
||||||
|
.args(["-C"])
|
||||||
|
.arg(model_dir)
|
||||||
|
.args(["lfs", "pull"])
|
||||||
|
.status()
|
||||||
|
.context("failed to invoke git lfs — is git-lfs installed?")?;
|
||||||
|
if !lfs_status.success() {
|
||||||
|
bail!("git lfs pull failed with status {lfs_status}");
|
||||||
|
}
|
||||||
|
|
||||||
|
if !model_dir.join("onnx").join("tts.json").exists() {
|
||||||
|
bail!(
|
||||||
|
"expected {} to exist after clone — the HF repo layout may have changed",
|
||||||
|
model_dir.join("onnx").join("tts.json").display()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("supertonic assets ready");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
460
crates/yaejuyang-supertonic/src/tts/engine.rs
Normal file
460
crates/yaejuyang-supertonic/src/tts/engine.rs
Normal file
|
|
@ -0,0 +1,460 @@
|
||||||
|
use super::text::{UnicodeProcessor, VoiceStyleData, chunk_text, length_to_mask};
|
||||||
|
use anyhow::{Context, Result, anyhow};
|
||||||
|
use ndarray::{Array, Array3};
|
||||||
|
use ort::value::Value;
|
||||||
|
use ort::{ep::ExecutionProviderDispatch, session::Session};
|
||||||
|
use rand_distr::{Distribution, Normal};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::BufReader;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub ae: AEConfig,
|
||||||
|
pub ttl: TTLConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct AEConfig {
|
||||||
|
pub sample_rate: i32,
|
||||||
|
pub base_chunk_size: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct TTLConfig {
|
||||||
|
pub chunk_compress_factor: i32,
|
||||||
|
pub latent_dim: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_cfgs<P: AsRef<Path>>(onnx_dir: P) -> Result<Config> {
|
||||||
|
let cfg_path = onnx_dir.as_ref().join("tts.json");
|
||||||
|
let file =
|
||||||
|
File::open(&cfg_path).with_context(|| format!("failed to open {}", cfg_path.display()))?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
let cfgs: Config = serde_json::from_reader(reader)?;
|
||||||
|
Ok(cfgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Style {
|
||||||
|
pub ttl: Array3<f32>,
|
||||||
|
pub dp: Array3<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TextToSpeech {
|
||||||
|
cfgs: Config,
|
||||||
|
text_processor: UnicodeProcessor,
|
||||||
|
dp_ort: Session,
|
||||||
|
text_enc_ort: Session,
|
||||||
|
vector_est_ort: Session,
|
||||||
|
vocoder_ort: Session,
|
||||||
|
pub sample_rate: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextToSpeech {
|
||||||
|
pub fn new(
|
||||||
|
cfgs: Config,
|
||||||
|
text_processor: UnicodeProcessor,
|
||||||
|
dp_ort: Session,
|
||||||
|
text_enc_ort: Session,
|
||||||
|
vector_est_ort: Session,
|
||||||
|
vocoder_ort: Session,
|
||||||
|
) -> Self {
|
||||||
|
let sample_rate = cfgs.ae.sample_rate;
|
||||||
|
TextToSpeech {
|
||||||
|
cfgs,
|
||||||
|
text_processor,
|
||||||
|
dp_ort,
|
||||||
|
text_enc_ort,
|
||||||
|
vector_est_ort,
|
||||||
|
vocoder_ort,
|
||||||
|
sample_rate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer(
|
||||||
|
&mut self,
|
||||||
|
text_list: &[String],
|
||||||
|
lang_list: &[String],
|
||||||
|
style: &Style,
|
||||||
|
total_step: usize,
|
||||||
|
speed: f32,
|
||||||
|
) -> Result<(Vec<f32>, Vec<f32>)> {
|
||||||
|
let bsz = text_list.len();
|
||||||
|
|
||||||
|
let (text_ids, text_mask) = self.text_processor.call(text_list, lang_list)?;
|
||||||
|
|
||||||
|
let text_ids_array = {
|
||||||
|
let text_ids_shape = (bsz, text_ids[0].len());
|
||||||
|
let mut flat = Vec::new();
|
||||||
|
for row in &text_ids {
|
||||||
|
flat.extend_from_slice(row);
|
||||||
|
}
|
||||||
|
Array::from_shape_vec(text_ids_shape, flat)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let text_ids_value = Value::from_array(text_ids_array)?;
|
||||||
|
let text_mask_value = Value::from_array(text_mask.clone())?;
|
||||||
|
let style_dp_value = Value::from_array(style.dp.clone())?;
|
||||||
|
|
||||||
|
let dp_outputs = self.dp_ort.run(ort::inputs! {
|
||||||
|
"text_ids" => &text_ids_value,
|
||||||
|
"style_dp" => &style_dp_value,
|
||||||
|
"text_mask" => &text_mask_value
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let (_, duration_data) = dp_outputs["duration"].try_extract_tensor::<f32>()?;
|
||||||
|
let mut duration: Vec<f32> = duration_data.to_vec();
|
||||||
|
|
||||||
|
for dur in duration.iter_mut() {
|
||||||
|
*dur /= speed;
|
||||||
|
}
|
||||||
|
|
||||||
|
let style_ttl_value = Value::from_array(style.ttl.clone())?;
|
||||||
|
let text_enc_outputs = self.text_enc_ort.run(ort::inputs! {
|
||||||
|
"text_ids" => &text_ids_value,
|
||||||
|
"style_ttl" => &style_ttl_value,
|
||||||
|
"text_mask" => &text_mask_value
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let (text_emb_shape, text_emb_data) =
|
||||||
|
text_enc_outputs["text_emb"].try_extract_tensor::<f32>()?;
|
||||||
|
let text_emb = Array3::from_shape_vec(
|
||||||
|
(
|
||||||
|
text_emb_shape[0] as usize,
|
||||||
|
text_emb_shape[1] as usize,
|
||||||
|
text_emb_shape[2] as usize,
|
||||||
|
),
|
||||||
|
text_emb_data.to_vec(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let (mut xt, latent_mask) = sample_noisy_latent(
|
||||||
|
&duration,
|
||||||
|
self.sample_rate,
|
||||||
|
self.cfgs.ae.base_chunk_size,
|
||||||
|
self.cfgs.ttl.chunk_compress_factor,
|
||||||
|
self.cfgs.ttl.latent_dim,
|
||||||
|
);
|
||||||
|
|
||||||
|
let total_step_array = Array::from_elem(bsz, total_step as f32);
|
||||||
|
|
||||||
|
for step in 0..total_step {
|
||||||
|
let current_step_array = Array::from_elem(bsz, step as f32);
|
||||||
|
|
||||||
|
let xt_value = Value::from_array(xt.clone())?;
|
||||||
|
let text_emb_value = Value::from_array(text_emb.clone())?;
|
||||||
|
let latent_mask_value = Value::from_array(latent_mask.clone())?;
|
||||||
|
let text_mask_value2 = Value::from_array(text_mask.clone())?;
|
||||||
|
let current_step_value = Value::from_array(current_step_array)?;
|
||||||
|
let total_step_value = Value::from_array(total_step_array.clone())?;
|
||||||
|
|
||||||
|
let vector_est_outputs = self.vector_est_ort.run(ort::inputs! {
|
||||||
|
"noisy_latent" => &xt_value,
|
||||||
|
"text_emb" => &text_emb_value,
|
||||||
|
"style_ttl" => &style_ttl_value,
|
||||||
|
"latent_mask" => &latent_mask_value,
|
||||||
|
"text_mask" => &text_mask_value2,
|
||||||
|
"current_step" => ¤t_step_value,
|
||||||
|
"total_step" => &total_step_value
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let (denoised_shape, denoised_data) =
|
||||||
|
vector_est_outputs["denoised_latent"].try_extract_tensor::<f32>()?;
|
||||||
|
xt = Array3::from_shape_vec(
|
||||||
|
(
|
||||||
|
denoised_shape[0] as usize,
|
||||||
|
denoised_shape[1] as usize,
|
||||||
|
denoised_shape[2] as usize,
|
||||||
|
),
|
||||||
|
denoised_data.to_vec(),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let final_latent_value = Value::from_array(xt)?;
|
||||||
|
let vocoder_outputs = self.vocoder_ort.run(ort::inputs! {
|
||||||
|
"latent" => &final_latent_value
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let (_, wav_data) = vocoder_outputs["wav_tts"].try_extract_tensor::<f32>()?;
|
||||||
|
let wav: Vec<f32> = wav_data.to_vec();
|
||||||
|
|
||||||
|
Ok((wav, duration))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn synthesize(
|
||||||
|
&mut self,
|
||||||
|
text: &str,
|
||||||
|
lang: &str,
|
||||||
|
style: &Style,
|
||||||
|
total_step: usize,
|
||||||
|
speed: f32,
|
||||||
|
silence_duration: f32,
|
||||||
|
) -> Result<(Vec<f32>, f32)> {
|
||||||
|
let max_len = if lang == "ko" || lang == "ja" {
|
||||||
|
120
|
||||||
|
} else {
|
||||||
|
300
|
||||||
|
};
|
||||||
|
let chunks = chunk_text(text, Some(max_len));
|
||||||
|
|
||||||
|
let mut wav_cat: Vec<f32> = Vec::new();
|
||||||
|
let mut dur_cat: f32 = 0.0;
|
||||||
|
|
||||||
|
for (i, chunk) in chunks.iter().enumerate() {
|
||||||
|
let (wav, duration) = self.infer(
|
||||||
|
&[chunk.clone()],
|
||||||
|
&[lang.to_string()],
|
||||||
|
style,
|
||||||
|
total_step,
|
||||||
|
speed,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let dur = duration[0];
|
||||||
|
let wav_len = (self.sample_rate as f32 * dur) as usize;
|
||||||
|
let wav_chunk = &wav[..wav_len.min(wav.len())];
|
||||||
|
|
||||||
|
if i == 0 {
|
||||||
|
wav_cat.extend_from_slice(wav_chunk);
|
||||||
|
dur_cat = dur;
|
||||||
|
} else {
|
||||||
|
let silence_len = (silence_duration * self.sample_rate as f32) as usize;
|
||||||
|
let silence = vec![0.0f32; silence_len];
|
||||||
|
|
||||||
|
wav_cat.extend_from_slice(&silence);
|
||||||
|
wav_cat.extend_from_slice(wav_chunk);
|
||||||
|
dur_cat += silence_duration + dur;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((wav_cat, dur_cat))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_noisy_latent(
|
||||||
|
duration: &[f32],
|
||||||
|
sample_rate: i32,
|
||||||
|
base_chunk_size: i32,
|
||||||
|
chunk_compress: i32,
|
||||||
|
latent_dim: i32,
|
||||||
|
) -> (Array3<f32>, Array3<f32>) {
|
||||||
|
let bsz = duration.len();
|
||||||
|
let max_dur = duration.iter().fold(0.0f32, |a, &b| a.max(b));
|
||||||
|
|
||||||
|
let wav_len_max = (max_dur * sample_rate as f32) as usize;
|
||||||
|
let wav_lengths: Vec<usize> = duration
|
||||||
|
.iter()
|
||||||
|
.map(|&d| (d * sample_rate as f32) as usize)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let chunk_size = (base_chunk_size * chunk_compress) as usize;
|
||||||
|
let latent_len = wav_len_max.div_ceil(chunk_size);
|
||||||
|
let latent_dim_val = (latent_dim * chunk_compress) as usize;
|
||||||
|
|
||||||
|
let mut noisy_latent = Array3::<f32>::zeros((bsz, latent_dim_val, latent_len));
|
||||||
|
|
||||||
|
let normal = Normal::new(0.0f32, 1.0).unwrap();
|
||||||
|
let mut rng = rand::rng();
|
||||||
|
|
||||||
|
for b in 0..bsz {
|
||||||
|
for d in 0..latent_dim_val {
|
||||||
|
for t in 0..latent_len {
|
||||||
|
noisy_latent[[b, d, t]] = normal.sample(&mut rng);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let latent_lengths: Vec<usize> = wav_lengths
|
||||||
|
.iter()
|
||||||
|
.map(|&len| len.div_ceil(chunk_size))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let latent_mask = length_to_mask(&latent_lengths, Some(latent_len));
|
||||||
|
|
||||||
|
for b in 0..bsz {
|
||||||
|
for d in 0..latent_dim_val {
|
||||||
|
for t in 0..latent_len {
|
||||||
|
noisy_latent[[b, d, t]] *= latent_mask[[b, 0, t]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(noisy_latent, latent_mask)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_voice_style(voice_style_paths: &[String], verbose: bool) -> Result<Style> {
|
||||||
|
let bsz = voice_style_paths.len();
|
||||||
|
|
||||||
|
let first_file =
|
||||||
|
File::open(&voice_style_paths[0]).context("Failed to open voice style file")?;
|
||||||
|
let first_reader = BufReader::new(first_file);
|
||||||
|
let first_data: VoiceStyleData = serde_json::from_reader(first_reader)?;
|
||||||
|
|
||||||
|
let ttl_dims = &first_data.style_ttl.dims;
|
||||||
|
let dp_dims = &first_data.style_dp.dims;
|
||||||
|
|
||||||
|
let ttl_dim1 = ttl_dims[1];
|
||||||
|
let ttl_dim2 = ttl_dims[2];
|
||||||
|
let dp_dim1 = dp_dims[1];
|
||||||
|
let dp_dim2 = dp_dims[2];
|
||||||
|
|
||||||
|
let ttl_size = bsz * ttl_dim1 * ttl_dim2;
|
||||||
|
let dp_size = bsz * dp_dim1 * dp_dim2;
|
||||||
|
let mut ttl_flat = vec![0.0f32; ttl_size];
|
||||||
|
let mut dp_flat = vec![0.0f32; dp_size];
|
||||||
|
|
||||||
|
for (i, path) in voice_style_paths.iter().enumerate() {
|
||||||
|
let file = File::open(path).context("Failed to open voice style file")?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
let data: VoiceStyleData = serde_json::from_reader(reader)?;
|
||||||
|
|
||||||
|
let ttl_offset = i * ttl_dim1 * ttl_dim2;
|
||||||
|
let mut idx = 0;
|
||||||
|
for batch in &data.style_ttl.data {
|
||||||
|
for row in batch {
|
||||||
|
for &val in row {
|
||||||
|
ttl_flat[ttl_offset + idx] = val;
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let dp_offset = i * dp_dim1 * dp_dim2;
|
||||||
|
idx = 0;
|
||||||
|
for batch in &data.style_dp.data {
|
||||||
|
for row in batch {
|
||||||
|
for &val in row {
|
||||||
|
dp_flat[dp_offset + idx] = val;
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let ttl_style = Array3::from_shape_vec((bsz, ttl_dim1, ttl_dim2), ttl_flat)?;
|
||||||
|
let dp_style = Array3::from_shape_vec((bsz, dp_dim1, dp_dim2), dp_flat)?;
|
||||||
|
|
||||||
|
if verbose {
|
||||||
|
tracing::info!("Loaded {} voice styles", bsz);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Style {
|
||||||
|
ttl: ttl_style,
|
||||||
|
dp: dp_style,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "webgpu")]
|
||||||
|
fn load_backend_webgpu(config: &HashMap<String, String>) -> Result<ExecutionProviderDispatch> {
|
||||||
|
let webgpu_device_id = config
|
||||||
|
.get("WEBGPU_DEVICE_ID")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| "0".to_string())
|
||||||
|
.parse::<i32>()
|
||||||
|
.inspect_err(|e| tracing::error!("{e}"))?;
|
||||||
|
|
||||||
|
let webgpu = ort::ep::WebGPU::default().with_device_id(webgpu_device_id);
|
||||||
|
|
||||||
|
Ok(webgpu.build())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn load_backend_cuda(config: &HashMap<String, String>) -> Result<ExecutionProviderDispatch> {
|
||||||
|
let cuda_device_id = config
|
||||||
|
.get("CUDA_DEVICE_ID")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| "0".to_string())
|
||||||
|
.parse::<i32>()
|
||||||
|
.inspect_err(|e| tracing::error!("{e}"))?;
|
||||||
|
|
||||||
|
let cuda = ort::ep::CUDA::default().with_device_id(cuda_device_id);
|
||||||
|
|
||||||
|
Ok(cuda.build())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_backends(config: &HashMap<String, String>) -> Vec<ExecutionProviderDispatch> {
|
||||||
|
let enabled_backends = config
|
||||||
|
.get("ENABLED_BACKENDS")
|
||||||
|
.map(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.split(",")
|
||||||
|
.map(Into::into)
|
||||||
|
.collect::<Vec<String>>();
|
||||||
|
|
||||||
|
enabled_backends.iter().filter_map(|name| {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
if name == "cuda" {
|
||||||
|
return load_backend_cuda(config)
|
||||||
|
.inspect_err(|err| {
|
||||||
|
tracing::error!("Failed to load backend *{}*: {:?}", name, err);
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "webgpu")]
|
||||||
|
if name == "webgpu" {
|
||||||
|
return load_backend_webgpu(config)
|
||||||
|
.inspect_err(|err| {
|
||||||
|
tracing::error!("Failed to load backend *{}*: {:?}", name, err);
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::error!(
|
||||||
|
"ENABLED_BACKENDS contains {}, but the binary is not compiled with {} backend support.",
|
||||||
|
name,
|
||||||
|
name
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_text_to_speech(onnx_dir: &str) -> Result<TextToSpeech> {
|
||||||
|
let cfgs = load_cfgs(onnx_dir)?;
|
||||||
|
|
||||||
|
let dp_path = format!("{}/duration_predictor.onnx", onnx_dir);
|
||||||
|
let text_enc_path = format!("{}/text_encoder.onnx", onnx_dir);
|
||||||
|
let vector_est_path = format!("{}/vector_estimator.onnx", onnx_dir);
|
||||||
|
let vocoder_path = format!("{}/vocoder.onnx", onnx_dir);
|
||||||
|
|
||||||
|
tracing::info!("Session successfully loaded with Vulkan GPU acceleration!");
|
||||||
|
|
||||||
|
let providers = load_backends(&std::env::vars().collect());
|
||||||
|
|
||||||
|
let dp_ort = Session::builder()?
|
||||||
|
.with_intra_threads(8)
|
||||||
|
.map_err(|e| anyhow!(e.message().to_string()))?
|
||||||
|
.with_execution_providers(&providers)
|
||||||
|
.map_err(|e| anyhow!(e.message().to_string()))?
|
||||||
|
.commit_from_file(&dp_path)?;
|
||||||
|
let text_enc_ort = Session::builder()?
|
||||||
|
.with_intra_threads(8)
|
||||||
|
.map_err(|e| anyhow!(e.message().to_string()))?
|
||||||
|
.with_execution_providers(&providers)
|
||||||
|
.map_err(|e| anyhow!(e.message().to_string()))?
|
||||||
|
.commit_from_file(&text_enc_path)?;
|
||||||
|
let vector_est_ort = Session::builder()?
|
||||||
|
.with_intra_threads(8)
|
||||||
|
.map_err(|e| anyhow!(e.message().to_string()))?
|
||||||
|
.with_execution_providers(&providers)
|
||||||
|
.map_err(|e| anyhow!(e.message().to_string()))?
|
||||||
|
.commit_from_file(&vector_est_path)?;
|
||||||
|
let vocoder_ort = Session::builder()?
|
||||||
|
.with_intra_threads(8)
|
||||||
|
.map_err(|e| anyhow!(e.message().to_string()))?
|
||||||
|
.with_execution_providers(&providers)
|
||||||
|
.map_err(|e| anyhow!(e.message().to_string()))?
|
||||||
|
.commit_from_file(&vocoder_path)?;
|
||||||
|
|
||||||
|
let unicode_indexer_path = format!("{}/unicode_indexer.json", onnx_dir);
|
||||||
|
let text_processor = UnicodeProcessor::new(&unicode_indexer_path)?;
|
||||||
|
|
||||||
|
Ok(TextToSpeech::new(
|
||||||
|
cfgs,
|
||||||
|
text_processor,
|
||||||
|
dp_ort,
|
||||||
|
text_enc_ort,
|
||||||
|
vector_est_ort,
|
||||||
|
vocoder_ort,
|
||||||
|
))
|
||||||
|
}
|
||||||
8
crates/yaejuyang-supertonic/src/tts/mod.rs
Normal file
8
crates/yaejuyang-supertonic/src/tts/mod.rs
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
pub mod assets;
|
||||||
|
pub mod engine;
|
||||||
|
pub mod pool;
|
||||||
|
pub mod text;
|
||||||
|
pub mod wav;
|
||||||
|
|
||||||
|
pub use engine::{Style, TextToSpeech, load_text_to_speech, load_voice_style};
|
||||||
|
pub use pool::{TtsOpts, TtsPool};
|
||||||
98
crates/yaejuyang-supertonic/src/tts/pool.rs
Normal file
98
crates/yaejuyang-supertonic/src/tts/pool.rs
Normal file
|
|
@ -0,0 +1,98 @@
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use crossbeam_channel::{Sender, bounded};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::thread;
|
||||||
|
use std::time::Instant;
|
||||||
|
use tokio::sync::oneshot;
|
||||||
|
|
||||||
|
use super::engine::{Style, TextToSpeech};
|
||||||
|
use super::wav::wav_bytes;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct TtsOpts {
|
||||||
|
pub total_step: usize,
|
||||||
|
pub speed: f32,
|
||||||
|
pub silence_duration: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TtsJob {
|
||||||
|
pub text: String,
|
||||||
|
pub lang: String,
|
||||||
|
pub reply: oneshot::Sender<Result<Vec<u8>, String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TtsPool {
|
||||||
|
tx: Sender<TtsJob>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TtsPool {
|
||||||
|
pub fn spawn<I>(workers: usize, init: I, opts: TtsOpts) -> Result<Self>
|
||||||
|
where
|
||||||
|
I: Fn(u32) -> Result<(TextToSpeech, Style)> + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let workers = workers.max(1);
|
||||||
|
let (tx, rx) = bounded::<TtsJob>(workers * 4);
|
||||||
|
let init = Arc::new(init);
|
||||||
|
|
||||||
|
for worker_id in 0..workers {
|
||||||
|
let rx = rx.clone();
|
||||||
|
let init = init.clone();
|
||||||
|
thread::Builder::new()
|
||||||
|
.name(format!("supertonic-tts-{worker_id}"))
|
||||||
|
.spawn(move || {
|
||||||
|
let (mut tts, style) = match init(worker_id as u32) {
|
||||||
|
Ok(pair) => pair,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("worker {worker_id} init failed: {e:?}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tracing::info!("supertonic worker {worker_id} ready");
|
||||||
|
|
||||||
|
while let Ok(job) = rx.recv() {
|
||||||
|
let start_at = Instant::now();
|
||||||
|
|
||||||
|
let result = (|| -> Result<Vec<u8>, String> {
|
||||||
|
let (wav, _dur) = tts
|
||||||
|
.synthesize(
|
||||||
|
&job.text,
|
||||||
|
&job.lang,
|
||||||
|
&style,
|
||||||
|
opts.total_step,
|
||||||
|
opts.speed,
|
||||||
|
opts.silence_duration,
|
||||||
|
)
|
||||||
|
.map_err(|e| e.to_string())?;
|
||||||
|
wav_bytes(&wav, tts.sample_rate).map_err(|e| e.to_string())
|
||||||
|
})();
|
||||||
|
tracing::info!("synthesize taken {}ms", start_at.elapsed().as_millis());
|
||||||
|
|
||||||
|
let _ = job.reply.send(result);
|
||||||
|
}
|
||||||
|
tracing::info!("supertonic worker {worker_id} exiting");
|
||||||
|
})
|
||||||
|
.context("failed to spawn TTS worker thread")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(TtsPool { tx })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn synthesize(&self, text: String, lang: String) -> Result<Vec<u8>, String> {
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
let job = TtsJob {
|
||||||
|
text,
|
||||||
|
lang,
|
||||||
|
reply: reply_tx,
|
||||||
|
};
|
||||||
|
|
||||||
|
let tx = self.tx.clone();
|
||||||
|
tokio::task::spawn_blocking(move || tx.send(job))
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("dispatch task join failed: {e}"))?
|
||||||
|
.map_err(|_| "TTS pool channel closed".to_string())?;
|
||||||
|
|
||||||
|
reply_rx
|
||||||
|
.await
|
||||||
|
.map_err(|_| "TTS worker dropped reply channel".to_string())?
|
||||||
|
}
|
||||||
|
}
|
||||||
375
crates/yaejuyang-supertonic/src/tts/text.rs
Normal file
375
crates/yaejuyang-supertonic/src/tts/text.rs
Normal file
|
|
@ -0,0 +1,375 @@
|
||||||
|
use anyhow::{Result, bail};
|
||||||
|
use ndarray::Array3;
|
||||||
|
use regex::Regex;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::BufReader;
|
||||||
|
use std::path::Path;
|
||||||
|
use unicode_normalization::UnicodeNormalization;
|
||||||
|
|
||||||
|
pub const AVAILABLE_LANGS: &[&str] = &[
|
||||||
|
"en", "ko", "ja", "ar", "bg", "cs", "da", "de", "el", "es", "et", "fi", "fr", "hi", "hr", "hu",
|
||||||
|
"id", "it", "lt", "lv", "nl", "pl", "pt", "ro", "ru", "sk", "sl", "sv", "tr", "uk", "vi", "na",
|
||||||
|
];
|
||||||
|
|
||||||
|
pub fn is_valid_lang(lang: &str) -> bool {
|
||||||
|
AVAILABLE_LANGS.contains(&lang)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct UnicodeProcessor {
|
||||||
|
indexer: Vec<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UnicodeProcessor {
|
||||||
|
pub fn new<P: AsRef<Path>>(unicode_indexer_json_path: P) -> Result<Self> {
|
||||||
|
let file = File::open(unicode_indexer_json_path)?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
let indexer: Vec<i64> = serde_json::from_reader(reader)?;
|
||||||
|
Ok(UnicodeProcessor { indexer })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call(
|
||||||
|
&self,
|
||||||
|
text_list: &[String],
|
||||||
|
lang_list: &[String],
|
||||||
|
) -> Result<(Vec<Vec<i64>>, Array3<f32>)> {
|
||||||
|
let mut processed_texts: Vec<String> = Vec::new();
|
||||||
|
for (text, lang) in text_list.iter().zip(lang_list.iter()) {
|
||||||
|
processed_texts.push(preprocess_text(text, lang)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
let text_ids_lengths: Vec<usize> =
|
||||||
|
processed_texts.iter().map(|t| t.chars().count()).collect();
|
||||||
|
|
||||||
|
let max_len = *text_ids_lengths.iter().max().unwrap_or(&0);
|
||||||
|
|
||||||
|
let mut text_ids = Vec::new();
|
||||||
|
for text in &processed_texts {
|
||||||
|
let mut row = vec![0i64; max_len];
|
||||||
|
let unicode_vals = text_to_unicode_values(text);
|
||||||
|
for (j, &val) in unicode_vals.iter().enumerate() {
|
||||||
|
if val < self.indexer.len() {
|
||||||
|
row[j] = self.indexer[val];
|
||||||
|
} else {
|
||||||
|
row[j] = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
text_ids.push(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
let text_mask = get_text_mask(&text_ids_lengths);
|
||||||
|
|
||||||
|
Ok((text_ids, text_mask))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn preprocess_text(text: &str, lang: &str) -> Result<String> {
|
||||||
|
let mut text: String = text.nfkd().collect();
|
||||||
|
|
||||||
|
let emoji_pattern = Regex::new(r"[\x{1F600}-\x{1F64F}\x{1F300}-\x{1F5FF}\x{1F680}-\x{1F6FF}\x{1F700}-\x{1F77F}\x{1F780}-\x{1F7FF}\x{1F800}-\x{1F8FF}\x{1F900}-\x{1F9FF}\x{1FA00}-\x{1FA6F}\x{1FA70}-\x{1FAFF}\x{2600}-\x{26FF}\x{2700}-\x{27BF}\x{1F1E6}-\x{1F1FF}]+").unwrap();
|
||||||
|
text = emoji_pattern.replace_all(&text, "").to_string();
|
||||||
|
|
||||||
|
let replacements = [
|
||||||
|
("\u{2013}", "-"),
|
||||||
|
("\u{2011}", "-"),
|
||||||
|
("\u{2014}", "-"),
|
||||||
|
("_", " "),
|
||||||
|
("\u{201C}", "\""),
|
||||||
|
("\u{201D}", "\""),
|
||||||
|
("\u{2018}", "'"),
|
||||||
|
("\u{2019}", "'"),
|
||||||
|
("\u{00B4}", "'"),
|
||||||
|
("`", "'"),
|
||||||
|
("[", " "),
|
||||||
|
("]", " "),
|
||||||
|
("|", " "),
|
||||||
|
("/", " "),
|
||||||
|
("#", " "),
|
||||||
|
("\u{2192}", " "),
|
||||||
|
("\u{2190}", " "),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (from, to) in &replacements {
|
||||||
|
text = text.replace(from, to);
|
||||||
|
}
|
||||||
|
|
||||||
|
let special_symbols = ["\u{2665}", "\u{2606}", "\u{2661}", "\u{00A9}", "\\"];
|
||||||
|
for symbol in &special_symbols {
|
||||||
|
text = text.replace(symbol, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
let expr_replacements = [
|
||||||
|
("@", " at "),
|
||||||
|
("e.g.,", "for example, "),
|
||||||
|
("i.e.,", "that is, "),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (from, to) in &expr_replacements {
|
||||||
|
text = text.replace(from, to);
|
||||||
|
}
|
||||||
|
|
||||||
|
text = Regex::new(r" ,")
|
||||||
|
.unwrap()
|
||||||
|
.replace_all(&text, ",")
|
||||||
|
.to_string();
|
||||||
|
text = Regex::new(r" \.")
|
||||||
|
.unwrap()
|
||||||
|
.replace_all(&text, ".")
|
||||||
|
.to_string();
|
||||||
|
text = Regex::new(r" !")
|
||||||
|
.unwrap()
|
||||||
|
.replace_all(&text, "!")
|
||||||
|
.to_string();
|
||||||
|
text = Regex::new(r" \?")
|
||||||
|
.unwrap()
|
||||||
|
.replace_all(&text, "?")
|
||||||
|
.to_string();
|
||||||
|
text = Regex::new(r" ;")
|
||||||
|
.unwrap()
|
||||||
|
.replace_all(&text, ";")
|
||||||
|
.to_string();
|
||||||
|
text = Regex::new(r" :")
|
||||||
|
.unwrap()
|
||||||
|
.replace_all(&text, ":")
|
||||||
|
.to_string();
|
||||||
|
text = Regex::new(r" '")
|
||||||
|
.unwrap()
|
||||||
|
.replace_all(&text, "'")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
while text.contains("\"\"") {
|
||||||
|
text = text.replace("\"\"", "\"");
|
||||||
|
}
|
||||||
|
while text.contains("''") {
|
||||||
|
text = text.replace("''", "'");
|
||||||
|
}
|
||||||
|
while text.contains("``") {
|
||||||
|
text = text.replace("``", "`");
|
||||||
|
}
|
||||||
|
|
||||||
|
text = Regex::new(r"\s+")
|
||||||
|
.unwrap()
|
||||||
|
.replace_all(&text, " ")
|
||||||
|
.to_string();
|
||||||
|
text = text.trim().to_string();
|
||||||
|
|
||||||
|
if !text.is_empty() {
|
||||||
|
let ends_with_punct =
|
||||||
|
Regex::new(r#"[.!?;:,'"\u{201C}\u{201D}\u{2018}\u{2019})\]}\u{2026}\u{3002}\u{300D}\u{300F}\u{3011}\u{3009}\u{300B}\u{203A}\u{00BB}]$"#).unwrap();
|
||||||
|
if !ends_with_punct.is_match(&text) {
|
||||||
|
text.push('.');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !is_valid_lang(lang) {
|
||||||
|
bail!(
|
||||||
|
"Invalid language: {}. Available: {:?}",
|
||||||
|
lang,
|
||||||
|
AVAILABLE_LANGS
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
text = format!("<{}>{}</{}>", lang, text, lang);
|
||||||
|
|
||||||
|
Ok(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn text_to_unicode_values(text: &str) -> Vec<usize> {
|
||||||
|
text.chars().map(|c| c as usize).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn length_to_mask(lengths: &[usize], max_len: Option<usize>) -> Array3<f32> {
|
||||||
|
let bsz = lengths.len();
|
||||||
|
let max_len = max_len.unwrap_or_else(|| *lengths.iter().max().unwrap_or(&0));
|
||||||
|
|
||||||
|
let mut mask = Array3::<f32>::zeros((bsz, 1, max_len));
|
||||||
|
for (i, &len) in lengths.iter().enumerate() {
|
||||||
|
for j in 0..len.min(max_len) {
|
||||||
|
mask[[i, 0, j]] = 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mask
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_text_mask(text_ids_lengths: &[usize]) -> Array3<f32> {
|
||||||
|
let max_len = *text_ids_lengths.iter().max().unwrap_or(&0);
|
||||||
|
length_to_mask(text_ids_lengths, Some(max_len))
|
||||||
|
}
|
||||||
|
|
||||||
|
const MAX_CHUNK_LENGTH: usize = 300;
|
||||||
|
|
||||||
|
const ABBREVIATIONS: &[&str] = &[
|
||||||
|
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.", "St.", "Ave.", "Rd.", "Blvd.", "Dept.",
|
||||||
|
"Inc.", "Ltd.", "Co.", "Corp.", "etc.", "vs.", "i.e.", "e.g.", "Ph.D.",
|
||||||
|
];
|
||||||
|
|
||||||
|
pub fn chunk_text(text: &str, max_len: Option<usize>) -> Vec<String> {
|
||||||
|
let max_len = max_len.unwrap_or(MAX_CHUNK_LENGTH);
|
||||||
|
let text = text.trim();
|
||||||
|
|
||||||
|
if text.is_empty() {
|
||||||
|
return vec![String::new()];
|
||||||
|
}
|
||||||
|
|
||||||
|
let para_re = Regex::new(r"\n\s*\n").unwrap();
|
||||||
|
let paragraphs: Vec<&str> = para_re.split(text).collect();
|
||||||
|
let mut chunks = Vec::new();
|
||||||
|
|
||||||
|
for para in paragraphs {
|
||||||
|
let para = para.trim();
|
||||||
|
if para.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if para.len() <= max_len {
|
||||||
|
chunks.push(para.to_string());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let sentences = split_sentences(para);
|
||||||
|
let mut current = String::new();
|
||||||
|
let mut current_len = 0;
|
||||||
|
|
||||||
|
for sentence in sentences {
|
||||||
|
let sentence = sentence.trim();
|
||||||
|
if sentence.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let sentence_len = sentence.len();
|
||||||
|
if sentence_len > max_len {
|
||||||
|
if !current.is_empty() {
|
||||||
|
chunks.push(current.trim().to_string());
|
||||||
|
current.clear();
|
||||||
|
current_len = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let parts: Vec<&str> = sentence.split(',').collect();
|
||||||
|
for part in parts {
|
||||||
|
let part = part.trim();
|
||||||
|
if part.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let part_len = part.len();
|
||||||
|
if part_len > max_len {
|
||||||
|
let words: Vec<&str> = part.split_whitespace().collect();
|
||||||
|
let mut word_chunk = String::new();
|
||||||
|
let mut word_chunk_len = 0;
|
||||||
|
|
||||||
|
for word in words {
|
||||||
|
let word_len = word.len();
|
||||||
|
if word_chunk_len + word_len + 1 > max_len && !word_chunk.is_empty() {
|
||||||
|
chunks.push(word_chunk.trim().to_string());
|
||||||
|
word_chunk.clear();
|
||||||
|
word_chunk_len = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !word_chunk.is_empty() {
|
||||||
|
word_chunk.push(' ');
|
||||||
|
word_chunk_len += 1;
|
||||||
|
}
|
||||||
|
word_chunk.push_str(word);
|
||||||
|
word_chunk_len += word_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !word_chunk.is_empty() {
|
||||||
|
chunks.push(word_chunk.trim().to_string());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if current_len + part_len + 1 > max_len && !current.is_empty() {
|
||||||
|
chunks.push(current.trim().to_string());
|
||||||
|
current.clear();
|
||||||
|
current_len = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current.is_empty() {
|
||||||
|
current.push_str(", ");
|
||||||
|
current_len += 2;
|
||||||
|
}
|
||||||
|
current.push_str(part);
|
||||||
|
current_len += part_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if current_len + sentence_len + 1 > max_len && !current.is_empty() {
|
||||||
|
chunks.push(current.trim().to_string());
|
||||||
|
current.clear();
|
||||||
|
current_len = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current.is_empty() {
|
||||||
|
current.push(' ');
|
||||||
|
current_len += 1;
|
||||||
|
}
|
||||||
|
current.push_str(sentence);
|
||||||
|
current_len += sentence_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current.is_empty() {
|
||||||
|
chunks.push(current.trim().to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if chunks.is_empty() {
|
||||||
|
vec![String::new()]
|
||||||
|
} else {
|
||||||
|
chunks
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn split_sentences(text: &str) -> Vec<String> {
|
||||||
|
let re = Regex::new(r"([.!?])\s+").unwrap();
|
||||||
|
|
||||||
|
let matches: Vec<_> = re.find_iter(text).collect();
|
||||||
|
if matches.is_empty() {
|
||||||
|
return vec![text.to_string()];
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sentences = Vec::new();
|
||||||
|
let mut last_end = 0;
|
||||||
|
|
||||||
|
for m in matches {
|
||||||
|
let before_punc = &text[last_end..m.start()];
|
||||||
|
|
||||||
|
let mut is_abbrev = false;
|
||||||
|
for abbrev in ABBREVIATIONS {
|
||||||
|
let combined = format!("{}{}", before_punc.trim(), &text[m.start()..m.start() + 1]);
|
||||||
|
if combined.ends_with(abbrev) {
|
||||||
|
is_abbrev = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !is_abbrev {
|
||||||
|
sentences.push(text[last_end..m.end()].to_string());
|
||||||
|
last_end = m.end();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if last_end < text.len() {
|
||||||
|
sentences.push(text[last_end..].to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if sentences.is_empty() {
|
||||||
|
vec![text.to_string()]
|
||||||
|
} else {
|
||||||
|
sentences
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct VoiceStyleData {
|
||||||
|
pub style_ttl: StyleComponent,
|
||||||
|
pub style_dp: StyleComponent,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct StyleComponent {
|
||||||
|
pub data: Vec<Vec<Vec<f32>>>,
|
||||||
|
pub dims: Vec<usize>,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub _dtype: String,
|
||||||
|
}
|
||||||
25
crates/yaejuyang-supertonic/src/tts/wav.rs
Normal file
25
crates/yaejuyang-supertonic/src/tts/wav.rs
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
use anyhow::Result;
|
||||||
|
use hound::{SampleFormat, WavSpec, WavWriter};
|
||||||
|
use std::io::Cursor;
|
||||||
|
|
||||||
|
pub fn wav_bytes(audio: &[f32], sample_rate: i32) -> Result<Vec<u8>> {
|
||||||
|
let spec = WavSpec {
|
||||||
|
channels: 1,
|
||||||
|
sample_rate: sample_rate as u32,
|
||||||
|
bits_per_sample: 16,
|
||||||
|
sample_format: SampleFormat::Int,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut buf = Cursor::new(Vec::<u8>::new());
|
||||||
|
{
|
||||||
|
let mut writer = WavWriter::new(&mut buf, spec)?;
|
||||||
|
for &sample in audio {
|
||||||
|
let clamped = sample.clamp(-1.0, 1.0);
|
||||||
|
let val = (clamped * 32767.0) as i16;
|
||||||
|
writer.write_sample(val)?;
|
||||||
|
}
|
||||||
|
writer.finalize()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(buf.into_inner())
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue