cpuai commited on
Commit
4cf56ba
·
verified ·
1 Parent(s): dc47b30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -314
app.py CHANGED
@@ -1,39 +1,42 @@
1
- import spaces
2
- from dataclasses import dataclass
3
- import json
4
- import logging
5
  import os
6
- import random
7
  import re
8
  import sys
 
 
 
9
  import warnings
 
10
 
11
- from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
12
  import gradio as gr
13
  import torch
14
- from transformers import AutoModel, AutoTokenizer
15
 
16
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
17
 
18
- from diffusers import ZImagePipeline
19
  from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
 
20
 
21
- from pe import prompt_template
 
 
 
 
 
 
22
 
23
- # ==================== Environment Variables ==================================
24
  MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
25
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
26
  ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
27
  ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
28
  DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
29
  HF_TOKEN = os.environ.get("HF_TOKEN")
30
- # =============================================================================
31
-
32
 
33
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
34
  warnings.filterwarnings("ignore")
35
  logging.getLogger("transformers").setLevel(logging.ERROR)
36
 
 
37
  RES_CHOICES = {
38
  "1024": [
39
  "1024x1024 ( 1:1 )",
@@ -57,179 +60,63 @@ RES_CHOICES = {
57
  "1536x1024 ( 3:2 )",
58
  "1024x1536 ( 2:3 )",
59
  "1600x896 ( 16:9 )",
60
- "896x1600 ( 9:16 )", # not 900 coz divided by 16 needed
61
  "1680x720 ( 21:9 )",
62
  "720x1680 ( 9:21 )",
63
  ],
64
  }
65
 
66
- RESOLUTION_SET = []
67
- for resolutions in RES_CHOICES.values():
68
- RESOLUTION_SET.extend(resolutions)
69
-
70
  EXAMPLE_PROMPTS = [
71
  ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"],
72
- [
73
- "极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。"
74
- ],
75
- [
76
- "一张中景手机自拍照片拍摄了一位留着长黑发的年轻东亚女子在灯光明亮的电梯内对着镜子自拍。她穿着一件带有白色花朵图案的黑色露肩短上衣和深色牛仔裤。她的头微微倾斜,嘴唇嘟起做亲吻状,非常可爱俏皮。她右手拿着一部深灰色智能手机,遮住了部分脸,后置摄像头镜头对着镜子"
77
- ],
78
- [
79
- "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
80
- ],
81
- [
82
- '''A vertical digital illustration depicting a serene and majestic Chinese landscape, rendered in a style reminiscent of traditional Shanshui painting but with a modern, clean aesthetic. The scene is dominated by towering, steep cliffs in various shades of blue and teal, which frame a central valley. In the distance, layers of mountains fade into a light blue and white mist, creating a strong sense of atmospheric perspective and depth. A calm, turquoise river flows through the center of the composition, with a small, traditional Chinese boat, possibly a sampan, navigating its waters. The boat has a bright yellow canopy and a red hull, and it leaves a gentle wake behind it. It carries several indistinct figures of people. Sparse vegetation, including green trees and some bare-branched trees, clings to the rocky ledges and peaks. The overall lighting is soft and diffused, casting a tranquil glow over the entire scene. Centered in the image is overlaid text. At the top of the text block is a small, red, circular seal-like logo containing stylized characters. Below it, in a smaller, black, sans-serif font, are the words 'Zao-Xiang * East Beauty & West Fashion * Z-Image'. Directly beneath this, in a larger, elegant black serif font, is the word 'SHOW & SHARE CREATIVITY WITH THE WORLD'. Among them, there are "SHOW & SHARE", "CREATIVITY", and "WITH THE WORLD"'''
83
- ],
84
- [
85
- """一张虚构的英语电影《回忆之味》(The Taste of Memory)的电影海报。场景设置在一个质朴的19世纪风格厨房里。画面中央,一位红棕色头发、留着小胡子的中年男子(演员阿瑟·彭哈利根饰)站在一张木桌后,他身穿白色衬衫、黑色马甲和米色围裙,正看着一位女士,手中拿着一大块生红肉,下方是一个木制切菜板。在他的右边,一位梳着高髻的黑发女子(演员埃莉诺·万斯饰)倚靠在桌子上,温柔地对他微笑。她穿着浅色衬衫和一条上白下蓝的长裙。桌上除了放有切碎的葱和卷心菜丝的切菜板外,还有一个白色陶瓷盘、新鲜香草,左侧一个木箱上放着一串深色葡萄。背景是一面粗糙的灰白色抹灰墙,墙上挂着一幅风景画。最右边的一个台面上放着一盏复古油灯。海报上有大量的文字信息。左上角是白色的无衬线字体"ARTISAN FILMS PRESENTS",其下方是"ELEANOR VANCE"和"ACADEMY AWARD® WINNER"。右上角写着"ARTHUR PENHALIGON"和"GOLDEN GLOBE® AWARD WINNER"。顶部中央是圣丹斯电影节的桂冠标志,下方写着"SUNDANCE FILM FESTIVAL GRAND JURY PRIZE 2024"。主标题"THE TASTE OF MEMORY"以白色的大号衬线字体醒目地显示在下半部分。标题下方注明了"A FILM BY Tongyi Interaction Lab"。底部区域用白色小字列出了完整的演职员名单,包括"SCREENPLAY BY ANNA REID"、"CULINARY DIRECTION BY JAMES CARTER"以及Artisan Films、Riverstone Pictures和Heritage Media等众多出品公司标志。整体风格是写实主义,采用温暖柔和的灯光方案,营造出一种亲密的氛围。色调以棕色、米色和柔和的绿色等大地色系为主。两位演员的身体都在腰部被截断。"""
86
- ],
87
- [
88
- """一张方形构图的特写照片,主体是一片巨大的、鲜绿色的植物叶片,并叠加了文字,使其具有海报或杂志封面的外观。主要拍摄对象是一片厚实、有蜡质感的叶子,从左下角到右上角呈对角线弯曲穿过画面。其表面反光性很强,捕捉到一个明亮的直射光源,形成了一道突出的高光,亮面下显露出平行的精细叶脉。背景由其他深绿色的叶子组成,这些叶子轻微失焦,营造出浅景深效果,突出了前景的主叶片。整体风格是写实摄影,明亮的叶片与黑暗的阴影背景之间形成高对比度。图像上有多处渲染文字。左上角是白色的衬线字体文字"PIXEL-PEEPERS GUILD Presents"。右上角同样是白色衬线字体的文字"[Instant Noodle] 泡面调料包"。左侧垂直排列着标题"Render Distance: Max",为白色衬线字体。左下角是五个硕大的白色宋体汉字"显卡在...燃烧"。右下角是较小的白色衬线字体文字"Leica Glow™ Unobtanium X-1",其正上方是用白色宋体字书写的名字"蔡几"。识别出的核心实体包括品牌像素偷窥者协会、其产品线泡面调料包、相机型号买不到™ X-1以及摄影师名字造相。"""
89
- ],
90
  ]
91
 
92
-
93
- def get_resolution(resolution):
94
  match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
95
  if match:
96
  return int(match.group(1)), int(match.group(2))
97
  return 1024, 1024
98
 
99
 
100
- def load_models(model_path, enable_compile=False, attention_backend="native"):
101
- print(f"Loading models from {model_path}...")
102
-
103
- use_auth_token = HF_TOKEN if HF_TOKEN else True
104
-
105
- if not os.path.exists(model_path):
106
- vae = AutoencoderKL.from_pretrained(
107
- f"{model_path}",
108
- subfolder="vae",
109
- torch_dtype=torch.bfloat16,
110
- device_map="cuda",
111
- use_auth_token=use_auth_token,
112
- )
113
-
114
- text_encoder = AutoModel.from_pretrained(
115
- f"{model_path}",
116
- subfolder="text_encoder",
117
- torch_dtype=torch.bfloat16,
118
- device_map="cuda",
119
- use_auth_token=use_auth_token,
120
- ).eval()
121
-
122
- tokenizer = AutoTokenizer.from_pretrained(f"{model_path}", subfolder="tokenizer", use_auth_token=use_auth_token)
123
- else:
124
- vae = AutoencoderKL.from_pretrained(
125
- os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda"
126
- )
127
-
128
- text_encoder = AutoModel.from_pretrained(
129
- os.path.join(model_path, "text_encoder"),
130
- torch_dtype=torch.bfloat16,
131
- device_map="cuda",
132
- ).eval()
133
-
134
- tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
135
-
136
- tokenizer.padding_side = "left"
137
-
138
- if enable_compile:
139
- print("Enabling torch.compile optimizations...")
140
- torch._inductor.config.conv_1x1_as_mm = True
141
- torch._inductor.config.coordinate_descent_tuning = True
142
- torch._inductor.config.epilogue_fusion = False
143
- torch._inductor.config.coordinate_descent_check_all_directions = True
144
- torch._inductor.config.max_autotune_gemm = True
145
- torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN"
146
- torch._inductor.config.triton.cudagraphs = False
147
-
148
- pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None)
149
-
150
- if enable_compile:
151
- pipe.vae.disable_tiling()
152
-
153
- if not os.path.exists(model_path):
154
- transformer = ZImageTransformer2DModel.from_pretrained(
155
- f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token
156
- ).to("cuda", torch.bfloat16)
157
- else:
158
- transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")).to(
159
- "cuda", torch.bfloat16
160
- )
161
-
162
- pipe.transformer = transformer
163
- pipe.transformer.set_attention_backend(attention_backend)
164
-
165
- if enable_compile:
166
- print("Compiling transformer...")
167
- pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
168
-
169
- pipe.to("cuda", torch.bfloat16)
170
-
171
- return pipe
172
-
173
-
174
- def generate_image(
175
- pipe,
176
- prompt,
177
- resolution="1024x1024",
178
- seed=-1,
179
- guidance_scale=5.0,
180
- num_inference_steps=50,
181
- shift=3.0,
182
- max_sequence_length=512,
183
- progress=gr.Progress(track_tqdm=True),
184
- ):
185
- width, height = get_resolution(resolution)
186
-
187
- if seed == -1:
188
- seed = torch.randint(0, 1000000, (1,)).item()
189
- print(f"Using seed: {seed}")
190
-
191
- generator = torch.Generator("cuda").manual_seed(seed)
192
-
193
- scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
194
- pipe.scheduler = scheduler
195
-
196
- image = pipe(
197
- prompt=prompt,
198
- height=height,
199
- width=width,
200
- guidance_scale=guidance_scale,
201
- num_inference_steps=num_inference_steps,
202
- generator=generator,
203
- max_sequence_length=max_sequence_length,
204
- ).images[0]
205
-
206
- return image
207
 
208
 
209
- def warmup_model(pipe, resolutions):
210
- print("Starting warmup phase...")
 
 
 
 
 
211
 
212
- dummy_prompt = "warmup"
213
 
214
- for res_str in resolutions:
215
- print(f"Warming up for resolution: {res_str}")
216
- try:
217
- for i in range(3):
218
- generate_image(
219
- pipe,
220
- prompt=dummy_prompt,
221
- resolution=res_str,
222
- num_inference_steps=9,
223
- guidance_scale=0.0,
224
- seed=42 + i,
225
- )
226
- except Exception as e:
227
- print(f"Warmup failed for {res_str}: {e}")
228
 
229
- print("Warmup completed.")
230
 
231
 
232
- # ==================== Prompt Expander ====================
233
  @dataclass
234
  class PromptOutput:
235
  status: bool
@@ -261,15 +148,12 @@ class APIPromptExpander(PromptExpander):
261
  base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
262
 
263
  if not api_key:
264
- print("Warning: DASHSCOPE_API_KEY not found.")
265
  return None
266
 
267
  return OpenAI(api_key=api_key, base_url=base_url)
268
- except ImportError:
269
- print("Please install openai: pip install openai")
270
- return None
271
  except Exception as e:
272
- print(f"Failed to initialize API client: {e}")
273
  return None
274
 
275
  def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs):
@@ -294,205 +178,245 @@ class APIPromptExpander(PromptExpander):
294
  temperature=0.7,
295
  top_p=0.8,
296
  )
297
-
298
- content = response.choices[0].message.content
299
  json_start = content.find("```json")
300
  if json_start != -1:
301
  json_end = content.find("```", json_start + 7)
302
  try:
303
- json_str = content[json_start + 7 : json_end].strip()
304
  data = json.loads(json_str)
305
  expanded_prompt = data.get("revised_prompt", content)
306
- except:
307
  expanded_prompt = content
308
  else:
309
  expanded_prompt = content
310
 
311
- return PromptOutput(
312
- status=True, prompt=expanded_prompt, seed=seed, system_prompt=system_prompt, message=content
313
- )
314
  except Exception as e:
315
  return PromptOutput(False, "", seed, system_prompt, str(e))
316
 
317
 
318
- def create_prompt_expander(backend="api", **kwargs):
319
- if backend == "api":
320
- return APIPromptExpander(**kwargs)
321
- raise ValueError("Only 'api' backend is supported.")
322
-
323
-
324
- pipe = None
325
- prompt_expander = None
326
-
327
-
328
- def init_app():
329
- global pipe, prompt_expander
330
-
331
- try:
332
- pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND)
333
- print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}")
334
-
335
- if ENABLE_WARMUP:
336
- all_resolutions = []
337
- for cat in RES_CHOICES.values():
338
- all_resolutions.extend(cat)
339
- warmup_model(pipe, all_resolutions)
340
-
341
- except Exception as e:
342
- print(f"Error loading model: {e}")
343
- pipe = None
344
-
345
  try:
346
- prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"})
347
  print("Prompt expander initialized.")
348
  except Exception as e:
349
  print(f"Error initializing prompt expander: {e}")
350
  prompt_expander = None
351
 
352
 
353
- def prompt_enhance(prompt, enable_enhance):
354
  if not enable_enhance or not prompt_expander:
355
  return prompt, "Enhancement disabled or not available."
356
-
357
  if not prompt.strip():
358
  return "", "Please enter a prompt."
359
-
360
  try:
361
  result = prompt_expander(prompt)
362
  if result.status:
363
  return result.prompt, result.message
364
- else:
365
- return prompt, f"Enhancement failed: {result.message}"
366
  except Exception as e:
367
  return prompt, f"Error: {str(e)}"
368
 
369
 
370
- @spaces.GPU
371
- def generate(
372
- prompt, resolution, seed, steps, shift, enhance, random_seed, gallery_images, progress=gr.Progress(track_tqdm=True)
373
- ):
374
- if pipe is None:
375
- raise gr.Error("Model not loaded.")
376
-
377
- final_prompt = prompt
378
-
379
- if enhance:
380
- final_prompt, _ = prompt_enhance(prompt, True)
381
- print(f"Enhanced prompt: {final_prompt}")
382
-
383
- if random_seed:
384
- new_seed = random.randint(1, 1000000)
385
- else:
386
- new_seed = seed if seed != -1 else random.randint(1, 1000000)
387
-
388
- try:
389
- resolution_str = resolution.split(" ")[0]
390
- except:
391
- resolution_str = "1024x1024"
392
-
393
- image = generate_image(
394
- pipe=pipe,
395
- prompt=final_prompt,
396
- resolution=resolution_str,
397
- seed=new_seed,
398
- guidance_scale=0.0,
399
- num_inference_steps=int(steps + 1),
400
- shift=shift,
401
- )
402
-
403
- if gallery_images is None:
404
- gallery_images = []
405
- gallery_images.append(image)
406
-
407
- return gallery_images, str(new_seed)
408
 
 
409
 
410
- init_app()
 
411
 
412
- with gr.Blocks(title="Z-Image Demo") as demo:
413
- gr.Markdown(
414
- """<div align="center">
415
-
416
- # Z-Image Generation Demo
417
 
418
- [![GitHub](https://img.shields.io/badge/GitHub-Z--Image-181717?logo=github&logoColor=white)](https://github.com/Tongyi-MAI/Z-Image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
- *An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer*
421
 
422
- </div>"""
 
 
 
 
 
 
423
  )
424
 
425
- with gr.Row():
426
- with gr.Column(scale=1):
427
- prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here...")
428
- # PE components (Temporarily disabled)
429
- # with gr.Row():
430
- # enable_enhance = gr.Checkbox(label="Enhance Prompt (DashScope)", value=False)
431
- # enhance_btn = gr.Button("Enhance Only")
 
 
 
432
 
433
- with gr.Row():
434
- choices = [int(k) for k in RES_CHOICES.keys()]
435
- res_cat = gr.Dropdown(value=1024, choices=choices, label="Resolution Category")
436
 
437
- initial_res_choices = RES_CHOICES["1024"]
438
- resolution = gr.Dropdown(value=initial_res_choices[0], choices=initial_res_choices, label="Width x Height (Ratio)")
 
 
 
439
 
440
- with gr.Row():
441
- seed = gr.Number(label="Seed", value=-1, precision=0)
442
- random_seed = gr.Checkbox(label="Random Seed", value=True)
443
 
444
- with gr.Row():
445
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=8, step=1, interactive=False)
446
- shift = gr.Slider(label="Time Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
447
 
448
- generate_btn = gr.Button("Generate", variant="primary")
 
449
 
450
- # Example prompts
451
- gr.Markdown("### 📝 Example Prompts")
452
- gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input, label=None)
453
 
454
- with gr.Column(scale=1):
455
- output_gallery = gr.Gallery(
456
- label="Generated Images", columns=2, rows=2, height=600, object_fit="contain", format="png", interactive=False
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  )
458
- used_seed = gr.Textbox(label="Seed Used", interactive=False)
 
459
 
460
- def update_res_choices(_res_cat):
461
- if str(_res_cat) in RES_CHOICES:
462
- res_choices = RES_CHOICES[str(_res_cat)]
463
- else:
464
- res_choices = RES_CHOICES["1024"]
465
- return gr.update(value=res_choices[0], choices=res_choices)
466
 
467
- res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution)
468
 
469
- # PE enhancement button (Temporarily disabled)
470
- # enhance_btn.click(
471
- # prompt_enhance,
472
- # inputs=[prompt_input, enable_enhance],
473
- # outputs=[prompt_input, final_prompt_output]
474
- # )
 
 
 
 
 
475
 
476
- # Dummy enable_enhance variable set to False
477
- enable_enhance = gr.State(value=False)
478
 
479
- def update_seed(current_seed, random_seed_enabled):
480
- if random_seed_enabled:
481
- new_seed = random.randint(1, 1000000)
482
- else:
483
- new_seed = current_seed if current_seed != -1 else random.randint(1, 1000000)
484
- return gr.update(value=new_seed)
485
 
486
- generate_btn.click(update_seed, inputs=[seed, random_seed], outputs=[seed])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
 
488
- generate_btn.click(
489
- generate,
490
- inputs=[prompt_input, resolution, seed, steps, shift, enable_enhance, random_seed, output_gallery],
491
- outputs=[output_gallery, used_seed],
492
- )
493
 
494
- css='''
495
- .fillable{max-width: 1230px !important}
496
- '''
497
- if __name__ == "__main__":
498
- demo.launch(css=css)
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  import re
3
  import sys
4
+ import json
5
+ import random
6
+ import logging
7
  import warnings
8
+ from dataclasses import dataclass
9
 
 
10
  import gradio as gr
11
  import torch
 
12
 
13
+ import spaces # ZeroGPU: 动态分配 GPU 的关键
14
 
15
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline
16
  from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
17
+ from transformers import AutoModel, AutoTokenizer
18
 
19
+ # -------------------- 可选:Prompt Expander 依赖 --------------------
20
+ # 你的原项目里:from pe import prompt_template
21
+ # 为了“开箱即用”,这里做成可选导入:没有 pe.py 也不致命
22
+ try:
23
+ from pe import prompt_template # noqa
24
+ except Exception:
25
+ prompt_template = "You are a helpful prompt expander.\nUser prompt: {prompt}"
26
 
27
+ # -------------------- 环境变量 --------------------
28
  MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
29
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
30
  ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
31
  ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
32
  DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
33
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
34
 
35
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
36
  warnings.filterwarnings("ignore")
37
  logging.getLogger("transformers").setLevel(logging.ERROR)
38
 
39
+ # -------------------- 分辨率 --------------------
40
  RES_CHOICES = {
41
  "1024": [
42
  "1024x1024 ( 1:1 )",
 
60
  "1536x1024 ( 3:2 )",
61
  "1024x1536 ( 2:3 )",
62
  "1600x896 ( 16:9 )",
63
+ "896x1600 ( 9:16 )", # 注意:需要能被 16 整除
64
  "1680x720 ( 21:9 )",
65
  "720x1680 ( 9:21 )",
66
  ],
67
  }
68
 
 
 
 
 
69
  EXAMPLE_PROMPTS = [
70
  ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"],
71
+ ["极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。"],
72
+ ["一张中景手机自拍照片拍摄了一位留着长黑发的年轻东亚女子在灯光明亮的电梯内对着镜子自拍。她穿着一件带有白色花朵图案的黑色露肩短上衣和深色牛仔裤。她的头微微倾斜,嘴唇嘟起做亲吻状,非常可爱俏皮。她右手拿着一部深灰色智能手机,遮住了部分脸,后置摄像头镜头对着镜子"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  ]
74
 
75
+ # -------------------- 工具函数 --------------------
76
+ def get_resolution(resolution: str):
77
  match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
78
  if match:
79
  return int(match.group(1)), int(match.group(2))
80
  return 1024, 1024
81
 
82
 
83
+ def _from_pretrained_safe(cls, *args, token=None, use_auth_token=None, **kwargs):
84
+ """
85
+ 兼容 transformers/diffusers 不同版本的鉴权参数:
86
+ - 新版逐步用 token
87
+ - 旧版仍可能用 use_auth_token
88
+ """
89
+ if token is not None:
90
+ try:
91
+ return cls.from_pretrained(*args, token=token, **kwargs)
92
+ except TypeError:
93
+ pass
94
+ if use_auth_token is not None:
95
+ try:
96
+ return cls.from_pretrained(*args, use_auth_token=use_auth_token, **kwargs)
97
+ except TypeError:
98
+ pass
99
+ return cls.from_pretrained(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
 
102
+ def pick_dtype(device: torch.device) -> torch.dtype:
103
+ """
104
+ ZeroGPU/H200 一般 bfloat16 最合适;CPU 则使用 float32 更稳。
105
+ """
106
+ if device.type == "cuda":
107
+ return torch.bfloat16
108
+ return torch.float32
109
 
 
110
 
111
+ # -------------------- 全局状态(关键:启动阶段不触碰 CUDA) --------------------
112
+ pipe = None
113
+ pipe_device = torch.device("cpu")
114
+ pipe_dtype = torch.float32
 
 
 
 
 
 
 
 
 
 
115
 
116
+ prompt_expander = None
117
 
118
 
119
+ # -------------------- Prompt Expander(可选) --------------------
120
  @dataclass
121
  class PromptOutput:
122
  status: bool
 
148
  base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
149
 
150
  if not api_key:
151
+ print("Warning: DASHSCOPE_API_KEY not found. Prompt enhance will be disabled.")
152
  return None
153
 
154
  return OpenAI(api_key=api_key, base_url=base_url)
 
 
 
155
  except Exception as e:
156
+ print(f"Prompt expander init failed: {e}")
157
  return None
158
 
159
  def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs):
 
178
  temperature=0.7,
179
  top_p=0.8,
180
  )
181
+ content = response.choices[0].message.content or ""
 
182
  json_start = content.find("```json")
183
  if json_start != -1:
184
  json_end = content.find("```", json_start + 7)
185
  try:
186
+ json_str = content[json_start + 7: json_end].strip()
187
  data = json.loads(json_str)
188
  expanded_prompt = data.get("revised_prompt", content)
189
+ except Exception:
190
  expanded_prompt = content
191
  else:
192
  expanded_prompt = content
193
 
194
+ return PromptOutput(True, expanded_prompt, seed, system_prompt, content)
 
 
195
  except Exception as e:
196
  return PromptOutput(False, "", seed, system_prompt, str(e))
197
 
198
 
199
+ def init_prompt_expander():
200
+ global prompt_expander
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  try:
202
+ prompt_expander = APIPromptExpander(api_config={"model": "qwen3-max-preview"})
203
  print("Prompt expander initialized.")
204
  except Exception as e:
205
  print(f"Error initializing prompt expander: {e}")
206
  prompt_expander = None
207
 
208
 
209
+ def prompt_enhance(prompt: str, enable_enhance: bool):
210
  if not enable_enhance or not prompt_expander:
211
  return prompt, "Enhancement disabled or not available."
 
212
  if not prompt.strip():
213
  return "", "Please enter a prompt."
 
214
  try:
215
  result = prompt_expander(prompt)
216
  if result.status:
217
  return result.prompt, result.message
218
+ return prompt, f"Enhancement failed: {result.message}"
 
219
  except Exception as e:
220
  return prompt, f"Error: {str(e)}"
221
 
222
 
223
+ # -------------------- 模型加载(关键:提供“仅CPU加载”与“迁移到CUDA”) --------------------
224
+ def load_models_cpu_only(model_path: str):
225
+ """
226
+ 只在 CPU 上加载权重,避免 ZeroGPU 启动阶段触发 CUDA 初始化失败。
227
+ 真正推理前在 @spaces.GPU 函数里再搬到 cuda。
228
+ """
229
+ global pipe, pipe_device, pipe_dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ print(f"[load] Loading models (CPU only) from: {model_path}")
232
 
233
+ token = HF_TOKEN
234
+ use_auth_token = True if HF_TOKEN is None else HF_TOKEN
235
 
236
+ dtype = torch.float32
237
+ device = torch.device("cpu")
 
 
 
238
 
239
+ if not os.path.exists(model_path):
240
+ vae = _from_pretrained_safe(
241
+ AutoencoderKL, model_path, subfolder="vae",
242
+ torch_dtype=dtype,
243
+ token=token, use_auth_token=use_auth_token,
244
+ )
245
+ text_encoder = _from_pretrained_safe(
246
+ AutoModel, model_path, subfolder="text_encoder",
247
+ torch_dtype=dtype,
248
+ token=token, use_auth_token=use_auth_token,
249
+ ).eval()
250
+ tokenizer = _from_pretrained_safe(
251
+ AutoTokenizer, model_path, subfolder="tokenizer",
252
+ token=token, use_auth_token=use_auth_token,
253
+ )
254
+ transformer = _from_pretrained_safe(
255
+ ZImageTransformer2DModel, model_path, subfolder="transformer",
256
+ token=token, use_auth_token=use_auth_token,
257
+ )
258
+ else:
259
+ vae = _from_pretrained_safe(
260
+ AutoencoderKL, os.path.join(model_path, "vae"),
261
+ torch_dtype=dtype,
262
+ token=token, use_auth_token=use_auth_token,
263
+ )
264
+ text_encoder = _from_pretrained_safe(
265
+ AutoModel, os.path.join(model_path, "text_encoder"),
266
+ torch_dtype=dtype,
267
+ token=token, use_auth_token=use_auth_token,
268
+ ).eval()
269
+ tokenizer = _from_pretrained_safe(
270
+ AutoTokenizer, os.path.join(model_path, "tokenizer"),
271
+ token=token, use_auth_token=use_auth_token,
272
+ )
273
+ transformer = _from_pretrained_safe(
274
+ ZImageTransformer2DModel, os.path.join(model_path, "transformer"),
275
+ token=token, use_auth_token=use_auth_token,
276
+ )
277
 
278
+ tokenizer.padding_side = "left"
279
 
280
+ # 先用 CPU 组装 pipeline
281
+ _pipe = ZImagePipeline(
282
+ scheduler=None,
283
+ vae=vae,
284
+ text_encoder=text_encoder,
285
+ tokenizer=tokenizer,
286
+ transformer=transformer,
287
  )
288
 
289
+ # 注意:这里只能在 CPU 上设置,不要 .to("cuda")
290
+ try:
291
+ _pipe.transformer.set_attention_backend(ATTENTION_BACKEND)
292
+ except Exception as e:
293
+ print(f"[warn] set_attention_backend failed: {e}")
294
+
295
+ pipe = _pipe
296
+ pipe_device = device
297
+ pipe_dtype = dtype
298
+ print("[load] CPU pipeline ready.")
299
 
 
 
 
300
 
301
+ def move_pipe_to_device(target_device: torch.device):
302
+ """
303
+ 把已加载的 pipeline 从 CPU 迁移到 CUDA(只能在 @spaces.GPU 函数内部调用)。
304
+ """
305
+ global pipe, pipe_device, pipe_dtype
306
 
307
+ if pipe is None:
308
+ load_models_cpu_only(MODEL_PATH)
 
309
 
310
+ if pipe_device == target_device:
311
+ return
 
312
 
313
+ dtype = pick_dtype(target_device)
314
+ print(f"[move] Moving pipeline to {target_device} with dtype={dtype} ...")
315
 
316
+ # diffusers pipeline 支持 .to(device, dtype)
317
+ pipe.to(target_device, dtype=dtype)
 
318
 
319
+ # 编译仅在 CUDA 时启用;ZeroGPU 每次 attach GPU 后才有意义
320
+ if ENABLE_COMPILE and target_device.type == "cuda":
321
+ try:
322
+ print("[compile] Enabling torch.compile for transformer...")
323
+ torch._inductor.config.conv_1x1_as_mm = True
324
+ torch._inductor.config.coordinate_descent_tuning = True
325
+ torch._inductor.config.epilogue_fusion = False
326
+ torch._inductor.config.coordinate_descent_check_all_directions = True
327
+ torch._inductor.config.max_autotune_gemm = True
328
+ torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN"
329
+ torch._inductor.config.triton.cudagraphs = False
330
+
331
+ pipe.transformer = torch.compile(
332
+ pipe.transformer,
333
+ mode="max-autotune-no-cudagraphs",
334
+ fullgraph=False
335
  )
336
+ except Exception as e:
337
+ print(f"[warn] torch.compile failed, continue without compile: {e}")
338
 
339
+ pipe_device = target_device
340
+ pipe_dtype = dtype
341
+ print("[move] Done.")
 
 
 
342
 
 
343
 
344
+ def generate_image(
345
+ prompt: str,
346
+ resolution: str = "1024x1024",
347
+ seed: int = -1,
348
+ guidance_scale: float = 0.0,
349
+ num_inference_steps: int = 9,
350
+ shift: float = 3.0,
351
+ max_sequence_length: int = 512,
352
+ progress=gr.Progress(track_tqdm=True),
353
+ ):
354
+ global pipe, pipe_device, pipe_dtype
355
 
356
+ if pipe is None:
357
+ raise gr.Error("Pipeline not loaded.")
358
 
359
+ width, height = get_resolution(resolution)
 
 
 
 
 
360
 
361
+ if seed == -1:
362
+ seed = random.randint(1, 1_000_000)
363
+
364
+ device = pipe_device
365
+ generator = torch.Generator(device=device).manual_seed(int(seed))
366
+
367
+ # scheduler 每次生成都可重新设,避免共享状态问题
368
+ pipe.scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift))
369
+
370
+ # 推理加速:inference_mode + autocast(CUDA 下 bfloat16/float16)
371
+ with torch.inference_mode():
372
+ if device.type == "cuda":
373
+ with torch.autocast(device_type="cuda", dtype=pipe_dtype):
374
+ image = pipe(
375
+ prompt=prompt,
376
+ height=int(height),
377
+ width=int(width),
378
+ guidance_scale=float(guidance_scale),
379
+ num_inference_steps=int(num_inference_steps),
380
+ generator=generator,
381
+ max_sequence_length=int(max_sequence_length),
382
+ ).images[0]
383
+ else:
384
+ image = pipe(
385
+ prompt=prompt,
386
+ height=int(height),
387
+ width=int(width),
388
+ guidance_scale=float(guidance_scale),
389
+ num_inference_steps=int(num_inference_steps),
390
+ generator=generator,
391
+ max_sequence_length=int(max_sequence_length),
392
+ ).images[0]
393
+
394
+ return image, str(seed)
395
+
396
+
397
+ def warmup_if_needed():
398
+ """
399
+ 可选 warmup:仅在 CUDA 上做一次小步数 warmup。
400
+ """
401
+ if not ENABLE_WARMUP:
402
+ return
403
+ if pipe_device.type != "cuda":
404
+ return
405
+ try:
406
+ print("[warmup] start ...")
407
+ _ = generate_image("warmup", "1024x1024", seed=42, num_inference_steps=5, guidance_scale=0.0)[0]
408
+ print("[warmup] done.")
409
+ except Exception as e:
410
+ print(f"[warmup] failed: {e}")
411
 
 
 
 
 
 
412
 
413
+ # -------------------- ZeroGPU 推理入口(关键:所有 CUDA 操作都在这里) --------------------
414
+ @spaces.GPU(duration=180) # 你可以按模型大小调大,比如 300
415
+ def generate(
416
+ prompt: str,
417
+ resolution: str,
418
+ seed: int,
419
+ steps: int,
420
+ shift: float,
421
+ enhance: bool,
422
+ random_s_