Spaces:
Paused
Paused
| import gradio as gr | |
| import os | |
| hf_token = os.environ.get("HF_TOKEN") | |
| from gradio_client import Client | |
| client = Client("https://fffiloni-safety-checker-bot.hf.space/", hf_token=hf_token) | |
| import re | |
| import spaces | |
| from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel | |
| import torch | |
| from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | |
| # load pipeline | |
| model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
| pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda") | |
| # load finetuned model | |
| unet_id = "mhdang/dpo-sdxl-text2image-v1" | |
| unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16) | |
| pipe.unet = unet | |
| pipe = pipe.to("cuda") | |
| pipe.enable_model_cpu_offload() | |
| pipe.enable_vae_slicing() | |
| def safety_check(user_prompt): | |
| response = client.predict( | |
| user_prompt, # str in 'User sent this' Textbox component | |
| api_name="/infer" | |
| ) | |
| return response | |
| def infer(prompt): | |
| print(f""" | |
| —/n | |
| {prompt} | |
| """) | |
| is_safe = safety_check(prompt) | |
| print(is_safe) | |
| match = re.search(r'\bYes\b', is_safe) | |
| if match: | |
| status = 'Yes' | |
| else: | |
| status = None | |
| if status == "Yes" : | |
| raise gr.Error("Don't ask for such things.") | |
| else: | |
| results = pipe(prompt, guidance_scale=7.5) | |
| #for i in range(len(results.images)): | |
| # if results.nsfw_content_detected[i]: | |
| # results.images[i] = Image.open("nsfw.png") | |
| return results.images[0] | |
| css = """ | |
| #col-container{ | |
| margin: 0 auto; | |
| max-width: 580px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.HTML(""" | |
| <h2 style="text-align: center;"> | |
| SDXL Using Direct Preference Optimization | |
| </h2> | |
| <p style="text-align: center;"> | |
| Direct Preference Optimization (DPO) for text-to-image diffusion models is a method to align diffusion models to text human preferences by directly optimizing on human comparison data. | |
| </p> | |
| """) | |
| with gr.Group(): | |
| with gr.Column(): | |
| prompt_in = gr.Textbox(label="Prompt", value="An old man with a bird on his head") | |
| submit_btn = gr.Button("Submit") | |
| result = gr.Image(label="DPO SDXL Result") | |
| gr.Examples( | |
| examples = [ | |
| "Dragon, digital art, by Greg Rutkowski", | |
| "Armored knight holding sword", | |
| "A flat roof villa near a river with black walls and huge windows", | |
| "A calm and peaceful office", | |
| "Pirate guinea pig" | |
| ], | |
| fn = infer, | |
| inputs = [ | |
| prompt_in | |
| ], | |
| outputs = [ | |
| result | |
| ] | |
| ) | |
| submit_btn.click( | |
| fn = infer, | |
| inputs = [ | |
| prompt_in | |
| ], | |
| outputs = [ | |
| result | |
| ] | |
| ) | |
| demo.queue().launch(show_api=False) |