add supertonic

This commit is contained in:
kimpure 2026-05-19 14:37:27 +00:00
parent a69c90c822
commit 69ec38d16b
No known key found for this signature in database
13 changed files with 2850 additions and 0 deletions

3
.gitignore vendored
View file

@ -9,3 +9,6 @@ packages/generated
cache
db
docker-compose.yml
target/
test.wav

1681
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

3
Cargo.toml Normal file
View file

@ -0,0 +1,3 @@
[workspace]
resolver = "2"
members = ["crates/yaejuyang-supertonic"]

View file

@ -0,0 +1,2 @@
assets/
.env

View file

@ -0,0 +1,30 @@
[package]
name = "yaejuyang-supertonic"
version = "0.1.0"
edition = "2024"
[features]
default = ["webgpu"]
webgpu = [ "ort/webgpu" ]
cuda = [ "ort/cuda" ]
[dependencies]
anyhow = "1.0.102"
axum = "0.8.9"
clap = "4.6.1"
crossbeam-channel = "0.5.15"
dotenvy = "0.15.7"
hound = "3.5.1"
ndarray = "0.17.2"
ort = "2.0.0-rc.12"
qwreey-utility-rs = "0.1.9"
rand = "0.10.1"
chacha20 = { version = "0.10.0-rc.5" }
rand_distr = "0.6.0"
regex = "1.12.3"
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.149"
tokio = { version = "1.52.3", features = ["full"] }
tracing = "0.1.44"
tracing-subscriber = "0.3.23"
unicode-normalization = "0.1.25"

View file

@ -0,0 +1,19 @@
use std::sync::Arc;
use axum::{Json, extract::State};
use crate::tts::TtsPool;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct Body {
text: String,
lang: String,
}
pub async fn handler(
State(state): State<Arc<TtsPool>>,
Json(payload): Json<Body>,
) -> Result<Vec<u8>, String> {
Ok(state.synthesize(payload.text, payload.lang).await?)
}

View file

@ -0,0 +1,69 @@
use axum::{Router, routing::post};
use std::path::PathBuf;
use std::sync::Arc;
pub mod api;
pub mod tts;
use tts::{TtsOpts, TtsPool, load_text_to_speech, load_voice_style};
#[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
dotenvy::dotenv().ok();
tracing_subscriber::fmt::init();
let model_dir = std::env::var("SUPERTONIC_MODEL_DIR")
.unwrap_or_else(|_| "./assets/supertonic-3".to_string());
let voice_style_path = std::env::var("SUPERTONIC_VOICE_STYLE")
.unwrap_or_else(|_| format!("{model_dir}/voice_styles/M1.json"));
let lang = std::env::var("SUPERTONIC_LANG").unwrap_or_else(|_| "en".to_string());
let total_step: usize = std::env::var("SUPERTONIC_TOTAL_STEP")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(8);
let speed: f32 = std::env::var("SUPERTONIC_SPEED")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(1.05);
let silence_duration: f32 = std::env::var("SUPERTONIC_SILENCE_DUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(0.3);
let workers: usize = std::env::var("SUPERTONIC_WORKERS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(2);
let hf_repo = std::env::var("SUPERTONIC_HF_REPO")
.unwrap_or_else(|_| "https://huggingface.co/Supertone/supertonic-3".to_string());
let model_path = PathBuf::from(&model_dir);
tts::assets::ensure_assets(&model_path, &hf_repo)?;
let onnx_dir_for_init = model_path.join("onnx").to_string_lossy().into_owned();
let voice_style_for_init = voice_style_path.clone();
let pool = Arc::new(TtsPool::spawn(
workers,
move |id| {
let span = tracing::info_span!("worker", worker_id = id);
let _enter = span.enter();
let tts = load_text_to_speech(&onnx_dir_for_init)?;
let style = load_voice_style(std::slice::from_ref(&voice_style_for_init), false)?;
Ok((tts, style))
},
TtsOpts {
total_step,
speed,
silence_duration,
},
)?);
let app = Router::new()
.route("/", post(api::handler))
.with_state(pool);
let addr = std::env::var("ADDR").unwrap_or_else(|_| "0.0.0.0:80".to_string());
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}

View file

@ -0,0 +1,77 @@
use anyhow::{Context, Result, bail};
use std::path::Path;
use std::process::Command;
/// Ensure the Supertonic model directory is present.
///
/// `model_dir` is the unpacked HuggingFace repo root — it should contain
/// `onnx/tts.json` and `voice_styles/*.json` after the clone. If
/// `model_dir/onnx/tts.json` already exists, this is a no-op (typical inside
/// containers that bake assets in at build time). Otherwise the HF repo is
/// cloned into `model_dir`, which must not already exist or must be empty.
///
/// Requires `git` and `git-lfs` on PATH.
pub fn ensure_assets(model_dir: &Path, hf_repo: &str) -> Result<()> {
if model_dir.join("onnx").join("tts.json").exists() {
tracing::info!(
model_dir = %model_dir.display(),
"supertonic assets already present"
);
return Ok(());
}
if model_dir.exists() {
let is_empty = std::fs::read_dir(model_dir)
.with_context(|| format!("failed to read {}", model_dir.display()))?
.next()
.is_none();
if !is_empty {
bail!(
"SUPERTONIC_MODEL_DIR={} already exists and is not empty but does not \
contain onnx/tts.json delete it or point SUPERTONIC_MODEL_DIR \
somewhere fresh",
model_dir.display()
);
}
} else if let Some(parent) = model_dir.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent)
.with_context(|| format!("failed to create {}", parent.display()))?;
}
}
tracing::info!(
repo = hf_repo,
target = %model_dir.display(),
"cloning supertonic assets from HuggingFace"
);
let status = Command::new("git")
.args(["clone", "--depth=1", hf_repo])
.arg(model_dir)
.status()
.context("failed to invoke git — is git installed?")?;
if !status.success() {
bail!("git clone of {hf_repo} failed with status {status}");
}
let lfs_status = Command::new("git")
.args(["-C"])
.arg(model_dir)
.args(["lfs", "pull"])
.status()
.context("failed to invoke git lfs — is git-lfs installed?")?;
if !lfs_status.success() {
bail!("git lfs pull failed with status {lfs_status}");
}
if !model_dir.join("onnx").join("tts.json").exists() {
bail!(
"expected {} to exist after clone — the HF repo layout may have changed",
model_dir.join("onnx").join("tts.json").display()
);
}
tracing::info!("supertonic assets ready");
Ok(())
}

View file

@ -0,0 +1,460 @@
use super::text::{UnicodeProcessor, VoiceStyleData, chunk_text, length_to_mask};
use anyhow::{Context, Result, anyhow};
use ndarray::{Array, Array3};
use ort::value::Value;
use ort::{ep::ExecutionProviderDispatch, session::Session};
use rand_distr::{Distribution, Normal};
use serde::Deserialize;
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
pub ae: AEConfig,
pub ttl: TTLConfig,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AEConfig {
pub sample_rate: i32,
pub base_chunk_size: i32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TTLConfig {
pub chunk_compress_factor: i32,
pub latent_dim: i32,
}
pub fn load_cfgs<P: AsRef<Path>>(onnx_dir: P) -> Result<Config> {
let cfg_path = onnx_dir.as_ref().join("tts.json");
let file =
File::open(&cfg_path).with_context(|| format!("failed to open {}", cfg_path.display()))?;
let reader = BufReader::new(file);
let cfgs: Config = serde_json::from_reader(reader)?;
Ok(cfgs)
}
pub struct Style {
pub ttl: Array3<f32>,
pub dp: Array3<f32>,
}
pub struct TextToSpeech {
cfgs: Config,
text_processor: UnicodeProcessor,
dp_ort: Session,
text_enc_ort: Session,
vector_est_ort: Session,
vocoder_ort: Session,
pub sample_rate: i32,
}
impl TextToSpeech {
pub fn new(
cfgs: Config,
text_processor: UnicodeProcessor,
dp_ort: Session,
text_enc_ort: Session,
vector_est_ort: Session,
vocoder_ort: Session,
) -> Self {
let sample_rate = cfgs.ae.sample_rate;
TextToSpeech {
cfgs,
text_processor,
dp_ort,
text_enc_ort,
vector_est_ort,
vocoder_ort,
sample_rate,
}
}
fn infer(
&mut self,
text_list: &[String],
lang_list: &[String],
style: &Style,
total_step: usize,
speed: f32,
) -> Result<(Vec<f32>, Vec<f32>)> {
let bsz = text_list.len();
let (text_ids, text_mask) = self.text_processor.call(text_list, lang_list)?;
let text_ids_array = {
let text_ids_shape = (bsz, text_ids[0].len());
let mut flat = Vec::new();
for row in &text_ids {
flat.extend_from_slice(row);
}
Array::from_shape_vec(text_ids_shape, flat)?
};
let text_ids_value = Value::from_array(text_ids_array)?;
let text_mask_value = Value::from_array(text_mask.clone())?;
let style_dp_value = Value::from_array(style.dp.clone())?;
let dp_outputs = self.dp_ort.run(ort::inputs! {
"text_ids" => &text_ids_value,
"style_dp" => &style_dp_value,
"text_mask" => &text_mask_value
})?;
let (_, duration_data) = dp_outputs["duration"].try_extract_tensor::<f32>()?;
let mut duration: Vec<f32> = duration_data.to_vec();
for dur in duration.iter_mut() {
*dur /= speed;
}
let style_ttl_value = Value::from_array(style.ttl.clone())?;
let text_enc_outputs = self.text_enc_ort.run(ort::inputs! {
"text_ids" => &text_ids_value,
"style_ttl" => &style_ttl_value,
"text_mask" => &text_mask_value
})?;
let (text_emb_shape, text_emb_data) =
text_enc_outputs["text_emb"].try_extract_tensor::<f32>()?;
let text_emb = Array3::from_shape_vec(
(
text_emb_shape[0] as usize,
text_emb_shape[1] as usize,
text_emb_shape[2] as usize,
),
text_emb_data.to_vec(),
)?;
let (mut xt, latent_mask) = sample_noisy_latent(
&duration,
self.sample_rate,
self.cfgs.ae.base_chunk_size,
self.cfgs.ttl.chunk_compress_factor,
self.cfgs.ttl.latent_dim,
);
let total_step_array = Array::from_elem(bsz, total_step as f32);
for step in 0..total_step {
let current_step_array = Array::from_elem(bsz, step as f32);
let xt_value = Value::from_array(xt.clone())?;
let text_emb_value = Value::from_array(text_emb.clone())?;
let latent_mask_value = Value::from_array(latent_mask.clone())?;
let text_mask_value2 = Value::from_array(text_mask.clone())?;
let current_step_value = Value::from_array(current_step_array)?;
let total_step_value = Value::from_array(total_step_array.clone())?;
let vector_est_outputs = self.vector_est_ort.run(ort::inputs! {
"noisy_latent" => &xt_value,
"text_emb" => &text_emb_value,
"style_ttl" => &style_ttl_value,
"latent_mask" => &latent_mask_value,
"text_mask" => &text_mask_value2,
"current_step" => &current_step_value,
"total_step" => &total_step_value
})?;
let (denoised_shape, denoised_data) =
vector_est_outputs["denoised_latent"].try_extract_tensor::<f32>()?;
xt = Array3::from_shape_vec(
(
denoised_shape[0] as usize,
denoised_shape[1] as usize,
denoised_shape[2] as usize,
),
denoised_data.to_vec(),
)?;
}
let final_latent_value = Value::from_array(xt)?;
let vocoder_outputs = self.vocoder_ort.run(ort::inputs! {
"latent" => &final_latent_value
})?;
let (_, wav_data) = vocoder_outputs["wav_tts"].try_extract_tensor::<f32>()?;
let wav: Vec<f32> = wav_data.to_vec();
Ok((wav, duration))
}
pub fn synthesize(
&mut self,
text: &str,
lang: &str,
style: &Style,
total_step: usize,
speed: f32,
silence_duration: f32,
) -> Result<(Vec<f32>, f32)> {
let max_len = if lang == "ko" || lang == "ja" {
120
} else {
300
};
let chunks = chunk_text(text, Some(max_len));
let mut wav_cat: Vec<f32> = Vec::new();
let mut dur_cat: f32 = 0.0;
for (i, chunk) in chunks.iter().enumerate() {
let (wav, duration) = self.infer(
&[chunk.clone()],
&[lang.to_string()],
style,
total_step,
speed,
)?;
let dur = duration[0];
let wav_len = (self.sample_rate as f32 * dur) as usize;
let wav_chunk = &wav[..wav_len.min(wav.len())];
if i == 0 {
wav_cat.extend_from_slice(wav_chunk);
dur_cat = dur;
} else {
let silence_len = (silence_duration * self.sample_rate as f32) as usize;
let silence = vec![0.0f32; silence_len];
wav_cat.extend_from_slice(&silence);
wav_cat.extend_from_slice(wav_chunk);
dur_cat += silence_duration + dur;
}
}
Ok((wav_cat, dur_cat))
}
}
fn sample_noisy_latent(
duration: &[f32],
sample_rate: i32,
base_chunk_size: i32,
chunk_compress: i32,
latent_dim: i32,
) -> (Array3<f32>, Array3<f32>) {
let bsz = duration.len();
let max_dur = duration.iter().fold(0.0f32, |a, &b| a.max(b));
let wav_len_max = (max_dur * sample_rate as f32) as usize;
let wav_lengths: Vec<usize> = duration
.iter()
.map(|&d| (d * sample_rate as f32) as usize)
.collect();
let chunk_size = (base_chunk_size * chunk_compress) as usize;
let latent_len = wav_len_max.div_ceil(chunk_size);
let latent_dim_val = (latent_dim * chunk_compress) as usize;
let mut noisy_latent = Array3::<f32>::zeros((bsz, latent_dim_val, latent_len));
let normal = Normal::new(0.0f32, 1.0).unwrap();
let mut rng = rand::rng();
for b in 0..bsz {
for d in 0..latent_dim_val {
for t in 0..latent_len {
noisy_latent[[b, d, t]] = normal.sample(&mut rng);
}
}
}
let latent_lengths: Vec<usize> = wav_lengths
.iter()
.map(|&len| len.div_ceil(chunk_size))
.collect();
let latent_mask = length_to_mask(&latent_lengths, Some(latent_len));
for b in 0..bsz {
for d in 0..latent_dim_val {
for t in 0..latent_len {
noisy_latent[[b, d, t]] *= latent_mask[[b, 0, t]];
}
}
}
(noisy_latent, latent_mask)
}
pub fn load_voice_style(voice_style_paths: &[String], verbose: bool) -> Result<Style> {
let bsz = voice_style_paths.len();
let first_file =
File::open(&voice_style_paths[0]).context("Failed to open voice style file")?;
let first_reader = BufReader::new(first_file);
let first_data: VoiceStyleData = serde_json::from_reader(first_reader)?;
let ttl_dims = &first_data.style_ttl.dims;
let dp_dims = &first_data.style_dp.dims;
let ttl_dim1 = ttl_dims[1];
let ttl_dim2 = ttl_dims[2];
let dp_dim1 = dp_dims[1];
let dp_dim2 = dp_dims[2];
let ttl_size = bsz * ttl_dim1 * ttl_dim2;
let dp_size = bsz * dp_dim1 * dp_dim2;
let mut ttl_flat = vec![0.0f32; ttl_size];
let mut dp_flat = vec![0.0f32; dp_size];
for (i, path) in voice_style_paths.iter().enumerate() {
let file = File::open(path).context("Failed to open voice style file")?;
let reader = BufReader::new(file);
let data: VoiceStyleData = serde_json::from_reader(reader)?;
let ttl_offset = i * ttl_dim1 * ttl_dim2;
let mut idx = 0;
for batch in &data.style_ttl.data {
for row in batch {
for &val in row {
ttl_flat[ttl_offset + idx] = val;
idx += 1;
}
}
}
let dp_offset = i * dp_dim1 * dp_dim2;
idx = 0;
for batch in &data.style_dp.data {
for row in batch {
for &val in row {
dp_flat[dp_offset + idx] = val;
idx += 1;
}
}
}
}
let ttl_style = Array3::from_shape_vec((bsz, ttl_dim1, ttl_dim2), ttl_flat)?;
let dp_style = Array3::from_shape_vec((bsz, dp_dim1, dp_dim2), dp_flat)?;
if verbose {
tracing::info!("Loaded {} voice styles", bsz);
}
Ok(Style {
ttl: ttl_style,
dp: dp_style,
})
}
#[cfg(feature = "webgpu")]
fn load_backend_webgpu(config: &HashMap<String, String>) -> Result<ExecutionProviderDispatch> {
let webgpu_device_id = config
.get("WEBGPU_DEVICE_ID")
.cloned()
.unwrap_or_else(|| "0".to_string())
.parse::<i32>()
.inspect_err(|e| tracing::error!("{e}"))?;
let webgpu = ort::ep::WebGPU::default().with_device_id(webgpu_device_id);
Ok(webgpu.build())
}
#[cfg(feature = "cuda")]
fn load_backend_cuda(config: &HashMap<String, String>) -> Result<ExecutionProviderDispatch> {
let cuda_device_id = config
.get("CUDA_DEVICE_ID")
.cloned()
.unwrap_or_else(|| "0".to_string())
.parse::<i32>()
.inspect_err(|e| tracing::error!("{e}"))?;
let cuda = ort::ep::CUDA::default().with_device_id(cuda_device_id);
Ok(cuda.build())
}
fn load_backends(config: &HashMap<String, String>) -> Vec<ExecutionProviderDispatch> {
let enabled_backends = config
.get("ENABLED_BACKENDS")
.map(|v| v.as_str())
.unwrap_or("")
.split(",")
.map(Into::into)
.collect::<Vec<String>>();
enabled_backends.iter().filter_map(|name| {
#[cfg(feature = "cuda")]
if name == "cuda" {
return load_backend_cuda(config)
.inspect_err(|err| {
tracing::error!("Failed to load backend *{}*: {:?}", name, err);
})
.ok();
}
#[cfg(feature = "webgpu")]
if name == "webgpu" {
return load_backend_webgpu(config)
.inspect_err(|err| {
tracing::error!("Failed to load backend *{}*: {:?}", name, err);
})
.ok();
}
tracing::error!(
"ENABLED_BACKENDS contains {}, but the binary is not compiled with {} backend support.",
name,
name
);
None
}).collect()
}
pub fn load_text_to_speech(onnx_dir: &str) -> Result<TextToSpeech> {
let cfgs = load_cfgs(onnx_dir)?;
let dp_path = format!("{}/duration_predictor.onnx", onnx_dir);
let text_enc_path = format!("{}/text_encoder.onnx", onnx_dir);
let vector_est_path = format!("{}/vector_estimator.onnx", onnx_dir);
let vocoder_path = format!("{}/vocoder.onnx", onnx_dir);
tracing::info!("Session successfully loaded with Vulkan GPU acceleration!");
let providers = load_backends(&std::env::vars().collect());
let dp_ort = Session::builder()?
.with_intra_threads(8)
.map_err(|e| anyhow!(e.message().to_string()))?
.with_execution_providers(&providers)
.map_err(|e| anyhow!(e.message().to_string()))?
.commit_from_file(&dp_path)?;
let text_enc_ort = Session::builder()?
.with_intra_threads(8)
.map_err(|e| anyhow!(e.message().to_string()))?
.with_execution_providers(&providers)
.map_err(|e| anyhow!(e.message().to_string()))?
.commit_from_file(&text_enc_path)?;
let vector_est_ort = Session::builder()?
.with_intra_threads(8)
.map_err(|e| anyhow!(e.message().to_string()))?
.with_execution_providers(&providers)
.map_err(|e| anyhow!(e.message().to_string()))?
.commit_from_file(&vector_est_path)?;
let vocoder_ort = Session::builder()?
.with_intra_threads(8)
.map_err(|e| anyhow!(e.message().to_string()))?
.with_execution_providers(&providers)
.map_err(|e| anyhow!(e.message().to_string()))?
.commit_from_file(&vocoder_path)?;
let unicode_indexer_path = format!("{}/unicode_indexer.json", onnx_dir);
let text_processor = UnicodeProcessor::new(&unicode_indexer_path)?;
Ok(TextToSpeech::new(
cfgs,
text_processor,
dp_ort,
text_enc_ort,
vector_est_ort,
vocoder_ort,
))
}

View file

@ -0,0 +1,8 @@
pub mod assets;
pub mod engine;
pub mod pool;
pub mod text;
pub mod wav;
pub use engine::{Style, TextToSpeech, load_text_to_speech, load_voice_style};
pub use pool::{TtsOpts, TtsPool};

View file

@ -0,0 +1,98 @@
use anyhow::{Context, Result};
use crossbeam_channel::{Sender, bounded};
use std::sync::Arc;
use std::thread;
use std::time::Instant;
use tokio::sync::oneshot;
use super::engine::{Style, TextToSpeech};
use super::wav::wav_bytes;
#[derive(Clone, Copy, Debug)]
pub struct TtsOpts {
pub total_step: usize,
pub speed: f32,
pub silence_duration: f32,
}
pub struct TtsJob {
pub text: String,
pub lang: String,
pub reply: oneshot::Sender<Result<Vec<u8>, String>>,
}
pub struct TtsPool {
tx: Sender<TtsJob>,
}
impl TtsPool {
pub fn spawn<I>(workers: usize, init: I, opts: TtsOpts) -> Result<Self>
where
I: Fn(u32) -> Result<(TextToSpeech, Style)> + Send + Sync + 'static,
{
let workers = workers.max(1);
let (tx, rx) = bounded::<TtsJob>(workers * 4);
let init = Arc::new(init);
for worker_id in 0..workers {
let rx = rx.clone();
let init = init.clone();
thread::Builder::new()
.name(format!("supertonic-tts-{worker_id}"))
.spawn(move || {
let (mut tts, style) = match init(worker_id as u32) {
Ok(pair) => pair,
Err(e) => {
tracing::error!("worker {worker_id} init failed: {e:?}");
return;
}
};
tracing::info!("supertonic worker {worker_id} ready");
while let Ok(job) = rx.recv() {
let start_at = Instant::now();
let result = (|| -> Result<Vec<u8>, String> {
let (wav, _dur) = tts
.synthesize(
&job.text,
&job.lang,
&style,
opts.total_step,
opts.speed,
opts.silence_duration,
)
.map_err(|e| e.to_string())?;
wav_bytes(&wav, tts.sample_rate).map_err(|e| e.to_string())
})();
tracing::info!("synthesize taken {}ms", start_at.elapsed().as_millis());
let _ = job.reply.send(result);
}
tracing::info!("supertonic worker {worker_id} exiting");
})
.context("failed to spawn TTS worker thread")?;
}
Ok(TtsPool { tx })
}
pub async fn synthesize(&self, text: String, lang: String) -> Result<Vec<u8>, String> {
let (reply_tx, reply_rx) = oneshot::channel();
let job = TtsJob {
text,
lang,
reply: reply_tx,
};
let tx = self.tx.clone();
tokio::task::spawn_blocking(move || tx.send(job))
.await
.map_err(|e| format!("dispatch task join failed: {e}"))?
.map_err(|_| "TTS pool channel closed".to_string())?;
reply_rx
.await
.map_err(|_| "TTS worker dropped reply channel".to_string())?
}
}

View file

@ -0,0 +1,375 @@
use anyhow::{Result, bail};
use ndarray::Array3;
use regex::Regex;
use serde::Deserialize;
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use unicode_normalization::UnicodeNormalization;
pub const AVAILABLE_LANGS: &[&str] = &[
"en", "ko", "ja", "ar", "bg", "cs", "da", "de", "el", "es", "et", "fi", "fr", "hi", "hr", "hu",
"id", "it", "lt", "lv", "nl", "pl", "pt", "ro", "ru", "sk", "sl", "sv", "tr", "uk", "vi", "na",
];
pub fn is_valid_lang(lang: &str) -> bool {
AVAILABLE_LANGS.contains(&lang)
}
pub struct UnicodeProcessor {
indexer: Vec<i64>,
}
impl UnicodeProcessor {
pub fn new<P: AsRef<Path>>(unicode_indexer_json_path: P) -> Result<Self> {
let file = File::open(unicode_indexer_json_path)?;
let reader = BufReader::new(file);
let indexer: Vec<i64> = serde_json::from_reader(reader)?;
Ok(UnicodeProcessor { indexer })
}
pub fn call(
&self,
text_list: &[String],
lang_list: &[String],
) -> Result<(Vec<Vec<i64>>, Array3<f32>)> {
let mut processed_texts: Vec<String> = Vec::new();
for (text, lang) in text_list.iter().zip(lang_list.iter()) {
processed_texts.push(preprocess_text(text, lang)?);
}
let text_ids_lengths: Vec<usize> =
processed_texts.iter().map(|t| t.chars().count()).collect();
let max_len = *text_ids_lengths.iter().max().unwrap_or(&0);
let mut text_ids = Vec::new();
for text in &processed_texts {
let mut row = vec![0i64; max_len];
let unicode_vals = text_to_unicode_values(text);
for (j, &val) in unicode_vals.iter().enumerate() {
if val < self.indexer.len() {
row[j] = self.indexer[val];
} else {
row[j] = -1;
}
}
text_ids.push(row);
}
let text_mask = get_text_mask(&text_ids_lengths);
Ok((text_ids, text_mask))
}
}
pub fn preprocess_text(text: &str, lang: &str) -> Result<String> {
let mut text: String = text.nfkd().collect();
let emoji_pattern = Regex::new(r"[\x{1F600}-\x{1F64F}\x{1F300}-\x{1F5FF}\x{1F680}-\x{1F6FF}\x{1F700}-\x{1F77F}\x{1F780}-\x{1F7FF}\x{1F800}-\x{1F8FF}\x{1F900}-\x{1F9FF}\x{1FA00}-\x{1FA6F}\x{1FA70}-\x{1FAFF}\x{2600}-\x{26FF}\x{2700}-\x{27BF}\x{1F1E6}-\x{1F1FF}]+").unwrap();
text = emoji_pattern.replace_all(&text, "").to_string();
let replacements = [
("\u{2013}", "-"),
("\u{2011}", "-"),
("\u{2014}", "-"),
("_", " "),
("\u{201C}", "\""),
("\u{201D}", "\""),
("\u{2018}", "'"),
("\u{2019}", "'"),
("\u{00B4}", "'"),
("`", "'"),
("[", " "),
("]", " "),
("|", " "),
("/", " "),
("#", " "),
("\u{2192}", " "),
("\u{2190}", " "),
];
for (from, to) in &replacements {
text = text.replace(from, to);
}
let special_symbols = ["\u{2665}", "\u{2606}", "\u{2661}", "\u{00A9}", "\\"];
for symbol in &special_symbols {
text = text.replace(symbol, "");
}
let expr_replacements = [
("@", " at "),
("e.g.,", "for example, "),
("i.e.,", "that is, "),
];
for (from, to) in &expr_replacements {
text = text.replace(from, to);
}
text = Regex::new(r" ,")
.unwrap()
.replace_all(&text, ",")
.to_string();
text = Regex::new(r" \.")
.unwrap()
.replace_all(&text, ".")
.to_string();
text = Regex::new(r" !")
.unwrap()
.replace_all(&text, "!")
.to_string();
text = Regex::new(r" \?")
.unwrap()
.replace_all(&text, "?")
.to_string();
text = Regex::new(r" ;")
.unwrap()
.replace_all(&text, ";")
.to_string();
text = Regex::new(r" :")
.unwrap()
.replace_all(&text, ":")
.to_string();
text = Regex::new(r" '")
.unwrap()
.replace_all(&text, "'")
.to_string();
while text.contains("\"\"") {
text = text.replace("\"\"", "\"");
}
while text.contains("''") {
text = text.replace("''", "'");
}
while text.contains("``") {
text = text.replace("``", "`");
}
text = Regex::new(r"\s+")
.unwrap()
.replace_all(&text, " ")
.to_string();
text = text.trim().to_string();
if !text.is_empty() {
let ends_with_punct =
Regex::new(r#"[.!?;:,'"\u{201C}\u{201D}\u{2018}\u{2019})\]}\u{2026}\u{3002}\u{300D}\u{300F}\u{3011}\u{3009}\u{300B}\u{203A}\u{00BB}]$"#).unwrap();
if !ends_with_punct.is_match(&text) {
text.push('.');
}
}
if !is_valid_lang(lang) {
bail!(
"Invalid language: {}. Available: {:?}",
lang,
AVAILABLE_LANGS
);
}
text = format!("<{}>{}</{}>", lang, text, lang);
Ok(text)
}
pub fn text_to_unicode_values(text: &str) -> Vec<usize> {
text.chars().map(|c| c as usize).collect()
}
pub fn length_to_mask(lengths: &[usize], max_len: Option<usize>) -> Array3<f32> {
let bsz = lengths.len();
let max_len = max_len.unwrap_or_else(|| *lengths.iter().max().unwrap_or(&0));
let mut mask = Array3::<f32>::zeros((bsz, 1, max_len));
for (i, &len) in lengths.iter().enumerate() {
for j in 0..len.min(max_len) {
mask[[i, 0, j]] = 1.0;
}
}
mask
}
pub fn get_text_mask(text_ids_lengths: &[usize]) -> Array3<f32> {
let max_len = *text_ids_lengths.iter().max().unwrap_or(&0);
length_to_mask(text_ids_lengths, Some(max_len))
}
const MAX_CHUNK_LENGTH: usize = 300;
const ABBREVIATIONS: &[&str] = &[
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.", "St.", "Ave.", "Rd.", "Blvd.", "Dept.",
"Inc.", "Ltd.", "Co.", "Corp.", "etc.", "vs.", "i.e.", "e.g.", "Ph.D.",
];
pub fn chunk_text(text: &str, max_len: Option<usize>) -> Vec<String> {
let max_len = max_len.unwrap_or(MAX_CHUNK_LENGTH);
let text = text.trim();
if text.is_empty() {
return vec![String::new()];
}
let para_re = Regex::new(r"\n\s*\n").unwrap();
let paragraphs: Vec<&str> = para_re.split(text).collect();
let mut chunks = Vec::new();
for para in paragraphs {
let para = para.trim();
if para.is_empty() {
continue;
}
if para.len() <= max_len {
chunks.push(para.to_string());
continue;
}
let sentences = split_sentences(para);
let mut current = String::new();
let mut current_len = 0;
for sentence in sentences {
let sentence = sentence.trim();
if sentence.is_empty() {
continue;
}
let sentence_len = sentence.len();
if sentence_len > max_len {
if !current.is_empty() {
chunks.push(current.trim().to_string());
current.clear();
current_len = 0;
}
let parts: Vec<&str> = sentence.split(',').collect();
for part in parts {
let part = part.trim();
if part.is_empty() {
continue;
}
let part_len = part.len();
if part_len > max_len {
let words: Vec<&str> = part.split_whitespace().collect();
let mut word_chunk = String::new();
let mut word_chunk_len = 0;
for word in words {
let word_len = word.len();
if word_chunk_len + word_len + 1 > max_len && !word_chunk.is_empty() {
chunks.push(word_chunk.trim().to_string());
word_chunk.clear();
word_chunk_len = 0;
}
if !word_chunk.is_empty() {
word_chunk.push(' ');
word_chunk_len += 1;
}
word_chunk.push_str(word);
word_chunk_len += word_len;
}
if !word_chunk.is_empty() {
chunks.push(word_chunk.trim().to_string());
}
} else {
if current_len + part_len + 1 > max_len && !current.is_empty() {
chunks.push(current.trim().to_string());
current.clear();
current_len = 0;
}
if !current.is_empty() {
current.push_str(", ");
current_len += 2;
}
current.push_str(part);
current_len += part_len;
}
}
continue;
}
if current_len + sentence_len + 1 > max_len && !current.is_empty() {
chunks.push(current.trim().to_string());
current.clear();
current_len = 0;
}
if !current.is_empty() {
current.push(' ');
current_len += 1;
}
current.push_str(sentence);
current_len += sentence_len;
}
if !current.is_empty() {
chunks.push(current.trim().to_string());
}
}
if chunks.is_empty() {
vec![String::new()]
} else {
chunks
}
}
fn split_sentences(text: &str) -> Vec<String> {
let re = Regex::new(r"([.!?])\s+").unwrap();
let matches: Vec<_> = re.find_iter(text).collect();
if matches.is_empty() {
return vec![text.to_string()];
}
let mut sentences = Vec::new();
let mut last_end = 0;
for m in matches {
let before_punc = &text[last_end..m.start()];
let mut is_abbrev = false;
for abbrev in ABBREVIATIONS {
let combined = format!("{}{}", before_punc.trim(), &text[m.start()..m.start() + 1]);
if combined.ends_with(abbrev) {
is_abbrev = true;
break;
}
}
if !is_abbrev {
sentences.push(text[last_end..m.end()].to_string());
last_end = m.end();
}
}
if last_end < text.len() {
sentences.push(text[last_end..].to_string());
}
if sentences.is_empty() {
vec![text.to_string()]
} else {
sentences
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct VoiceStyleData {
pub style_ttl: StyleComponent,
pub style_dp: StyleComponent,
}
#[derive(Debug, Clone, Deserialize)]
pub struct StyleComponent {
pub data: Vec<Vec<Vec<f32>>>,
pub dims: Vec<usize>,
#[serde(rename = "type")]
pub _dtype: String,
}

View file

@ -0,0 +1,25 @@
use anyhow::Result;
use hound::{SampleFormat, WavSpec, WavWriter};
use std::io::Cursor;
pub fn wav_bytes(audio: &[f32], sample_rate: i32) -> Result<Vec<u8>> {
let spec = WavSpec {
channels: 1,
sample_rate: sample_rate as u32,
bits_per_sample: 16,
sample_format: SampleFormat::Int,
};
let mut buf = Cursor::new(Vec::<u8>::new());
{
let mut writer = WavWriter::new(&mut buf, spec)?;
for &sample in audio {
let clamped = sample.clamp(-1.0, 1.0);
let val = (clamped * 32767.0) as i16;
writer.write_sample(val)?;
}
writer.finalize()?;
}
Ok(buf.into_inner())
}