Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2b63a34470 | |||
| 593238c877 | |||
| 97c56cef94 | |||
| 1c740cc0b2 | |||
| 2b44fa7748 | |||
| 3c7c1b808c | |||
| f51b4c9a16 | |||
| 526a151f96 | |||
| 26f220b443 | |||
| 5e133fcd33 | |||
| bc2b7f9d3e | |||
| 0e585bfdbe | |||
| 03f617f943 | |||
| c19bc8683a | |||
| 5454087833 | |||
| eb7435d49a | |||
| 91c0bf98ad | |||
| c0ad4d4b39 | |||
| f203cdd78f | |||
| 3187ae752e | |||
| 1f178ff452 | |||
| 5e57ac1409 | |||
| 99c7dfa554 | |||
| 0efca9dd1a | |||
| 854a88f0ab | |||
| 10fc2afbbb | |||
| c6453bc773 | |||
| e14b033da4 | |||
| 13834643ec | |||
| 71c7fe2166 | |||
| 5fd395d0fe |
@@ -5,3 +5,5 @@ tmp
|
||||
*.png
|
||||
*.tar
|
||||
*.txt
|
||||
*.egg-info
|
||||
dist/
|
||||
@@ -0,0 +1 @@
|
||||
3.10
|
||||
@@ -1,100 +1,67 @@
|
||||
# HiDream-I1
|
||||
# HiDream-I1 4Bit Quantized Model
|
||||
|
||||
This repository is a fork of `HiDream-I1` quantized to 4 bits, allowing the full model to run in less than 16GB of VRAM.
|
||||
|
||||
The original repository can be found [here](https://github.com/HiDream-ai/HiDream-I1).
|
||||
|
||||
> `HiDream-I1` is a new open-source image generative foundation model with 17B parameters that achieves state-of-the-art image generation quality within seconds.
|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
`HiDream-I1` is a new open-source image generative foundation model with 17B parameters that achieves state-of-the-art image generation quality within seconds.
|
||||
|
||||
<span style="color: #FF5733; font-weight: bold">For more features and to experience the full capabilities of our product, please visit [https://vivago.ai/](https://vivago.ai/).</span>
|
||||
|
||||
## Project Updates
|
||||
- 🤗 **April 8, 2025**: We've launched a Hugging Face Space for **HiDream-I1-Dev**. Experience our model firsthand at [https://huggingface.co/spaces/HiDream-ai/HiDream-I1-Dev](https://huggingface.co/spaces/HiDream-ai/HiDream-I1-Dev)!
|
||||
- 🚀 **April 7, 2025**: We've open-sourced the text-to-image model **HiDream-I1**.
|
||||
|
||||

|
||||
|
||||
## Models
|
||||
|
||||
We offer both the full version and distilled models. For more information about the models, please refer to the link under Usage.
|
||||
We offer both the full version and distilled models. The parameter size are the same, so they require the same amount of GPU memory to run. However, the distilled models are faster because of reduced number of inference steps.
|
||||
|
||||
| Name | Script | Inference Steps | HuggingFace repo |
|
||||
| --------------- | -------------------------------------------------- | --------------- | ---------------------- |
|
||||
| HiDream-I1-Full | [inference.py](./inference.py) | 50 | 🤗 [HiDream-I1-Full](https://huggingface.co/HiDream-ai/HiDream-I1-Full) |
|
||||
| HiDream-I1-Dev | [inference.py](./inference.py) | 28 | 🤗 [HiDream-I1-Dev](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) |
|
||||
| HiDream-I1-Fast | [inference.py](./inference.py) | 16 | 🤗 [HiDream-I1-Fast](https://huggingface.co/HiDream-ai/HiDream-I1-Fast) |
|
||||
| Name | Min VRAM | Steps | HuggingFace |
|
||||
|-----------------|----------|-------|------------------------------------------------------------------------------------------------------------------------------|
|
||||
| HiDream-I1-Full | 16 GB | 50 | 🤗 [Original](https://huggingface.co/HiDream-ai/HiDream-I1-Full) / [NF4](https://huggingface.co/azaneko/HiDream-I1-Full-nf4) |
|
||||
| HiDream-I1-Dev | 16 GB | 28 | 🤗 [Original](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) / [NF4](https://huggingface.co/azaneko/HiDream-I1-Dev-nf4) |
|
||||
| HiDream-I1-Fast | 16 GB | 16 | 🤗 [Original](https://huggingface.co/HiDream-ai/HiDream-I1-Fast) / [NF4](https://huggingface.co/azaneko/HiDream-I1-Fast-nf4) |
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
- GPU Architecture: NVIDIA `>= Ampere` (e.g. A100, H100, A40, RTX 3090, RTX 4090)
|
||||
- GPU RAM: `>= 16 GB`
|
||||
- CPU RAM: `>= 16 GB`
|
||||
|
||||
## Quick Start
|
||||
Please make sure you have installed [Flash Attention](https://github.com/Dao-AILab/flash-attention). We recommend CUDA versions 12.4 for the manual installation.
|
||||
|
||||
Simply run:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
pip install hdi1 --no-build-isolation
|
||||
```
|
||||
|
||||
Then you can run the inference scripts to generate images:
|
||||
> [!NOTE]
|
||||
> It's recommended that you start a new python environment for this package to avoid dependency conflicts.
|
||||
> To do that, you can use `conda create -n hdi1 python=3.12` and then `conda activate hdi1`.
|
||||
> Or you can use `python3 -m venv venv` and then `source venv/bin/activate` on Linux or `venv\Scripts\activate` on Windows.
|
||||
|
||||
### Command Line Interface
|
||||
|
||||
Then you can run the module to generate images:
|
||||
|
||||
``` python
|
||||
# For full model inference
|
||||
python ./inference.py --model_type full
|
||||
python -m hdi1 "A cat holding a sign that says 'hello world'"
|
||||
|
||||
# For distilled dev model inference
|
||||
python ./inference.py --model_type dev
|
||||
|
||||
# For distilled fast model inference
|
||||
python ./inference.py --model_type fast
|
||||
# or you can specify the model
|
||||
python -m hdi1 "A cat holding a sign that says 'hello world'" -m fast
|
||||
```
|
||||
> **Note:** The inference script will automatically download `meta-llama/Meta-Llama-3.1-8B-Instruct` model files. If you encounter network issues, you can download these files ahead of time and place them in the appropriate cache directory to avoid download failures during inference.
|
||||
|
||||
## Gradio Demo
|
||||
> [!NOTE]
|
||||
> The inference script will try to automatically download `meta-llama/Llama-3.1-8B-Instruct` model files. You need to [agree to the license of the Llama model](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) on your HuggingFace account and login using `huggingface-cli login` in order to use the automatic downloader.
|
||||
|
||||
We also provide a Gradio demo for interactive image generation. You can run the demo with:
|
||||
### Web Dashboard
|
||||
|
||||
We also provide a web dashboard for interactive image generation. You can start it by running:
|
||||
|
||||
``` python
|
||||
python gradio_demo.py
|
||||
python -m hdi1.web
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Evaluation Metrics
|
||||
|
||||
### DPG-Bench
|
||||
| Model | Overall | Global | Entity | Attribute | Relation | Other |
|
||||
| -------------- | --------- | ------ | ------ | --------- | -------- | ----- |
|
||||
| PixArt-alpha | 71.11 | 74.97 | 79.32 | 78.60 | 82.57 | 76.96 |
|
||||
| SDXL | 74.65 | 83.27 | 82.43 | 80.91 | 86.76 | 80.41 |
|
||||
| DALL-E 3 | 83.50 | 90.97 | 89.61 | 88.39 | 90.58 | 89.83 |
|
||||
| Flux.1-dev | 83.79 | 85.80 | 86.79 | 89.98 | 90.04 | 89.90 |
|
||||
| SD3-Medium | 84.08 | 87.90 | 91.01 | 88.83 | 80.70 | 88.68 |
|
||||
| Janus-Pro-7B | 84.19 | 86.90 | 88.90 | 89.40 | 89.32 | 89.48 |
|
||||
| CogView4-6B | 85.13 | 83.85 | 90.35 | 91.17 | 91.14 | 87.29 |
|
||||
| **HiDream-I1** | **85.89** | 76.44 | 90.22 | 89.48 | 93.74 | 91.83 |
|
||||
|
||||
### GenEval
|
||||
|
||||
| Model | Overall | Single Obj. | Two Obj. | Counting | Colors | Position | Color attribution |
|
||||
| -------------- | -------- | ----------- | -------- | -------- | ------ | -------- | ----------------- |
|
||||
| SDXL | 0.55 | 0.98 | 0.74 | 0.39 | 0.85 | 0.15 | 0.23 |
|
||||
| PixArt-alpha | 0.48 | 0.98 | 0.50 | 0.44 | 0.80 | 0.08 | 0.07 |
|
||||
| Flux.1-dev | 0.66 | 0.98 | 0.79 | 0.73 | 0.77 | 0.22 | 0.45 |
|
||||
| DALL-E 3 | 0.67 | 0.96 | 0.87 | 0.47 | 0.83 | 0.43 | 0.45 |
|
||||
| CogView4-6B | 0.73 | 0.99 | 0.86 | 0.66 | 0.79 | 0.48 | 0.58 |
|
||||
| SD3-Medium | 0.74 | 0.99 | 0.94 | 0.72 | 0.89 | 0.33 | 0.60 |
|
||||
| Janus-Pro-7B | 0.80 | 0.99 | 0.89 | 0.59 | 0.90 | 0.79 | 0.66 |
|
||||
| **HiDream-I1** | **0.83** | 1.00 | 0.98 | 0.79 | 0.91 | 0.60 | 0.72 |
|
||||
|
||||
### HPSv2.1 benchmark
|
||||
|
||||
| Model | Averaged | Animation | Concept-art | Painting | Photo |
|
||||
| --------------------- | --------- | --------- | ----------- | -------- | ----- |
|
||||
| Stable Diffusion v2.0 | 26.38 | 27.09 | 26.02 | 25.68 | 26.73 |
|
||||
| Midjourney V6 | 30.29 | 32.02 | 30.29 | 29.74 | 29.10 |
|
||||
| SDXL | 30.64 | 32.84 | 31.36 | 30.86 | 27.48 |
|
||||
| Dall-E3 | 31.44 | 32.39 | 31.09 | 31.18 | 31.09 |
|
||||
| SD3 | 31.53 | 32.60 | 31.82 | 32.06 | 29.62 |
|
||||
| Midjourney V5 | 32.33 | 34.05 | 32.47 | 32.24 | 30.56 |
|
||||
| CogView4-6B | 32.31 | 33.23 | 32.60 | 32.89 | 30.52 |
|
||||
| Flux.1-dev | 32.47 | 33.87 | 32.27 | 32.62 | 31.11 |
|
||||
| stable cascade | 32.95 | 34.58 | 33.13 | 33.29 | 30.78 |
|
||||
| **HiDream-I1** | **33.82** | 35.05 | 33.74 | 33.88 | 32.61 |
|
||||

|
||||
|
||||
## License
|
||||
|
||||
|
||||
-190
@@ -1,190 +0,0 @@
|
||||
import torch
|
||||
import gradio as gr
|
||||
from hi_diffusers import HiDreamImagePipeline
|
||||
from hi_diffusers import HiDreamImageTransformer2DModel
|
||||
from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
|
||||
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
|
||||
|
||||
MODEL_PREFIX = "HiDream-ai"
|
||||
LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
|
||||
# Model configurations
|
||||
MODEL_CONFIGS = {
|
||||
"dev": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Dev",
|
||||
"guidance_scale": 0.0,
|
||||
"num_inference_steps": 28,
|
||||
"shift": 6.0,
|
||||
"scheduler": FlashFlowMatchEulerDiscreteScheduler
|
||||
},
|
||||
"full": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Full",
|
||||
"guidance_scale": 5.0,
|
||||
"num_inference_steps": 50,
|
||||
"shift": 3.0,
|
||||
"scheduler": FlowUniPCMultistepScheduler
|
||||
},
|
||||
"fast": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Fast",
|
||||
"guidance_scale": 0.0,
|
||||
"num_inference_steps": 16,
|
||||
"shift": 3.0,
|
||||
"scheduler": FlashFlowMatchEulerDiscreteScheduler
|
||||
}
|
||||
}
|
||||
|
||||
# Resolution options
|
||||
RESOLUTION_OPTIONS = [
|
||||
"1024 × 1024 (Square)",
|
||||
"768 × 1360 (Portrait)",
|
||||
"1360 × 768 (Landscape)",
|
||||
"880 × 1168 (Portrait)",
|
||||
"1168 × 880 (Landscape)",
|
||||
"1248 × 832 (Landscape)",
|
||||
"832 × 1248 (Portrait)"
|
||||
]
|
||||
|
||||
# Load models
|
||||
def load_models(model_type):
|
||||
config = MODEL_CONFIGS[model_type]
|
||||
pretrained_model_name_or_path = config["path"]
|
||||
scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False)
|
||||
|
||||
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(
|
||||
LLAMA_MODEL_NAME,
|
||||
use_fast=False)
|
||||
|
||||
text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
LLAMA_MODEL_NAME,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
transformer = HiDreamImageTransformer2DModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
pipe = HiDreamImagePipeline.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
scheduler=scheduler,
|
||||
tokenizer_4=tokenizer_4,
|
||||
text_encoder_4=text_encoder_4,
|
||||
torch_dtype=torch.bfloat16
|
||||
).to("cuda", torch.bfloat16)
|
||||
pipe.transformer = transformer
|
||||
|
||||
return pipe, config
|
||||
|
||||
# Parse resolution string to get height and width
|
||||
def parse_resolution(resolution_str):
|
||||
if "1024 × 1024" in resolution_str:
|
||||
return 1024, 1024
|
||||
elif "768 × 1360" in resolution_str:
|
||||
return 768, 1360
|
||||
elif "1360 × 768" in resolution_str:
|
||||
return 1360, 768
|
||||
elif "880 × 1168" in resolution_str:
|
||||
return 880, 1168
|
||||
elif "1168 × 880" in resolution_str:
|
||||
return 1168, 880
|
||||
elif "1248 × 832" in resolution_str:
|
||||
return 1248, 832
|
||||
elif "832 × 1248" in resolution_str:
|
||||
return 832, 1248
|
||||
else:
|
||||
return 1024, 1024 # Default fallback
|
||||
|
||||
# Generate image function
|
||||
def generate_image(model_type, prompt, resolution, seed):
|
||||
global pipe, current_model
|
||||
|
||||
# Reload model if needed
|
||||
if model_type != current_model:
|
||||
del pipe
|
||||
torch.cuda.empty_cache()
|
||||
print(f"Loading {model_type} model...")
|
||||
pipe, config = load_models(model_type)
|
||||
current_model = model_type
|
||||
print(f"{model_type} model loaded successfully!")
|
||||
|
||||
# Get configuration for current model
|
||||
config = MODEL_CONFIGS[model_type]
|
||||
guidance_scale = config["guidance_scale"]
|
||||
num_inference_steps = config["num_inference_steps"]
|
||||
|
||||
# Parse resolution
|
||||
height, width = parse_resolution(resolution)
|
||||
|
||||
# Handle seed
|
||||
if seed == -1:
|
||||
seed = torch.randint(0, 1000000, (1,)).item()
|
||||
|
||||
generator = torch.Generator("cuda").manual_seed(seed)
|
||||
|
||||
images = pipe(
|
||||
prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images_per_prompt=1,
|
||||
generator=generator
|
||||
).images
|
||||
|
||||
return images[0], seed
|
||||
|
||||
# Initialize with default model
|
||||
print("Loading default model (full)...")
|
||||
current_model = "full"
|
||||
pipe, _ = load_models(current_model)
|
||||
print("Model loaded successfully!")
|
||||
|
||||
# Create Gradio interface
|
||||
with gr.Blocks(title="HiDream Image Generator") as demo:
|
||||
gr.Markdown("# HiDream Image Generator")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
model_type = gr.Radio(
|
||||
choices=list(MODEL_CONFIGS.keys()),
|
||||
value="full",
|
||||
label="Model Type",
|
||||
info="Select model variant"
|
||||
)
|
||||
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
placeholder="A cat holding a sign that says \"Hi-Dreams.ai\".",
|
||||
lines=3
|
||||
)
|
||||
|
||||
resolution = gr.Radio(
|
||||
choices=RESOLUTION_OPTIONS,
|
||||
value=RESOLUTION_OPTIONS[0],
|
||||
label="Resolution",
|
||||
info="Select image resolution"
|
||||
)
|
||||
|
||||
seed = gr.Number(
|
||||
label="Seed (use -1 for random)",
|
||||
value=-1,
|
||||
precision=0
|
||||
)
|
||||
|
||||
generate_btn = gr.Button("Generate Image")
|
||||
seed_used = gr.Number(label="Seed Used", interactive=False)
|
||||
|
||||
with gr.Column():
|
||||
output_image = gr.Image(label="Generated Image", type="pil")
|
||||
|
||||
generate_btn.click(
|
||||
fn=generate_image,
|
||||
inputs=[model_type, prompt, resolution, seed],
|
||||
outputs=[output_image, seed_used]
|
||||
)
|
||||
|
||||
# Launch app
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
@@ -0,0 +1,43 @@
|
||||
from .nf4 import *
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import logging
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("prompt", type=str, help="Prompt to generate image from")
|
||||
|
||||
parser.add_argument("-m", "--model", type=str, default="dev",
|
||||
help="Model to use",
|
||||
choices=["dev", "full", "fast"])
|
||||
|
||||
parser.add_argument("-s", "--seed", type=int, default=-1,
|
||||
help="Seed for generation")
|
||||
|
||||
parser.add_argument("-r", "--res", type=str, default="1024x1024",
|
||||
help="Resolution for generation",
|
||||
choices=["1024x1024", "768x1360", "1360x768", "880x1168", "1168x880", "1248x832", "832x1248"])
|
||||
|
||||
parser.add_argument("-o", "--output", type=str, default="output.png")
|
||||
|
||||
args = parser.parse_args()
|
||||
model_type = args.model
|
||||
|
||||
# Initialize with default model
|
||||
print(f"Loading model {model_type}...")
|
||||
pipe, _ = load_models(model_type)
|
||||
print("Model loaded successfully!")
|
||||
|
||||
st = time.time()
|
||||
|
||||
resolution = tuple(map(int, args.res.strip().split("x")))
|
||||
image, seed = generate_image(pipe, model_type, args.prompt, resolution, args.seed)
|
||||
image.save(args.output)
|
||||
|
||||
print(f"Image saved to {args.output}, elapsed time: {time.time() - st:.2f} seconds")
|
||||
print(f"Seed used: {seed}")
|
||||
+108
@@ -0,0 +1,108 @@
|
||||
import torch
|
||||
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
|
||||
|
||||
from . import HiDreamImagePipeline
|
||||
from . import HiDreamImageTransformer2DModel
|
||||
from .schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from .schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
|
||||
|
||||
|
||||
MODEL_PREFIX = "azaneko"
|
||||
LLAMA_MODEL_NAME = "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
|
||||
|
||||
|
||||
# Model configurations
|
||||
MODEL_CONFIGS = {
|
||||
"dev": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Dev-nf4",
|
||||
"guidance_scale": 0.0,
|
||||
"num_inference_steps": 28,
|
||||
"shift": 6.0,
|
||||
"scheduler": FlashFlowMatchEulerDiscreteScheduler
|
||||
},
|
||||
"full": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Full-nf4",
|
||||
"guidance_scale": 5.0,
|
||||
"num_inference_steps": 50,
|
||||
"shift": 3.0,
|
||||
"scheduler": FlowUniPCMultistepScheduler
|
||||
},
|
||||
"fast": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Fast-nf4",
|
||||
"guidance_scale": 0.0,
|
||||
"num_inference_steps": 16,
|
||||
"shift": 3.0,
|
||||
"scheduler": FlashFlowMatchEulerDiscreteScheduler
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def log_vram(msg: str):
|
||||
print(f"{msg} (used {torch.cuda.memory_allocated() / 1024**2:.2f} MB VRAM)\n")
|
||||
|
||||
|
||||
def load_models(model_type: str):
|
||||
config = MODEL_CONFIGS[model_type]
|
||||
|
||||
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(LLAMA_MODEL_NAME)
|
||||
log_vram("✅ Tokenizer loaded!")
|
||||
|
||||
text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
LLAMA_MODEL_NAME,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
log_vram("✅ Text encoder loaded!")
|
||||
|
||||
transformer = HiDreamImageTransformer2DModel.from_pretrained(
|
||||
config["path"],
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
log_vram("✅ Transformer loaded!")
|
||||
|
||||
pipe = HiDreamImagePipeline.from_pretrained(
|
||||
config["path"],
|
||||
scheduler=config["scheduler"](num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False),
|
||||
tokenizer_4=tokenizer_4,
|
||||
text_encoder_4=text_encoder_4,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.transformer = transformer
|
||||
log_vram("✅ Pipeline loaded!")
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
return pipe, config
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_image(pipe: HiDreamImagePipeline, model_type: str, prompt: str, resolution: tuple[int, int], seed: int):
|
||||
# Get configuration for current model
|
||||
config = MODEL_CONFIGS[model_type]
|
||||
guidance_scale = config["guidance_scale"]
|
||||
num_inference_steps = config["num_inference_steps"]
|
||||
|
||||
# Parse resolution
|
||||
width, height = resolution
|
||||
|
||||
# Handle seed
|
||||
if seed == -1:
|
||||
seed = torch.randint(0, 1000000, (1,)).item()
|
||||
|
||||
generator = torch.Generator("cuda").manual_seed(seed)
|
||||
|
||||
images = pipe(
|
||||
prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images_per_prompt=1,
|
||||
generator=generator
|
||||
).images
|
||||
|
||||
return images[0], seed
|
||||
|
||||
+323
@@ -0,0 +1,323 @@
|
||||
import torch
|
||||
import gradio as gr
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import glob
|
||||
from datetime import datetime
|
||||
from PIL import Image
|
||||
from .nf4 import *
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Output directory for saving images
|
||||
OUTPUT_DIR = os.path.join("outputs")
|
||||
|
||||
# Resolution options
|
||||
RESOLUTION_OPTIONS = [
|
||||
"1024 × 1024 (Square)",
|
||||
"768 × 1360 (Portrait)",
|
||||
"1360 × 768 (Landscape)",
|
||||
"880 × 1168 (Portrait)",
|
||||
"1168 × 880 (Landscape)",
|
||||
"1248 × 832 (Landscape)",
|
||||
"832 × 1248 (Portrait)"
|
||||
]
|
||||
|
||||
# Scheduler options (flow-matching only)
|
||||
SCHEDULER_OPTIONS = [
|
||||
"FlashFlowMatchEulerDiscreteScheduler",
|
||||
"FlowUniPCMultistepScheduler"
|
||||
]
|
||||
|
||||
# Image format options
|
||||
IMAGE_FORMAT_OPTIONS = ["PNG", "JPEG", "WEBP"]
|
||||
|
||||
# Parse resolution string to get height and width
|
||||
def parse_resolution(resolution_str):
|
||||
try:
|
||||
return tuple(map(int, resolution_str.split("(")[0].strip().split(" × ")))
|
||||
except (ValueError, IndexError) as e:
|
||||
raise ValueError("Invalid resolution format") from e
|
||||
|
||||
def clean_previous_temp_files():
|
||||
"""Delete temporary files from previous generations matching hdi1_* pattern and log Gradio temp files."""
|
||||
temp_dir = tempfile.gettempdir()
|
||||
patterns = [os.path.join(temp_dir, f"hdi1_*.{ext}") for ext in ["png", "jpeg", "webp"]]
|
||||
deleted_files = []
|
||||
|
||||
# Clean hdi1_* files
|
||||
for pattern in patterns:
|
||||
for temp_file in glob.glob(pattern):
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
deleted_files.append(temp_file)
|
||||
logger.info(f"Deleted temporary file: {temp_file}")
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to delete temporary file {temp_file}: {str(e)}")
|
||||
|
||||
# Log Gradio temp files (for monitoring)
|
||||
gradio_temp_dir = os.path.join(temp_dir, "gradio")
|
||||
if os.path.exists(gradio_temp_dir):
|
||||
for root, _, files in os.walk(gradio_temp_dir):
|
||||
for file in files:
|
||||
if file.endswith((".png", ".jpeg", ".webp")):
|
||||
gradio_file = os.path.join(root, file)
|
||||
logger.info(f"Found Gradio temporary file: {gradio_file}")
|
||||
|
||||
return deleted_files
|
||||
|
||||
def clean_all_temp_files():
|
||||
"""Manually clean hdi1_* and Gradio temporary files, with user confirmation."""
|
||||
status_message = "Starting temporary file cleanup..."
|
||||
logger.info(status_message)
|
||||
|
||||
try:
|
||||
# Clean hdi1_* files
|
||||
deleted_files = clean_previous_temp_files()
|
||||
|
||||
# Clean Gradio temp files
|
||||
temp_dir = tempfile.gettempdir()
|
||||
gradio_temp_dir = os.path.join(temp_dir, "gradio")
|
||||
if os.path.exists(gradio_temp_dir):
|
||||
for root, _, files in os.walk(gradio_temp_dir):
|
||||
for file in files:
|
||||
if file.endswith((".png", ".jpeg", ".webp")):
|
||||
gradio_file = os.path.join(root, file)
|
||||
try:
|
||||
os.remove(gradio_file)
|
||||
deleted_files.append(gradio_file)
|
||||
logger.info(f"Deleted Gradio temporary file: {gradio_file}")
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to delete Gradio temporary file {gradio_file}: {str(e)}")
|
||||
|
||||
status_message = f"Cleanup complete. Deleted {len(deleted_files)} files."
|
||||
logger.info(status_message)
|
||||
return status_message
|
||||
except Exception as e:
|
||||
error_message = f"Cleanup error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
return error_message
|
||||
|
||||
def gen_img_helper(model, prompt, res, seed, scheduler, guidance_scale, num_inference_steps, shift, image_format):
|
||||
global pipe, current_model
|
||||
status_message = "Starting image generation..."
|
||||
|
||||
try:
|
||||
# Clean up previous temporary files
|
||||
status_message = "Cleaning up previous temporary files..."
|
||||
logger.info(status_message)
|
||||
clean_previous_temp_files()
|
||||
status_message = "Previous temporary files cleaned."
|
||||
|
||||
# Validate inputs
|
||||
if not prompt or len(prompt.strip()) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
if not isinstance(seed, (int, float)) or seed < -1:
|
||||
raise ValueError("Seed must be -1 or a non-negative integer")
|
||||
if num_inference_steps < 1 or num_inference_steps > 100:
|
||||
raise ValueError("Number of inference steps must be between 1 and 100")
|
||||
if guidance_scale < 0 or guidance_scale > 10:
|
||||
raise ValueError("Guidance scale must be between 0 and 10")
|
||||
if shift < 1 or shift > 10:
|
||||
raise ValueError("Shift must be between 1 and 10")
|
||||
|
||||
# 1. Check if the model matches loaded model, load the model if not
|
||||
if model != current_model:
|
||||
status_message = f"Unloading model {current_model}..."
|
||||
logger.info(status_message)
|
||||
if pipe is not None:
|
||||
del pipe
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
status_message = f"Loading model {model}..."
|
||||
logger.info(status_message)
|
||||
pipe, _ = load_models(model)
|
||||
current_model = model
|
||||
status_message = "Model loaded successfully!"
|
||||
logger.info(status_message)
|
||||
|
||||
# 2. Update scheduler
|
||||
config = MODEL_CONFIGS[model]
|
||||
scheduler_map = {
|
||||
"FlashFlowMatchEulerDiscreteScheduler": FlashFlowMatchEulerDiscreteScheduler,
|
||||
"FlowUniPCMultistepScheduler": FlowUniPCMultistepScheduler
|
||||
}
|
||||
if scheduler not in scheduler_map:
|
||||
raise ValueError(f"Invalid scheduler: {scheduler}")
|
||||
scheduler_class = scheduler_map[scheduler]
|
||||
device = pipe._execution_device
|
||||
|
||||
# Set scheduler with shift for flow-matching schedulers
|
||||
pipe.scheduler = scheduler_class(num_train_timesteps=1000, shift=shift, use_dynamic_shifting=False)
|
||||
|
||||
# 3. Generate image
|
||||
status_message = "Generating image..."
|
||||
logger.info(status_message)
|
||||
res = parse_resolution(res)
|
||||
image, seed = generate_image(pipe, model, prompt, res, seed, guidance_scale, num_inference_steps)
|
||||
|
||||
# 4. Save image locally with selected format
|
||||
status_message = "Saving image locally..."
|
||||
logger.info(status_message)
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
file_extension = image_format.lower()
|
||||
output_path = os.path.join(OUTPUT_DIR, f"output_{timestamp}.{file_extension}")
|
||||
if image_format == "JPEG":
|
||||
image = image.convert("RGB") # JPEG doesn't support RGBA
|
||||
image.save(output_path, format=image_format)
|
||||
logger.info(f"Image saved to {output_path}")
|
||||
|
||||
# 5. Prepare image for download in selected format
|
||||
status_message = "Preparing image for download..."
|
||||
logger.info(status_message)
|
||||
download_filename = f"generated_image_{timestamp}.{file_extension}"
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_extension}", prefix="hdi1_") as temp_file:
|
||||
if image_format == "JPEG":
|
||||
image = image.convert("RGB") # Ensure JPEG compatibility
|
||||
image.save(temp_file, format=image_format)
|
||||
temp_file_path = temp_file.name
|
||||
logger.info(f"Temporary file created at {temp_file_path}")
|
||||
|
||||
status_message = "Image generation complete!"
|
||||
logger.info(status_message)
|
||||
return image, seed, f"Image saved to: {output_path}", temp_file_path, status_message
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
return None, None, None, None, error_message
|
||||
|
||||
def generate_image(pipe, model_type, prompt, resolution, seed, guidance_scale, num_inference_steps):
|
||||
try:
|
||||
# Parse resolution
|
||||
width, height = resolution
|
||||
|
||||
# Handle seed
|
||||
if seed == -1:
|
||||
seed = torch.randint(0, 1000000, (1,)).item()
|
||||
|
||||
generator = torch.Generator("cuda").manual_seed(seed)
|
||||
|
||||
# Common parameters
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"guidance_scale": guidance_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"num_images_per_prompt": 1,
|
||||
"generator": generator
|
||||
}
|
||||
|
||||
images = pipe(**params).images
|
||||
return images[0], seed
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Image generation failed: {str(e)}") from e
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
||||
|
||||
# Initialize globals without loading model
|
||||
current_model = None
|
||||
pipe = None
|
||||
|
||||
# Create Gradio interface
|
||||
with gr.Blocks(title="HiDream-I1-nf4 Dashboard") as demo:
|
||||
gr.Markdown("# HiDream-I1-nf4 Dashboard")
|
||||
gr.Markdown("**Note**: Use the 'Download Image' link below to download the image in your selected format (PNG, JPEG, or WEBP). Downloading from the image preview's download button is WEBP format.")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
model_type = gr.Radio(
|
||||
choices=list(MODEL_CONFIGS.keys()),
|
||||
value="fast",
|
||||
label="Model Type",
|
||||
info="Select model variant (e.g., 'fast' for quick generation)"
|
||||
)
|
||||
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
placeholder="A cat holding a sign that says \"Hi-Dreams.ai\".",
|
||||
lines=3
|
||||
)
|
||||
|
||||
resolution = gr.Radio(
|
||||
choices=RESOLUTION_OPTIONS,
|
||||
value=RESOLUTION_OPTIONS[0],
|
||||
label="Resolution",
|
||||
info="Select image resolution"
|
||||
)
|
||||
|
||||
seed = gr.Number(
|
||||
label="Seed (use -1 for random)",
|
||||
value=-1,
|
||||
precision=0
|
||||
)
|
||||
|
||||
scheduler = gr.Radio(
|
||||
choices=SCHEDULER_OPTIONS,
|
||||
value="FlashFlowMatchEulerDiscreteScheduler",
|
||||
label="Scheduler",
|
||||
info="Select scheduler type. Flow-matching schedulers are optimized for HiDream, providing stable, high-quality, prompt-relevant images."
|
||||
)
|
||||
|
||||
guidance_scale = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=10.0,
|
||||
step=0.1,
|
||||
value=2.0,
|
||||
label="Guidance Scale",
|
||||
info="Controls prompt adherence. Use 2.0–5.0; increase to 4.0–5.0 for stronger prompt following."
|
||||
)
|
||||
|
||||
num_inference_steps = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
step=1,
|
||||
value=25,
|
||||
label="Number of Inference Steps",
|
||||
info="Controls denoising steps. Use 25–50; increase to 40–50 for sharper images."
|
||||
)
|
||||
|
||||
shift = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=10.0,
|
||||
step=0.1,
|
||||
value=3.0,
|
||||
label="Shift",
|
||||
info="Scheduler shift parameter for flow-matching schedulers. Use 1.0–5.0; 3.0 is a good default."
|
||||
)
|
||||
|
||||
image_format = gr.Radio(
|
||||
choices=IMAGE_FORMAT_OPTIONS,
|
||||
value="PNG",
|
||||
label="Image Format",
|
||||
info="Select the format for the saved and downloaded image."
|
||||
)
|
||||
|
||||
generate_btn = gr.Button("Generate Image")
|
||||
cleanup_btn = gr.Button("Clean Temporary Files")
|
||||
|
||||
with gr.Column():
|
||||
status_message = gr.Textbox(label="Status", value="Ready", interactive=False)
|
||||
output_image = gr.Image(label="Generated Image", type="pil")
|
||||
seed_used = gr.Number(label="Seed Used", interactive=False)
|
||||
save_path = gr.Textbox(label="Saved Image Path", interactive=False)
|
||||
download_file = gr.File(label="Download Image", interactive=False, file_types=[".png", ".jpeg", ".webp"])
|
||||
|
||||
generate_btn.click(
|
||||
fn=gen_img_helper,
|
||||
inputs=[model_type, prompt, resolution, seed, scheduler, guidance_scale, num_inference_steps, shift, image_format],
|
||||
outputs=[output_image, seed_used, save_path, download_file, status_message]
|
||||
)
|
||||
cleanup_btn.click(
|
||||
fn=clean_all_temp_files,
|
||||
inputs=[],
|
||||
outputs=[status_message]
|
||||
)
|
||||
|
||||
demo.launch(share=True, pwa=True)
|
||||
-138
@@ -1,138 +0,0 @@
|
||||
import torch
|
||||
import argparse
|
||||
from hi_diffusers import HiDreamImagePipeline
|
||||
from hi_diffusers import HiDreamImageTransformer2DModel
|
||||
from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
|
||||
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_type", type=str, default="dev")
|
||||
args = parser.parse_args()
|
||||
model_type = args.model_type
|
||||
MODEL_PREFIX = "HiDream-ai"
|
||||
LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
|
||||
# Model configurations
|
||||
MODEL_CONFIGS = {
|
||||
"dev": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Dev",
|
||||
"guidance_scale": 0.0,
|
||||
"num_inference_steps": 28,
|
||||
"shift": 6.0,
|
||||
"scheduler": FlashFlowMatchEulerDiscreteScheduler
|
||||
},
|
||||
"full": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Full",
|
||||
"guidance_scale": 5.0,
|
||||
"num_inference_steps": 50,
|
||||
"shift": 3.0,
|
||||
"scheduler": FlowUniPCMultistepScheduler
|
||||
},
|
||||
"fast": {
|
||||
"path": f"{MODEL_PREFIX}/HiDream-I1-Fast",
|
||||
"guidance_scale": 0.0,
|
||||
"num_inference_steps": 16,
|
||||
"shift": 3.0,
|
||||
"scheduler": FlashFlowMatchEulerDiscreteScheduler
|
||||
}
|
||||
}
|
||||
|
||||
# Resolution options
|
||||
RESOLUTION_OPTIONS = [
|
||||
"1024 × 1024 (Square)",
|
||||
"768 × 1360 (Portrait)",
|
||||
"1360 × 768 (Landscape)",
|
||||
"880 × 1168 (Portrait)",
|
||||
"1168 × 880 (Landscape)",
|
||||
"1248 × 832 (Landscape)",
|
||||
"832 × 1248 (Portrait)"
|
||||
]
|
||||
|
||||
# Load models
|
||||
def load_models(model_type):
|
||||
config = MODEL_CONFIGS[model_type]
|
||||
pretrained_model_name_or_path = config["path"]
|
||||
scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False)
|
||||
|
||||
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(
|
||||
LLAMA_MODEL_NAME,
|
||||
use_fast=False)
|
||||
|
||||
text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
LLAMA_MODEL_NAME,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
transformer = HiDreamImageTransformer2DModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
pipe = HiDreamImagePipeline.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
scheduler=scheduler,
|
||||
tokenizer_4=tokenizer_4,
|
||||
text_encoder_4=text_encoder_4,
|
||||
torch_dtype=torch.bfloat16
|
||||
).to("cuda", torch.bfloat16)
|
||||
pipe.transformer = transformer
|
||||
|
||||
return pipe, config
|
||||
|
||||
# Parse resolution string to get height and width
|
||||
def parse_resolution(resolution_str):
|
||||
if "1024 × 1024" in resolution_str:
|
||||
return 1024, 1024
|
||||
elif "768 × 1360" in resolution_str:
|
||||
return 768, 1360
|
||||
elif "1360 × 768" in resolution_str:
|
||||
return 1360, 768
|
||||
elif "880 × 1168" in resolution_str:
|
||||
return 880, 1168
|
||||
elif "1168 × 880" in resolution_str:
|
||||
return 1168, 880
|
||||
elif "1248 × 832" in resolution_str:
|
||||
return 1248, 832
|
||||
elif "832 × 1248" in resolution_str:
|
||||
return 832, 1248
|
||||
else:
|
||||
return 1024, 1024 # Default fallback
|
||||
|
||||
# Generate image function
|
||||
def generate_image(pipe, model_type, prompt, resolution, seed):
|
||||
# Get configuration for current model
|
||||
config = MODEL_CONFIGS[model_type]
|
||||
guidance_scale = config["guidance_scale"]
|
||||
num_inference_steps = config["num_inference_steps"]
|
||||
|
||||
# Parse resolution
|
||||
height, width = parse_resolution(resolution)
|
||||
|
||||
# Handle seed
|
||||
if seed == -1:
|
||||
seed = torch.randint(0, 1000000, (1,)).item()
|
||||
|
||||
generator = torch.Generator("cuda").manual_seed(seed)
|
||||
|
||||
images = pipe(
|
||||
prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images_per_prompt=1,
|
||||
generator=generator
|
||||
).images
|
||||
|
||||
return images[0], seed
|
||||
|
||||
# Initialize with default model
|
||||
print("Loading default model (full)...")
|
||||
pipe, _ = load_models(model_type)
|
||||
print("Model loaded successfully!")
|
||||
prompt = "A cat holding a sign that says \"Hi-Dreams.ai\"."
|
||||
resolution = "1024 × 1024 (Square)"
|
||||
seed = -1
|
||||
image, seed = generate_image(pipe, model_type, prompt, resolution, seed)
|
||||
image.save("output.png")
|
||||
@@ -0,0 +1,37 @@
|
||||
[project]
|
||||
name = "hdi1"
|
||||
version = "1.0.4"
|
||||
description = "HiDream-I1 is a new open-source image generative foundation model with 17B parameters that achieves state-of-the-art image generation quality within seconds."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"accelerate>=1.2.1",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"datasets>=3.5.0",
|
||||
"device-smi>=0.4.1",
|
||||
"diffusers>=0.32.1",
|
||||
"einops>=0.7.0",
|
||||
"flash-attn>=2.7.0",
|
||||
"gptqmodel>=2.2.0",
|
||||
"gradio>=5.24.0",
|
||||
"hf-transfer>=0.1.9",
|
||||
"logbar>=0.0.4",
|
||||
"optimum>=1.24.0",
|
||||
"pillow>=11.1.0",
|
||||
"protobuf>=6.30.2",
|
||||
"sentencepiece>=0.2.0",
|
||||
"setuptools>=78.1.0",
|
||||
"threadpoolctl>=3.6.0",
|
||||
"tokenicer>=0.0.4",
|
||||
"torch>=2.5.1",
|
||||
"torchvision>=0.20.1",
|
||||
"transformers>=4.47.1",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["hdi1*"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/hykilpikonna/HiDream-I1-nf4"
|
||||
Repository = "https://github.com/hykilpikonna/HiDream-I1-nf4"
|
||||
Issues = "https://github.com/hykilpikonna/HiDream-I1-nf4/issues"
|
||||
+3
-1
@@ -2,4 +2,6 @@ torch>=2.5.1
|
||||
torchvision>=0.20.1
|
||||
diffusers>=0.32.1
|
||||
transformers>=4.47.1
|
||||
einops>=0.7.0
|
||||
einops>=0.7.0
|
||||
accelerate>=1.2.1
|
||||
sentencepiece
|
||||
Reference in New Issue
Block a user