update readme and fix pipeline text length
This commit is contained in:
+3
-1
@@ -2,4 +2,6 @@ __pycache__
|
||||
tmp
|
||||
*_local.py
|
||||
*.jpg
|
||||
*.png
|
||||
*.png
|
||||
*.tar
|
||||
*.txt
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# HiDream-I1
|
||||
|
||||

|
||||
|
||||
`HiDream-I1` is a new open-source image generative foundation model with 17B parameters that achieves state-of-the-art image generation quality within seconds.
|
||||
|
||||
## Project Updates
|
||||
|
||||
@@ -176,10 +176,10 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.text_encoder_3.model_max_length - 1 : -1])
|
||||
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {self.text_encoder_3.model_max_length} tokens: {removed_text}"
|
||||
f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0]
|
||||
@@ -262,10 +262,10 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, self.text_encoder_4.model_max_length - 1 : -1])
|
||||
removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {self.text_encoder_4.model_max_length} tokens: {removed_text}"
|
||||
f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
outputs = self.text_encoder_4(
|
||||
|
||||
Reference in New Issue
Block a user