Files
HiDream-I1-nf4/inference.py
T
2025-04-07 15:31:48 +08:00

139 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import argparse
from hi_diffusers import HiDreamImagePipeline
from hi_diffusers import HiDreamImageTransformer2DModel
from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default="dev")
args = parser.parse_args()
model_type = args.model_type
MODEL_PREFIX = "HiDream-ai"
LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# Model configurations
MODEL_CONFIGS = {
"dev": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Dev",
"guidance_scale": 0.0,
"num_inference_steps": 28,
"shift": 6.0,
"scheduler": FlashFlowMatchEulerDiscreteScheduler
},
"full": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Full",
"guidance_scale": 5.0,
"num_inference_steps": 50,
"shift": 3.0,
"scheduler": FlowUniPCMultistepScheduler
},
"fast": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Fast",
"guidance_scale": 0.0,
"num_inference_steps": 16,
"shift": 3.0,
"scheduler": FlashFlowMatchEulerDiscreteScheduler
}
}
# Resolution options
RESOLUTION_OPTIONS = [
"1024 × 1024 (Square)",
"768 × 1360 (Portrait)",
"1360 × 768 (Landscape)",
"880 × 1168 (Portrait)",
"1168 × 880 (Landscape)",
"1248 × 832 (Landscape)",
"832 × 1248 (Portrait)"
]
# Load models
def load_models(model_type):
config = MODEL_CONFIGS[model_type]
pretrained_model_name_or_path = config["path"]
scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False)
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(
LLAMA_MODEL_NAME,
use_fast=False)
text_encoder_4 = LlamaForCausalLM.from_pretrained(
LLAMA_MODEL_NAME,
output_hidden_states=True,
output_attentions=True,
torch_dtype=torch.bfloat16).to("cuda")
transformer = HiDreamImageTransformer2DModel.from_pretrained(
pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=torch.bfloat16).to("cuda")
pipe = HiDreamImagePipeline.from_pretrained(
pretrained_model_name_or_path,
scheduler=scheduler,
tokenizer_4=tokenizer_4,
text_encoder_4=text_encoder_4,
torch_dtype=torch.bfloat16
).to("cuda", torch.bfloat16)
pipe.transformer = transformer
return pipe, config
# Parse resolution string to get height and width
def parse_resolution(resolution_str):
if "1024 × 1024" in resolution_str:
return 1024, 1024
elif "768 × 1360" in resolution_str:
return 768, 1360
elif "1360 × 768" in resolution_str:
return 1360, 768
elif "880 × 1168" in resolution_str:
return 880, 1168
elif "1168 × 880" in resolution_str:
return 1168, 880
elif "1248 × 832" in resolution_str:
return 1248, 832
elif "832 × 1248" in resolution_str:
return 832, 1248
else:
return 1024, 1024 # Default fallback
# Generate image function
def generate_image(pipe, model_type, prompt, resolution, seed):
# Get configuration for current model
config = MODEL_CONFIGS[model_type]
guidance_scale = config["guidance_scale"]
num_inference_steps = config["num_inference_steps"]
# Parse resolution
height, width = parse_resolution(resolution)
# Handle seed
if seed == -1:
seed = torch.randint(0, 1000000, (1,)).item()
generator = torch.Generator("cuda").manual_seed(seed)
images = pipe(
prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=1,
generator=generator
).images
return images[0], seed
# Initialize with default model
print("Loading default model (full)...")
pipe, _ = load_models(model_type)
print("Model loaded successfully!")
prompt = "A cat holding a sign that says \"Hi-Dreams.ai\"."
resolution = "1024 × 1024 (Square)"
seed = -1
image, seed = generate_image(pipe, model_type, prompt, resolution, seed)
image.save("output.png")