update readme and fix pipeline text length
This commit is contained in:
+3
-1
@@ -2,4 +2,6 @@ __pycache__
|
|||||||
tmp
|
tmp
|
||||||
*_local.py
|
*_local.py
|
||||||
*.jpg
|
*.jpg
|
||||||
*.png
|
*.png
|
||||||
|
*.tar
|
||||||
|
*.txt
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
# HiDream-I1
|
# HiDream-I1
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
`HiDream-I1` is a new open-source image generative foundation model with 17B parameters that achieves state-of-the-art image generation quality within seconds.
|
`HiDream-I1` is a new open-source image generative foundation model with 17B parameters that achieves state-of-the-art image generation quality within seconds.
|
||||||
|
|
||||||
## Project Updates
|
## Project Updates
|
||||||
|
|||||||
@@ -176,10 +176,10 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
|||||||
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
|
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
|
||||||
|
|
||||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||||
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.text_encoder_3.model_max_length - 1 : -1])
|
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1])
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||||
f" {self.text_encoder_3.model_max_length} tokens: {removed_text}"
|
f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0]
|
prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0]
|
||||||
@@ -262,10 +262,10 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
|||||||
untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids
|
untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids
|
||||||
|
|
||||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||||
removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, self.text_encoder_4.model_max_length - 1 : -1])
|
removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1])
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||||
f" {self.text_encoder_4.model_max_length} tokens: {removed_text}"
|
f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self.text_encoder_4(
|
outputs = self.text_encoder_4(
|
||||||
|
|||||||
Reference in New Issue
Block a user