update readme and fix pipeline text length

This commit is contained in:
Qi Cai
2025-04-07 22:54:41 +08:00
parent 9bf3ac0e4f
commit 35197292da
3 changed files with 9 additions and 5 deletions
+2
View File
@@ -3,3 +3,5 @@ tmp
*_local.py *_local.py
*.jpg *.jpg
*.png *.png
*.tar
*.txt
+2
View File
@@ -1,5 +1,7 @@
# HiDream-I1 # HiDream-I1
![HiDream-I1 Demo](assets/demo.jpg)
`HiDream-I1` is a new open-source image generative foundation model with 17B parameters that achieves state-of-the-art image generation quality within seconds. `HiDream-I1` is a new open-source image generative foundation model with 17B parameters that achieves state-of-the-art image generation quality within seconds.
## Project Updates ## Project Updates
@@ -176,10 +176,10 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.text_encoder_3.model_max_length - 1 : -1]) removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to " "The following part of your input was truncated because `max_sequence_length` is set to "
f" {self.text_encoder_3.model_max_length} tokens: {removed_text}" f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}"
) )
prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0]
@@ -262,10 +262,10 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, self.text_encoder_4.model_max_length - 1 : -1]) removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to " "The following part of your input was truncated because `max_sequence_length` is set to "
f" {self.text_encoder_4.model_max_length} tokens: {removed_text}" f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}"
) )
outputs = self.text_encoder_4( outputs = self.text_encoder_4(