from transformers import CLIPTextModel, CLIPTokenizer, logging from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler # suppress partial model loading warning logging.set_verbosity_error() import torch import torch.nn as nn import torchvision.transforms as T import argparse
defget_text_embeds(self, prompt, negative_prompt): # prompt, negative_prompt: [str] # Tokenize text and get embeddings text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt') text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
# Do the same for unconditional embeddings uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# Cat for final embeddings text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings
# Define panorama grid and get views latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device) views = get_views(height, width) count = torch.zeros_like(latent) value = torch.zeros_like(latent)
self.scheduler.set_timesteps(num_inference_steps)
with torch.autocast('cuda'): for i, t inenumerate(self.scheduler.timesteps): count.zero_() value.zero_()
for h_start, h_end, w_start, w_end in views: # TODO we can support batches, and pass multiple views at once to the unet latent_view = latent[:, :, h_start:h_end, w_start:w_end]
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latent_model_input = torch.cat([latent_view] * 2)
# predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
关注 R. 却没能成为自己. (n.d.). pytorch如何确保 可重复性/每次训练结果相同(固定了随机种子,为什么还不行)?. 知乎. Retrieved May 9, 2023, from http://zhihu.com/question/345043149/answer/2940838756 ↩︎