Adding docstrings

This commit is contained in:
Shivam Mehta
2023-09-17 06:50:46 +00:00
parent c079e5254a
commit a9251ed984
10 changed files with 47 additions and 32 deletions

View File

@@ -2,8 +2,13 @@ from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from diffusers.models.attention import (GEGLU, GELU, AdaLayerNorm,
AdaLayerNormZero, ApproximateGELU)
from diffusers.models.attention import (
GEGLU,
GELU,
AdaLayerNorm,
AdaLayerNormZero,
ApproximateGELU,
)
from diffusers.models.attention_processor import Attention
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.utils.torch_utils import maybe_allow_in_graph
@@ -38,7 +43,7 @@ class SnakeBeta(nn.Module):
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super(SnakeBeta, self).__init__()
super().__init__()
self.in_features = out_features if isinstance(out_features, list) else [out_features]
self.proj = LoRACompatibleLinear(in_features, out_features)
@@ -73,8 +78,8 @@ class SnakeBeta(nn.Module):
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
return x
class FeedForward(nn.Module):
r"""
A feed-forward layer.
@@ -127,8 +132,7 @@ class FeedForward(nn.Module):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
@@ -217,7 +221,7 @@ class BasicTransformerBlock(nn.Module):
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
# scale_qk=False, # uncomment this to not to use flash attention
# scale_qk=False, # uncomment this to not to use flash attention
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
@@ -309,4 +313,4 @@ class BasicTransformerBlock(nn.Module):
hidden_states = ff_output + hidden_states
return hidden_states
return hidden_states