ollieollie commited on
Commit
7661963
·
verified ·
1 Parent(s): 9a10769

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -23
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import random
 
2
  import numpy as np
3
  import torch
4
  import gradio as gr
 
5
  from chatterbox.tts_turbo import ChatterboxTurboTTS
6
 
 
7
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
  EVENT_TAGS = [
@@ -11,13 +14,10 @@ EVENT_TAGS = [
11
  "[sniff]", "[gasp]", "[chuckle]", "[laugh]"
12
  ]
13
 
14
- # --- REFINED CSS ---
15
- # 1. tag-container: Forces the row to wrap items instead of scrolling. Removes borders/backgrounds.
16
- # 2. tag-btn: Sets the specific look (indigo theme) and stops them from stretching.
17
  CUSTOM_CSS = """
18
  .tag-container {
19
  display: flex !important;
20
- flex-wrap: wrap !important; /* This fixes the one-per-line issue */
21
  gap: 8px !important;
22
  margin-top: 5px !important;
23
  margin-bottom: 10px !important;
@@ -52,20 +52,19 @@ INSERT_TAG_JS = """
52
 
53
  const start = textarea.selectionStart;
54
  const end = textarea.selectionEnd;
55
-
56
  let prefix = " ";
57
  let suffix = " ";
58
-
59
  if (start === 0) prefix = "";
60
  else if (current_text[start - 1] === ' ') prefix = "";
61
-
62
  if (end < current_text.length && current_text[end] === ' ') suffix = "";
63
 
64
  return current_text.slice(0, start) + prefix + tag_val + suffix + current_text.slice(end);
65
  }
66
  """
67
 
68
-
69
  def set_seed(seed: int):
70
  torch.manual_seed(seed)
71
  torch.cuda.manual_seed(seed)
@@ -74,12 +73,16 @@ def set_seed(seed: int):
74
  np.random.seed(seed)
75
 
76
 
 
77
  def load_model():
78
  print(f"Loading Chatterbox-Turbo on {DEVICE}...")
79
  model = ChatterboxTurboTTS.from_pretrained(DEVICE)
80
  return model
81
 
82
-
 
 
 
83
  def generate(
84
  model,
85
  text,
@@ -92,6 +95,7 @@ def generate(
92
  repetition_penalty,
93
  norm_loudness
94
  ):
 
95
  if model is None:
96
  model = ChatterboxTurboTTS.from_pretrained(DEVICE)
97
 
@@ -119,23 +123,19 @@ with gr.Blocks(title="Chatterbox Turbo", css=CUSTOM_CSS) as demo:
119
  with gr.Row():
120
  with gr.Column():
121
  text = gr.Textbox(
122
- value="Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and um all that jazz. Would you like me to get some prices for you?",
123
  label="Text to synthesize (max chars 300)",
124
  max_lines=5,
125
- elem_id="main_textbox"
126
  )
127
 
128
- # --- Event Tags ---
129
- # Switched back to Row, but applied specific CSS to force wrapping
130
  with gr.Row(elem_classes=["tag-container"]):
131
  for tag in EVENT_TAGS:
132
- # elem_classes targets the button specifically
133
  btn = gr.Button(tag, elem_classes=["tag-btn"])
134
-
135
  btn.click(
136
- fn=None,
137
- inputs=[btn, text],
138
- outputs=text,
139
  js=INSERT_TAG_JS
140
  )
141
 
@@ -179,8 +179,4 @@ with gr.Blocks(title="Chatterbox Turbo", css=CUSTOM_CSS) as demo:
179
  outputs=audio_output,
180
  )
181
 
182
- if __name__ == "__main__":
183
- demo.queue(
184
- max_size=50,
185
- default_concurrency_limit=1,
186
- ).launch(share=True)
 
1
  import random
2
+ import os
3
  import numpy as np
4
  import torch
5
  import gradio as gr
6
+ import spaces
7
  from chatterbox.tts_turbo import ChatterboxTurboTTS
8
 
9
+ # Check for GPU, but ZeroGPU handles the actual assignment dynamically
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
  EVENT_TAGS = [
 
14
  "[sniff]", "[gasp]", "[chuckle]", "[laugh]"
15
  ]
16
 
 
 
 
17
  CUSTOM_CSS = """
18
  .tag-container {
19
  display: flex !important;
20
+ flex-wrap: wrap !important;
21
  gap: 8px !important;
22
  margin-top: 5px !important;
23
  margin-bottom: 10px !important;
 
52
 
53
  const start = textarea.selectionStart;
54
  const end = textarea.selectionEnd;
55
+
56
  let prefix = " ";
57
  let suffix = " ";
58
+
59
  if (start === 0) prefix = "";
60
  else if (current_text[start - 1] === ' ') prefix = "";
61
+
62
  if (end < current_text.length && current_text[end] === ' ') suffix = "";
63
 
64
  return current_text.slice(0, start) + prefix + tag_val + suffix + current_text.slice(end);
65
  }
66
  """
67
 
 
68
  def set_seed(seed: int):
69
  torch.manual_seed(seed)
70
  torch.cuda.manual_seed(seed)
 
73
  np.random.seed(seed)
74
 
75
 
76
+ # We don't need to decorate load_model, it runs on CPU or during startup
77
  def load_model():
78
  print(f"Loading Chatterbox-Turbo on {DEVICE}...")
79
  model = ChatterboxTurboTTS.from_pretrained(DEVICE)
80
  return model
81
 
82
+ # --- 2. THE CRITICAL DECORATOR ---
83
+ # This tells ZeroGPU to assign a GPU to this specific function call.
84
+ # The duration param is optional but helps with scheduling (e.g. 60s limit).
85
+ @spaces.GPU
86
  def generate(
87
  model,
88
  text,
 
95
  repetition_penalty,
96
  norm_loudness
97
  ):
98
+ # Reload model inside the GPU context if it was lost (ZeroGPU quirk)
99
  if model is None:
100
  model = ChatterboxTurboTTS.from_pretrained(DEVICE)
101
 
 
123
  with gr.Row():
124
  with gr.Column():
125
  text = gr.Textbox(
126
+ value="Congratulations Miss Connor! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?",
127
  label="Text to synthesize (max chars 300)",
128
  max_lines=5,
129
+ elem_id="main_textbox"
130
  )
131
 
 
 
132
  with gr.Row(elem_classes=["tag-container"]):
133
  for tag in EVENT_TAGS:
 
134
  btn = gr.Button(tag, elem_classes=["tag-btn"])
 
135
  btn.click(
136
+ fn=None,
137
+ inputs=[btn, text],
138
+ outputs=text,
139
  js=INSERT_TAG_JS
140
  )
141
 
 
179
  outputs=audio_output,
180
  )
181
 
182
+ demo.launch(mcp_server=True)