update readme
This commit is contained in:
@@ -1,2 +1,5 @@
|
||||
__pycache__
|
||||
tmp
|
||||
*_local.py
|
||||
*.jpg
|
||||
*.png
|
||||
@@ -1,6 +1,6 @@
|
||||
# HiDream-I1
|
||||
|
||||
`HiDream-I1` is a series of state-of-the-art open-source image generation models featuring a 16 billion parameter rectified flow transformer with Mixture of Experts architecture, designed to create high-quality images from text prompts.
|
||||
`HiDream-I1` is a new open-source image generative foundation model with 17B parameters that achieves state-of-the-art image generation quality within seconds.
|
||||
|
||||
## Project Updates
|
||||
- ```2025/4/7```: We've open-sourced the text-to-image model **HiDream-I1**.
|
||||
@@ -13,8 +13,8 @@ We offer both the full version and distilled models. For more information about
|
||||
| Name | Script | Inference Steps | HuggingFace repo |
|
||||
| --------------- | -------------------------------------------------- | --------------- | ---------------------- |
|
||||
| HiDream-I1-Full | [inference.py](./inference.py) | 50 | 🤗 [HiDream-I1-Full](https://huggingface.co/HiDream-ai/HiDream-I1-Full) |
|
||||
| HiDream-I1-Dev | [inference_distilled.py](./inference_distilled.py) | 28 | 🤗 [HiDream-I1-Dev](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) |
|
||||
| HiDream-I1-Fast | [inference_distilled.py](./inference_distilled.py) | 16 | 🤗 [HiDream-I1-Fast](https://huggingface.co/HiDream-ai/HiDream-I1-Fast) |
|
||||
| HiDream-I1-Dev | [inference.py](./inference_distilled.py) | 28 | 🤗 [HiDream-I1-Dev](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) |
|
||||
| HiDream-I1-Fast | [inference.py](./inference_distilled.py) | 16 | 🤗 [HiDream-I1-Fast](https://huggingface.co/HiDream-ai/HiDream-I1-Fast) |
|
||||
|
||||
|
||||
## Quick Start
|
||||
@@ -26,19 +26,26 @@ pip install -r requirements.txt
|
||||
Then you can run the inference scripts to generate images:
|
||||
|
||||
``` python
|
||||
|
||||
# For full model inference
|
||||
python ./inference.py
|
||||
python ./inference.py --model_type full
|
||||
|
||||
# For distilled dev model inference
|
||||
INFERENCE_STEP=28 PRETRAINED_MODEL_NAME_OR_PATH=HiDream-ai/HiDream-I1-Dev python inference_distilled.py
|
||||
python ./inference.py --model_type dev
|
||||
|
||||
# For distilled fast model inference
|
||||
INFERENCE_STEP=16 PRETRAINED_MODEL_NAME_OR_PATH=HiDream-ai/HiDream-I1-Fast python inference_distilled.py
|
||||
|
||||
python ./inference.py --model_type fast
|
||||
```
|
||||
> **Note:** The inference script will automatically download `meta-llama/Meta-Llama-3.1-8B-Instruct` model files. If you encounter network issues, you can download these files ahead of time and place them in the appropriate cache directory to avoid download failures during inference.
|
||||
|
||||
## Gradio Demo
|
||||
|
||||
We also provide a Gradio demo for interactive image generation. You can run the demo with:
|
||||
|
||||
``` python
|
||||
python gradio_demo.py
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Evaluation Metrics
|
||||
|
||||
|
||||
+190
@@ -0,0 +1,190 @@
|
||||
import torch
|
||||
import gradio as gr
|
||||
from hi_diffusers import HiDreamImagePipeline
|
||||
from hi_diffusers import HiDreamImageTransformer2DModel
|
||||
from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
|
||||
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
|
||||
|
||||
MODEL_PREFIX = "HiDream-ai"
|
||||
LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
|
||||
# Model configurations
|
||||
MODEL_CONFIGS = {
|
||||
"dev": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Dev",
|
||||
"guidance_scale": 0.0,
|
||||
"num_inference_steps": 28,
|
||||
"shift": 6.0,
|
||||
"scheduler": FlashFlowMatchEulerDiscreteScheduler
|
||||
},
|
||||
"full": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Full",
|
||||
"guidance_scale": 5.0,
|
||||
"num_inference_steps": 50,
|
||||
"shift": 3.0,
|
||||
"scheduler": FlowUniPCMultistepScheduler
|
||||
},
|
||||
"fast": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Fast",
|
||||
"guidance_scale": 0.0,
|
||||
"num_inference_steps": 16,
|
||||
"shift": 3.0,
|
||||
"scheduler": FlashFlowMatchEulerDiscreteScheduler
|
||||
}
|
||||
}
|
||||
|
||||
# Resolution options
|
||||
RESOLUTION_OPTIONS = [
|
||||
"1024 × 1024 (Square)",
|
||||
"768 × 1360 (Portrait)",
|
||||
"1360 × 768 (Landscape)",
|
||||
"880 × 1168 (Portrait)",
|
||||
"1168 × 880 (Landscape)",
|
||||
"1248 × 832 (Landscape)",
|
||||
"832 × 1248 (Portrait)"
|
||||
]
|
||||
|
||||
# Load models
|
||||
def load_models(model_type):
|
||||
config = MODEL_CONFIGS[model_type]
|
||||
pretrained_model_name_or_path = config["path"]
|
||||
scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False)
|
||||
|
||||
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(
|
||||
LLAMA_MODEL_NAME,
|
||||
use_fast=False)
|
||||
|
||||
text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
LLAMA_MODEL_NAME,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
transformer = HiDreamImageTransformer2DModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
pipe = HiDreamImagePipeline.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
scheduler=scheduler,
|
||||
tokenizer_4=tokenizer_4,
|
||||
text_encoder_4=text_encoder_4,
|
||||
torch_dtype=torch.bfloat16
|
||||
).to("cuda", torch.bfloat16)
|
||||
pipe.transformer = transformer
|
||||
|
||||
return pipe, config
|
||||
|
||||
# Parse resolution string to get height and width
|
||||
def parse_resolution(resolution_str):
|
||||
if "1024 × 1024" in resolution_str:
|
||||
return 1024, 1024
|
||||
elif "768 × 1360" in resolution_str:
|
||||
return 768, 1360
|
||||
elif "1360 × 768" in resolution_str:
|
||||
return 1360, 768
|
||||
elif "880 × 1168" in resolution_str:
|
||||
return 880, 1168
|
||||
elif "1168 × 880" in resolution_str:
|
||||
return 1168, 880
|
||||
elif "1248 × 832" in resolution_str:
|
||||
return 1248, 832
|
||||
elif "832 × 1248" in resolution_str:
|
||||
return 832, 1248
|
||||
else:
|
||||
return 1024, 1024 # Default fallback
|
||||
|
||||
# Generate image function
|
||||
def generate_image(model_type, prompt, resolution, seed):
|
||||
global pipe, current_model
|
||||
|
||||
# Reload model if needed
|
||||
if model_type != current_model:
|
||||
del pipe
|
||||
torch.cuda.empty_cache()
|
||||
print(f"Loading {model_type} model...")
|
||||
pipe, config = load_models(model_type)
|
||||
current_model = model_type
|
||||
print(f"{model_type} model loaded successfully!")
|
||||
|
||||
# Get configuration for current model
|
||||
config = MODEL_CONFIGS[model_type]
|
||||
guidance_scale = config["guidance_scale"]
|
||||
num_inference_steps = config["num_inference_steps"]
|
||||
|
||||
# Parse resolution
|
||||
height, width = parse_resolution(resolution)
|
||||
|
||||
# Handle seed
|
||||
if seed == -1:
|
||||
seed = torch.randint(0, 1000000, (1,)).item()
|
||||
|
||||
generator = torch.Generator("cuda").manual_seed(seed)
|
||||
|
||||
images = pipe(
|
||||
prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images_per_prompt=1,
|
||||
generator=generator
|
||||
).images
|
||||
|
||||
return images[0], seed
|
||||
|
||||
# Initialize with default model
|
||||
print("Loading default model (full)...")
|
||||
current_model = "full"
|
||||
pipe, _ = load_models(current_model)
|
||||
print("Model loaded successfully!")
|
||||
|
||||
# Create Gradio interface
|
||||
with gr.Blocks(title="HiDream Image Generator") as demo:
|
||||
gr.Markdown("# HiDream Image Generator")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
model_type = gr.Radio(
|
||||
choices=list(MODEL_CONFIGS.keys()),
|
||||
value="full",
|
||||
label="Model Type",
|
||||
info="Select model variant"
|
||||
)
|
||||
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
placeholder="A cat holding a sign that says \"Hi-Dreams.ai\".",
|
||||
lines=3
|
||||
)
|
||||
|
||||
resolution = gr.Radio(
|
||||
choices=RESOLUTION_OPTIONS,
|
||||
value=RESOLUTION_OPTIONS[0],
|
||||
label="Resolution",
|
||||
info="Select image resolution"
|
||||
)
|
||||
|
||||
seed = gr.Number(
|
||||
label="Seed (use -1 for random)",
|
||||
value=-1,
|
||||
precision=0
|
||||
)
|
||||
|
||||
generate_btn = gr.Button("Generate Image")
|
||||
seed_used = gr.Number(label="Seed Used", interactive=False)
|
||||
|
||||
with gr.Column():
|
||||
output_image = gr.Image(label="Generated Image", type="pil")
|
||||
|
||||
generate_btn.click(
|
||||
fn=generate_image,
|
||||
inputs=[model_type, prompt, resolution, seed],
|
||||
outputs=[output_image, seed_used]
|
||||
)
|
||||
|
||||
# Launch app
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
@@ -337,7 +337,7 @@ class HiDreamImageTransformer2DModel(
|
||||
return x
|
||||
|
||||
def patchify(self, x, max_seq, img_sizes=None):
|
||||
pz2 = self.patch_size * self.patch_size
|
||||
pz2 = self.config.patch_size * self.config.patch_size
|
||||
if isinstance(x, torch.Tensor):
|
||||
B, C = x.shape[0], x.shape[1]
|
||||
device = x.device
|
||||
|
||||
+109
-13
@@ -1,15 +1,65 @@
|
||||
import torch
|
||||
import argparse
|
||||
from hi_diffusers import HiDreamImagePipeline
|
||||
from hi_diffusers import HiDreamImageTransformer2DModel
|
||||
from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
|
||||
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
|
||||
pretrained_model_name_or_path = "HiDream-ai/HiDream-I1-Full"
|
||||
scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=1, use_dynamic_shifting=False)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_type", type=str, default="dev")
|
||||
args = parser.parse_args()
|
||||
model_type = args.model_type
|
||||
MODEL_PREFIX = "HiDream-ai"
|
||||
LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
|
||||
# Model configurations
|
||||
MODEL_CONFIGS = {
|
||||
"dev": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Dev",
|
||||
"guidance_scale": 0.0,
|
||||
"num_inference_steps": 28,
|
||||
"shift": 6.0,
|
||||
"scheduler": FlashFlowMatchEulerDiscreteScheduler
|
||||
},
|
||||
"full": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Full",
|
||||
"guidance_scale": 5.0,
|
||||
"num_inference_steps": 50,
|
||||
"shift": 3.0,
|
||||
"scheduler": FlowUniPCMultistepScheduler
|
||||
},
|
||||
"fast": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Fast",
|
||||
"guidance_scale": 0.0,
|
||||
"num_inference_steps": 16,
|
||||
"shift": 3.0,
|
||||
"scheduler": FlashFlowMatchEulerDiscreteScheduler
|
||||
}
|
||||
}
|
||||
|
||||
# Resolution options
|
||||
RESOLUTION_OPTIONS = [
|
||||
"1024 × 1024 (Square)",
|
||||
"768 × 1360 (Portrait)",
|
||||
"1360 × 768 (Landscape)",
|
||||
"880 × 1168 (Portrait)",
|
||||
"1168 × 880 (Landscape)",
|
||||
"1248 × 832 (Landscape)",
|
||||
"832 × 1248 (Portrait)"
|
||||
]
|
||||
|
||||
# Load models
|
||||
def load_models(model_type):
|
||||
config = MODEL_CONFIGS[model_type]
|
||||
pretrained_model_name_or_path = config["path"]
|
||||
scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False)
|
||||
|
||||
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
LLAMA_MODEL_NAME,
|
||||
use_fast=False)
|
||||
|
||||
text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
LLAMA_MODEL_NAME,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
torch_dtype=torch.bfloat16).to("cuda")
|
||||
@@ -25,18 +75,64 @@ pipe = HiDreamImagePipeline.from_pretrained(
|
||||
tokenizer_4=tokenizer_4,
|
||||
text_encoder_4=text_encoder_4,
|
||||
torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
).to("cuda", torch.bfloat16)
|
||||
pipe.transformer = transformer
|
||||
|
||||
prompt = "A cat holding a sign that says \"Hi-Dreams.ai\"."
|
||||
return pipe, config
|
||||
|
||||
# Parse resolution string to get height and width
|
||||
def parse_resolution(resolution_str):
|
||||
if "1024 × 1024" in resolution_str:
|
||||
return 1024, 1024
|
||||
elif "768 × 1360" in resolution_str:
|
||||
return 768, 1360
|
||||
elif "1360 × 768" in resolution_str:
|
||||
return 1360, 768
|
||||
elif "880 × 1168" in resolution_str:
|
||||
return 880, 1168
|
||||
elif "1168 × 880" in resolution_str:
|
||||
return 1168, 880
|
||||
elif "1248 × 832" in resolution_str:
|
||||
return 1248, 832
|
||||
elif "832 × 1248" in resolution_str:
|
||||
return 832, 1248
|
||||
else:
|
||||
return 1024, 1024 # Default fallback
|
||||
|
||||
# Generate image function
|
||||
def generate_image(pipe, model_type, prompt, resolution, seed):
|
||||
# Get configuration for current model
|
||||
config = MODEL_CONFIGS[model_type]
|
||||
guidance_scale = config["guidance_scale"]
|
||||
num_inference_steps = config["num_inference_steps"]
|
||||
|
||||
# Parse resolution
|
||||
height, width = parse_resolution(resolution)
|
||||
|
||||
# Handle seed
|
||||
if seed == -1:
|
||||
seed = torch.randint(0, 1000000, (1,)).item()
|
||||
|
||||
generator = torch.Generator("cuda").manual_seed(seed)
|
||||
|
||||
images = pipe(
|
||||
prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
guidance_scale=5.0,
|
||||
num_inference_steps=50,
|
||||
height=height,
|
||||
width=width,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images_per_prompt=1,
|
||||
generator=torch.Generator("cuda").manual_seed(42)
|
||||
generator=generator
|
||||
).images
|
||||
for i, image in enumerate(images):
|
||||
image.save(f"{i}.jpg")
|
||||
|
||||
return images[0], seed
|
||||
|
||||
# Initialize with default model
|
||||
print("Loading default model (full)...")
|
||||
pipe, _ = load_models(model_type)
|
||||
print("Model loaded successfully!")
|
||||
prompt = "A cat holding a sign that says \"Hi-Dreams.ai\"."
|
||||
resolution = "1024 × 1024 (Square)"
|
||||
seed = -1
|
||||
image, seed = generate_image(pipe, model_type, prompt, resolution, seed)
|
||||
image.save("output.png")
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from hi_diffusers import HiDreamImagePipeline
|
||||
from hi_diffusers import HiDreamImageTransformer2DModel
|
||||
from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
|
||||
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
|
||||
|
||||
INFERENCE_STEP = int(os.getenv("INFERENCE_STEP", "28"))
|
||||
PRETRAINED_MODEL_NAME_OR_PATH = os.getenv("PRETRAINED_MODEL_NAME_OR_PATH", "HiDream-I1-Dev")
|
||||
|
||||
scheduler = FlashFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=6.0, use_dynamic_shifting=False)
|
||||
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
use_fast=False)
|
||||
text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
transformer = HiDreamImageTransformer2DModel.from_pretrained(
|
||||
PRETRAINED_MODEL_NAME_OR_PATH,
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
pipe = HiDreamImagePipeline.from_pretrained(
|
||||
PRETRAINED_MODEL_NAME_OR_PATH,
|
||||
scheduler=scheduler,
|
||||
tokenizer_4=tokenizer_4,
|
||||
text_encoder_4=text_encoder_4,
|
||||
torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
pipe.transformer = transformer
|
||||
|
||||
prompt = "A cat holding a sign that says \"Hi-Dreams.ai\"."
|
||||
images = pipe(
|
||||
prompt,
|
||||
height=1024,
|
||||
width=1024,
|
||||
guidance_scale=0.0,
|
||||
num_inference_steps=INFERENCE_STEP,
|
||||
num_images_per_prompt=1,
|
||||
generator=torch.Generator("cuda").manual_seed(42)
|
||||
).images
|
||||
for i, image in enumerate(images):
|
||||
image.save(f"{i}.jpg")
|
||||
Reference in New Issue
Block a user