From f4058373897e074b467df716182b9b51ffb6e902 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Thu, 30 Jun 2022 13:47:46 -0700 Subject: [PATCH 01/26] create filestructure for science application --- src/diffusers/models/dualencoder_gfn.py | 1018 +++++++++++++++++++++++ 1 file changed, 1018 insertions(+) create mode 100644 src/diffusers/models/dualencoder_gfn.py diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/dualencoder_gfn.py new file mode 100644 index 000000000000..288c329526c1 --- /dev/null +++ b/src/diffusers/models/dualencoder_gfn.py @@ -0,0 +1,1018 @@ +# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff +import torch +from torch import nn +import torch.nn.functional as F +from torch import Tensor +from typing import Callable, Union +from torch.nn import Module, Sequential, ModuleList, Linear, Embedding +from torch_scatter import scatter_add, scatter_mean +from torch_geometric.nn import MessagePassing, radius_graph, radius +from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size +from torch_geometric.utils import to_dense_adj, dense_to_sparse +from torch_sparse import SparseTensor, coalesce +import numpy as np +from tqdm.auto import tqdm +from rdkit.Chem.rdchem import BondType as BT +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin + + +BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} + + +class MultiLayerPerceptron(nn.Module): + """ + Multi-layer Perceptron. + Note there is no activation or dropout in the last layer. + Parameters: + input_dim (int): input dimension + hidden_dim (list of int): hidden dimensions + activation (str or function, optional): activation function + dropout (float, optional): dropout rate + """ + + def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0): + super(MultiLayerPerceptron, self).__init__() + + self.dims = [input_dim] + hidden_dims + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = None + if dropout: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1])) + + def forward(self, input): + """""" + x = input + for i, layer in enumerate(self.layers): + x = layer(x) + if i < len(self.layers) - 1: + if self.activation: + x = self.activation(x) + if self.dropout: + x = self.dropout(x) + return x + + +class ShiftedSoftplus(torch.nn.Module): + def __init__(self): + super(ShiftedSoftplus, self).__init__() + self.shift = torch.log(torch.tensor(2.0)).item() + + def forward(self, x): + return F.softplus(x) - self.shift + + +class CFConv(MessagePassing): + def __init__(self, in_channels, out_channels, num_filters, nn, cutoff, smooth): + super(CFConv, self).__init__(aggr="add") + self.lin1 = Linear(in_channels, num_filters, bias=False) + self.lin2 = Linear(num_filters, out_channels) + self.nn = nn + self.cutoff = cutoff + self.smooth = smooth + + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.lin1.weight) + torch.nn.init.xavier_uniform_(self.lin2.weight) + self.lin2.bias.data.fill_(0) + + def forward(self, x, edge_index, edge_length, edge_attr): + if self.smooth: + C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0) + C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff + else: + C = (edge_length <= self.cutoff).float() + W = self.nn(edge_attr) * C.view(-1, 1) + + x = self.lin1(x) + x = self.propagate(edge_index, x=x, W=W) + x = self.lin2(x) + return x + + def message(self, x_j, W): + return x_j * W + + +class InteractionBlock(torch.nn.Module): + def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth): + super(InteractionBlock, self).__init__() + mlp = Sequential( + Linear(num_gaussians, num_filters), + ShiftedSoftplus(), + Linear(num_filters, num_filters), + ) + self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth) + self.act = ShiftedSoftplus() + self.lin = Linear(hidden_channels, hidden_channels) + + def forward(self, x, edge_index, edge_length, edge_attr): + x = self.conv(x, edge_index, edge_length, edge_attr) + x = self.act(x) + x = self.lin(x) + return x + + +class SchNetEncoder(Module): + def __init__( + self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False + ): + super().__init__() + + self.hidden_channels = hidden_channels + self.num_filters = num_filters + self.num_interactions = num_interactions + self.cutoff = cutoff + + self.embedding = Embedding(100, hidden_channels, max_norm=10.0) + + self.interactions = ModuleList() + for _ in range(num_interactions): + block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth) + self.interactions.append(block) + + def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True): + if embed_node: + assert z.dim() == 1 and z.dtype == torch.long + h = self.embedding(z) + else: + h = z + for interaction in self.interactions: + h = h + interaction(h, edge_index, edge_length, edge_attr) + + return h + + +class GINEConv(MessagePassing): + def __init__(self, nn: Callable, eps: float = 0.0, train_eps: bool = False, activation="softplus", **kwargs): + super(GINEConv, self).__init__(aggr="add", **kwargs) + self.nn = nn + self.initial_eps = eps + + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = None + + if train_eps: + self.eps = torch.nn.Parameter(torch.Tensor([eps])) + else: + self.register_buffer("eps", torch.Tensor([eps])) + + def forward( + self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None + ) -> torch.Tensor: + """""" + if isinstance(x, torch.Tensor): + x: OptPairTensor = (x, x) + + # Node and edge feature dimensionalites need to match. + if isinstance(edge_index, torch.Tensor): + assert edge_attr is not None + assert x[0].size(-1) == edge_attr.size(-1) + elif isinstance(edge_index, SparseTensor): + assert x[0].size(-1) == edge_index.size(-1) + + # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) + out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) + + x_r = x[1] + if x_r is not None: + out += (1 + self.eps) * x_r + + return self.nn(out) + + def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor: + if self.activation: + return self.activation(x_j + edge_attr) + else: + return x_j + edge_attr + + def __repr__(self): + return "{}(nn={})".format(self.__class__.__name__, self.nn) + + +class GINEncoder(torch.nn.Module): + def __init__(self, hidden_dim, num_convs=3, activation="relu", short_cut=True, concat_hidden=False): + super().__init__() + + self.hidden_dim = hidden_dim + self.num_convs = num_convs + self.short_cut = short_cut + self.concat_hidden = concat_hidden + self.node_emb = nn.Embedding(100, hidden_dim) + + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = None + + self.convs = nn.ModuleList() + for i in range(self.num_convs): + self.convs.append( + GINEConv( + MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation), + activation=activation, + ) + ) + + def forward(self, z, edge_index, edge_attr): + """ + Input: + data: (torch_geometric.data.Data): batched graph + node_attr: node feature tensor with shape (num_node, hidden) + edge_attr: edge feature tensor with shape (num_edge, hidden) + Output: + node_attr + graph feature + """ + + node_attr = self.node_emb(z) # (num_node, hidden) + + hiddens = [] + conv_input = node_attr # (num_node, hidden) + + for conv_idx, conv in enumerate(self.convs): + hidden = conv(conv_input, edge_index, edge_attr) + if conv_idx < len(self.convs) - 1 and self.activation is not None: + hidden = self.activation(hidden) + assert hidden.shape == conv_input.shape + if self.short_cut and hidden.shape == conv_input.shape: + hidden += conv_input + + hiddens.append(hidden) + conv_input = hidden + + if self.concat_hidden: + node_feature = torch.cat(hiddens, dim=-1) + else: + node_feature = hiddens[-1] + + return node_feature + + +class MLPEdgeEncoder(Module): + def __init__(self, hidden_dim=100, activation="relu"): + super().__init__() + self.hidden_dim = hidden_dim + self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim) + self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation) + + @property + def out_channels(self): + return self.hidden_dim + + def forward(self, edge_length, edge_type): + """ + Input: + edge_length: The length of edges, shape=(E, 1). + edge_type: The type pf edges, shape=(E,) + Returns: + edge_attr: The representation of edges. (E, 2 * num_gaussians) + """ + d_emb = self.mlp(edge_length) # (num_edge, hidden_dim) + edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim) + return d_emb * edge_attr # (num_edge, hidden) + + +def assemble_atom_pair_feature(node_attr, edge_index, edge_attr): + h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]] + h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H) + return h_pair + + +def generate_symmetric_edge_noise(num_nodes_per_graph, edge_index, edge2graph, device): + num_cum_nodes = num_nodes_per_graph.cumsum(0) # (G, ) + node_offset = num_cum_nodes - num_nodes_per_graph # (G, ) + edge_offset = node_offset[edge2graph] # (E, ) + + num_nodes_square = num_nodes_per_graph**2 # (G, ) + num_nodes_square_cumsum = num_nodes_square.cumsum(-1) # (G, ) + edge_start = num_nodes_square_cumsum - num_nodes_square # (G, ) + edge_start = edge_start[edge2graph] + + all_len = num_nodes_square_cumsum[-1] + + node_index = edge_index.t() - edge_offset.unsqueeze(-1) + node_large = node_index.max(dim=-1)[0] + node_small = node_index.min(dim=-1)[0] + undirected_edge_id = node_large * (node_large + 1) + node_small + edge_start + + symm_noise = torch.zeros(size=[all_len.item()], device=device) + symm_noise.normal_() + d_noise = symm_noise[undirected_edge_id].unsqueeze(-1) # (E, 1) + return d_noise + + +def _extend_graph_order(num_nodes, edge_index, edge_type, order=3): + """ + Args: + num_nodes: Number of atoms. + edge_index: Bond indices of the original graph. + edge_type: Bond types of the original graph. + order: Extension order. + Returns: + new_edge_index: Extended edge indices. + new_edge_type: Extended edge types. + """ + + def binarize(x): + return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) + + def get_higher_order_adj_matrix(adj, order): + """ + Args: + adj: (N, N) + type_mat: (N, N) + Returns: + Following attributes will be updated: + - edge_index + - edge_type + Following attributes will be added to the data object: + - bond_edge_index: Original edge_index. + """ + adj_mats = [ + torch.eye(adj.size(0), dtype=torch.long, device=adj.device), + binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)), + ] + + for i in range(2, order + 1): + adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1])) + order_mat = torch.zeros_like(adj) + + for i in range(1, order + 1): + order_mat += (adj_mats[i] - adj_mats[i - 1]) * i + + return order_mat + + num_types = len(BOND_TYPES) + + N = num_nodes + adj = to_dense_adj(edge_index).squeeze(0) + adj_order = get_higher_order_adj_matrix(adj, order) # (N, N) + + type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N) + type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order)) + assert (type_mat * type_highorder == 0).all() + type_new = type_mat + type_highorder + + new_edge_index, new_edge_type = dense_to_sparse(type_new) + _, edge_order = dense_to_sparse(adj_order) + + # data.bond_edge_index = data.edge_index # Save original edges + new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data + + # [Note] This is not necessary + # data.is_bond = (data.edge_type < num_types) + + # [Note] In earlier versions, `edge_order` attribute will be added. + # However, it doesn't seem to be necessary anymore so I removed it. + # edge_index_1, data.edge_order = coalesce(new_edge_index, edge_order.long(), N, N) # modify data + # assert (data.edge_index == edge_index_1).all() + + return new_edge_index, new_edge_type + + +def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None): + assert edge_type.dim() == 1 + N = pos.size(0) + + bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N])) + + if is_sidechain is None: + rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r) + else: + # fetch sidechain and its batch index + is_sidechain = is_sidechain.bool() + dummy_index = torch.arange(pos.size(0), device=pos.device) + sidechain_pos = pos[is_sidechain] + sidechain_index = dummy_index[is_sidechain] + sidechain_batch = batch[is_sidechain] + + assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch) + r_edge_index_x = assign_index[1] + r_edge_index_y = assign_index[0] + r_edge_index_y = sidechain_index[r_edge_index_y] + + rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E) + rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E) + rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E) + # delete self loop + rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])] + + rgraph_adj = torch.sparse.LongTensor( + rgraph_edge_index, + torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number, + torch.Size([N, N]), + ) + + composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T) + # edge_index = composed_adj.indices() + # dist = (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1) + + new_edge_index = composed_adj.indices() + new_edge_type = composed_adj.values().long() + + return new_edge_index, new_edge_type + + +def extend_graph_order_radius( + num_nodes, + pos, + edge_index, + edge_type, + batch, + order=3, + cutoff=10.0, + extend_order=True, + extend_radius=True, + is_sidechain=None, +): + if extend_order: + edge_index, edge_type = _extend_graph_order( + num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order + ) + # edge_index_order = edge_index + # edge_type_order = edge_type + + if extend_radius: + edge_index, edge_type = _extend_to_radius_graph( + pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain + ) + + return edge_index, edge_type + + +def get_distance(pos, edge_index): + return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1) + + +def eq_transform(score_d, pos, edge_index, edge_length): + N = pos.size(0) + dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3) + score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add( + -dd_dr * score_d, edge_index[1], dim=0, dim_size=N + ) # (N, 3) + return score_pos + + +def convert_cluster_score_d(cluster_score_d, cluster_pos, cluster_edge_index, cluster_edge_length, subgraph_index): + """ + Args: + cluster_score_d: (E_c, 1) + subgraph_index: (N, ) + """ + cluster_score_pos = eq_transform(cluster_score_d, cluster_pos, cluster_edge_index, cluster_edge_length) # (C, 3) + score_pos = cluster_score_pos[subgraph_index] + return score_pos + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + def sigmoid(x): + return 1 / (np.exp(-x) + 1) + + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "sigmoid": + betas = np.linspace(-6, 6, num_diffusion_timesteps) + betas = sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +class DualEncoderEpsNetwork(ModelMixin, ConfigMixin): + def __init__(self, config): + super().__init__() + self.config = config + + """ + edge_encoder: Takes both edge type and edge length as input and outputs a vector + [Note]: node embedding is done in SchNetEncoder + """ + self.edge_encoder_global = MLPEdgeEncoder(config.hidden_dim, config.mlp_act) # get_edge_encoder(config) + self.edge_encoder_local = MLPEdgeEncoder(config.hidden_dim, config.mlp_act) # get_edge_encoder(config) + + """ + The graph neural network that extracts node-wise features. + """ + self.encoder_global = SchNetEncoder( + hidden_channels=config.hidden_dim, + num_filters=config.hidden_dim, + num_interactions=config.num_convs, + edge_channels=self.edge_encoder_global.out_channels, + cutoff=config.cutoff, + smooth=config.smooth_conv, + ) + self.encoder_local = GINEncoder( + hidden_dim=config.hidden_dim, + num_convs=config.num_convs_local, + ) + + """ + `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs + gradients w.r.t. edge_length (out_dim = 1). + """ + self.grad_global_dist_mlp = MultiLayerPerceptron( + 2 * config.hidden_dim, [config.hidden_dim, config.hidden_dim // 2, 1], activation=config.mlp_act + ) + + self.grad_local_dist_mlp = MultiLayerPerceptron( + 2 * config.hidden_dim, [config.hidden_dim, config.hidden_dim // 2, 1], activation=config.mlp_act + ) + + """ + Incorporate parameters together + """ + self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp]) + self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp]) + + self.model_type = config.type # config.type # 'diffusion'; 'dsm' + + # denoising diffusion + ## betas + betas = get_beta_schedule( + beta_schedule=config.beta_schedule, + beta_start=config.beta_start, + beta_end=config.beta_end, + num_diffusion_timesteps=config.num_diffusion_timesteps, + ) + betas = torch.from_numpy(betas).float() + self.betas = nn.Parameter(betas, requires_grad=False) + ## variances + alphas = (1.0 - betas).cumprod(dim=0) + self.alphas = nn.Parameter(alphas, requires_grad=False) + self.num_timesteps = self.betas.size(0) + + def forward( + self, + atom_type, + pos, + bond_index, + bond_type, + batch, + time_step, + edge_index=None, + edge_type=None, + edge_length=None, + return_edges=False, + extend_order=True, + extend_radius=True, + is_sidechain=None, + ): + """ + Args: + atom_type: Types of atoms, (N, ). + bond_index: Indices of bonds (not extended, not radius-graph), (2, E). + bond_type: Bond types, (E, ). + batch: Node index to graph index, (N, ). + """ + N = atom_type.size(0) + if edge_index is None or edge_type is None or edge_length is None: + edge_index, edge_type = extend_graph_order_radius( + num_nodes=N, + pos=pos, + edge_index=bond_index, + edge_type=bond_type, + batch=batch, + order=self.config.edge_order, + cutoff=self.config.cutoff, + extend_order=extend_order, + extend_radius=extend_radius, + is_sidechain=is_sidechain, + ) + edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1) + local_edge_mask = is_local_edge(edge_type) # (E, ) + + # with the parameterization of NCSNv2 + # DDPM loss implicit handle the noise variance scale conditioning + sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1) + + # Encoding global + edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges + # edge_attr += temb_edge + + # Global + node_attr_global = self.encoder_global( + z=atom_type, + edge_index=edge_index, + edge_length=edge_length, + edge_attr=edge_attr_global, + ) + ## Assemble pairwise features + h_pair_global = assemble_atom_pair_feature( + node_attr=node_attr_global, + edge_index=edge_index, + edge_attr=edge_attr_global, + ) # (E_global, 2H) + ## Invariant features of edges (radius graph, global) + edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1) + + # Encoding local + edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges + # edge_attr += temb_edge + + # Local + node_attr_local = self.encoder_local( + z=atom_type, + edge_index=edge_index[:, local_edge_mask], + edge_attr=edge_attr_local[local_edge_mask], + ) + ## Assemble pairwise features + h_pair_local = assemble_atom_pair_feature( + node_attr=node_attr_local, + edge_index=edge_index[:, local_edge_mask], + edge_attr=edge_attr_local[local_edge_mask], + ) # (E_local, 2H) + ## Invariant features of edges (bond graph, local) + edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge[local_edge_mask]) # (E_local, 1) + + if return_edges: + return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask + else: + return edge_inv_global, edge_inv_local + + def get_loss( + self, + atom_type, + pos, + bond_index, + bond_type, + batch, + num_nodes_per_graph, + num_graphs, + anneal_power=2.0, + return_unreduced_loss=False, + return_unreduced_edge_loss=False, + extend_order=True, + extend_radius=True, + is_sidechain=None, + ): + return self.get_loss_diffusion( + atom_type, + pos, + bond_index, + bond_type, + batch, + num_nodes_per_graph, + num_graphs, + anneal_power, + return_unreduced_loss, + return_unreduced_edge_loss, + extend_order, + extend_radius, + is_sidechain, + ) + + def get_loss_diffusion( + self, + atom_type, + pos, + bond_index, + bond_type, + batch, + num_nodes_per_graph, + num_graphs, + anneal_power=2.0, + return_unreduced_loss=False, + return_unreduced_edge_loss=False, + extend_order=True, + extend_radius=True, + is_sidechain=None, + ): + N = atom_type.size(0) + node2graph = batch + + # Four elements for DDPM: original_data(pos), gaussian_noise(pos_noise), beta(sigma), time_step + # Sample noise levels + time_step = torch.randint(0, self.num_timesteps, size=(num_graphs // 2 + 1,), device=pos.device) + time_step = torch.cat([time_step, self.num_timesteps - time_step - 1], dim=0)[:num_graphs] + a = self.alphas.index_select(0, time_step) # (G, ) + # Perterb pos + a_pos = a.index_select(0, node2graph).unsqueeze(-1) # (N, 1) + pos_noise = torch.zeros(size=pos.size(), device=pos.device) + pos_noise.normal_() + pos_perturbed = pos + pos_noise * (1.0 - a_pos).sqrt() / a_pos.sqrt() + + # Update invariant edge features, as shown in equation 5-7 + edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self( + atom_type=atom_type, + pos=pos_perturbed, + bond_index=bond_index, + bond_type=bond_type, + batch=batch, + time_step=time_step, + return_edges=True, + extend_order=extend_order, + extend_radius=extend_radius, + is_sidechain=is_sidechain, + ) # (E_global, 1), (E_local, 1) + + edge2graph = node2graph.index_select(0, edge_index[0]) + # Compute sigmas_edge + a_edge = a.index_select(0, edge2graph).unsqueeze(-1) # (E, 1) + + # Compute original and perturbed distances + d_gt = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1) + d_perturbed = edge_length + # Filtering for protein + train_edge_mask = is_train_edge(edge_index, is_sidechain) + d_perturbed = torch.where(train_edge_mask.unsqueeze(-1), d_perturbed, d_gt) + + if self.config.edge_encoder == "gaussian": + # Distances must be greater than 0 + d_sgn = torch.sign(d_perturbed) + d_perturbed = torch.clamp(d_perturbed * d_sgn, min=0.01, max=float("inf")) + d_target = (d_gt - d_perturbed) / (1.0 - a_edge).sqrt() * a_edge.sqrt() # (E_global, 1), denoising direction + + global_mask = torch.logical_and( + torch.logical_or(d_perturbed <= self.config.cutoff, local_edge_mask.unsqueeze(-1)), + ~local_edge_mask.unsqueeze(-1), + ) + target_d_global = torch.where(global_mask, d_target, torch.zeros_like(d_target)) + edge_inv_global = torch.where(global_mask, edge_inv_global, torch.zeros_like(edge_inv_global)) + target_pos_global = eq_transform(target_d_global, pos_perturbed, edge_index, edge_length) + node_eq_global = eq_transform(edge_inv_global, pos_perturbed, edge_index, edge_length) + loss_global = (node_eq_global - target_pos_global) ** 2 + loss_global = 2 * torch.sum(loss_global, dim=-1, keepdim=True) + + target_pos_local = eq_transform( + d_target[local_edge_mask], pos_perturbed, edge_index[:, local_edge_mask], edge_length[local_edge_mask] + ) + node_eq_local = eq_transform( + edge_inv_local, pos_perturbed, edge_index[:, local_edge_mask], edge_length[local_edge_mask] + ) + loss_local = (node_eq_local - target_pos_local) ** 2 + loss_local = 5 * torch.sum(loss_local, dim=-1, keepdim=True) + + # loss for atomic eps regression + loss = loss_global + loss_local + + if return_unreduced_edge_loss: + pass + elif return_unreduced_loss: + return loss, loss_global, loss_local + else: + return loss + + def langevin_dynamics_sample( + self, + atom_type, + pos_init, + bond_index, + bond_type, + batch, + num_graphs, + extend_order, + extend_radius=True, + n_steps=100, + step_lr=0.0000010, + clip=1000, + clip_local=None, + clip_pos=None, + min_sigma=0, + is_sidechain=None, + global_start_sigma=float("inf"), + w_global=0.2, + w_reg=1.0, + **kwargs, + ): + return self.langevin_dynamics_sample_diffusion( + atom_type, + pos_init, + bond_index, + bond_type, + batch, + num_graphs, + extend_order, + extend_radius, + n_steps, + step_lr, + clip, + clip_local, + clip_pos, + min_sigma, + is_sidechain, + global_start_sigma, + w_global, + w_reg, + sampling_type=kwargs.get("sampling_type", "ddpm_noisy"), + eta=kwargs.get("eta", 1.0), + ) + + def langevin_dynamics_sample_diffusion( + self, + atom_type, + pos_init, + bond_index, + bond_type, + batch, + num_graphs, + extend_order, + extend_radius=True, + n_steps=100, + step_lr=0.0000010, + clip=1000, + clip_local=None, + clip_pos=None, + min_sigma=0, + is_sidechain=None, + global_start_sigma=float("inf"), + w_global=0.2, + w_reg=1.0, + **kwargs, + ): + def compute_alpha(beta, t): + beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) + a = (1 - beta).cumprod(dim=0).index_select(0, t + 1) # .view(-1, 1, 1, 1) + return a + + sigmas = (1.0 - self.alphas).sqrt() / self.alphas.sqrt() + pos_traj = [] + if is_sidechain is not None: + assert pos_gt is not None, "need crd of backbone for sidechain prediction" + with torch.no_grad(): + # skip = self.num_timesteps // n_steps + # seq = range(0, self.num_timesteps, skip) + + ## to test sampling with less intermediate diffusion steps + # n_steps: the num of steps + seq = range(self.num_timesteps - n_steps, self.num_timesteps) + seq_next = [-1] + list(seq[:-1]) + + pos = pos_init * sigmas[-1] + if is_sidechain is not None: + pos[~is_sidechain] = pos_gt[~is_sidechain] + for i, j in tqdm(zip(reversed(seq), reversed(seq_next)), desc="sample"): + t = torch.full(size=(num_graphs,), fill_value=i, dtype=torch.long, device=pos.device) + + edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self( + atom_type=atom_type, + pos=pos, + bond_index=bond_index, + bond_type=bond_type, + batch=batch, + time_step=t, + return_edges=True, + extend_order=extend_order, + extend_radius=extend_radius, + is_sidechain=is_sidechain, + ) # (E_global, 1), (E_local, 1) + + # Local + node_eq_local = eq_transform( + edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask] + ) + if clip_local is not None: + node_eq_local = clip_norm(node_eq_local, limit=clip_local) + # Global + if sigmas[i] < global_start_sigma: + edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float()) + node_eq_global = eq_transform(edge_inv_global, pos, edge_index, edge_length) + node_eq_global = clip_norm(node_eq_global, limit=clip) + else: + node_eq_global = 0 + # Sum + eps_pos = node_eq_local + node_eq_global * w_global # + eps_pos_reg * w_reg + + # Update + + sampling_type = kwargs.get("sampling_type", "ddpm_noisy") # types: generalized, ddpm_noisy, ld + + noise = torch.randn_like(pos) # center_pos(torch.randn_like(pos), batch) + if sampling_type == "generalized" or sampling_type == "ddpm_noisy": + b = self.betas + t = t[0] + next_t = (torch.ones(1) * j).to(pos.device) + at = compute_alpha(b, t.long()) + at_next = compute_alpha(b, next_t.long()) + if sampling_type == "generalized": + eta = kwargs.get("eta", 1.0) + et = -eps_pos + ## original + # pos0_t = (pos - et * (1 - at).sqrt()) / at.sqrt() + ## reweighted + # pos0_t = pos - et * (1 - at).sqrt() / at.sqrt() + c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() + c2 = ((1 - at_next) - c1**2).sqrt() + # pos_next = at_next.sqrt() * pos0_t + c1 * noise + c2 * et + # pos_next = pos0_t + c1 * noise / at_next.sqrt() + c2 * et / at_next.sqrt() + + # pos_next = pos + et * (c2 / at_next.sqrt() - (1 - at).sqrt() / at.sqrt()) + noise * c1 / at_next.sqrt() + step_size_pos_ld = step_lr * (sigmas[i] / 0.01) ** 2 / sigmas[i] + step_size_pos_generalized = 5 * ((1 - at).sqrt() / at.sqrt() - c2 / at_next.sqrt()) + step_size_pos = ( + step_size_pos_ld + if step_size_pos_ld < step_size_pos_generalized + else step_size_pos_generalized + ) + + step_size_noise_ld = torch.sqrt((step_lr * (sigmas[i] / 0.01) ** 2) * 2) + step_size_noise_generalized = 3 * (c1 / at_next.sqrt()) + step_size_noise = ( + step_size_noise_ld + if step_size_noise_ld < step_size_noise_generalized + else step_size_noise_generalized + ) + + pos_next = pos - et * step_size_pos + noise * step_size_noise + + elif sampling_type == "ddpm_noisy": + atm1 = at_next + beta_t = 1 - at / atm1 + e = -eps_pos + pos0_from_e = (1.0 / at).sqrt() * pos - (1.0 / at - 1).sqrt() * e + mean_eps = ( + (atm1.sqrt() * beta_t) * pos0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * pos + ) / (1.0 - at) + mean = mean_eps + mask = 1 - (t == 0).float() + logvar = beta_t.log() + pos_next = mean + mask * torch.exp(0.5 * logvar) * noise + elif sampling_type == "ld": + step_size = step_lr * (sigmas[i] / 0.01) ** 2 + pos_next = pos + step_size * eps_pos / sigmas[i] + noise * torch.sqrt(step_size * 2) + + pos = pos_next + + if is_sidechain is not None: + pos[~is_sidechain] = pos_gt[~is_sidechain] + + if torch.isnan(pos).any(): + print("NaN detected. Please restart.") + raise FloatingPointError() + pos = center_pos(pos, batch) + if clip_pos is not None: + pos = torch.clamp(pos, min=-clip_pos, max=clip_pos) + pos_traj.append(pos.clone().cpu()) + + return pos, pos_traj + + +def is_bond(edge_type): + return torch.logical_and(edge_type < len(BOND_TYPES), edge_type > 0) + + +def is_angle_edge(edge_type): + return edge_type == len(BOND_TYPES) + 1 - 1 + + +def is_dihedral_edge(edge_type): + return edge_type == len(BOND_TYPES) + 2 - 1 + + +def is_radius_edge(edge_type): + return edge_type == 0 + + +def is_local_edge(edge_type): + return edge_type > 0 + + +def is_train_edge(edge_index, is_sidechain): + if is_sidechain is None: + return torch.ones(edge_index.size(1), device=edge_index.device).bool() + else: + is_sidechain = is_sidechain.bool() + return torch.logical_or(is_sidechain[edge_index[0]], is_sidechain[edge_index[1]]) + + +def regularize_bond_length(edge_type, edge_length, rng=5.0): + mask = is_bond(edge_type).float().reshape(-1, 1) + d = -torch.clamp(edge_length - rng, min=0.0, max=float("inf")) * mask + return d + + +def center_pos(pos, batch): + pos_center = pos - scatter_mean(pos, batch, dim=0)[batch] + return pos_center + + +def clip_norm(vec, limit, p=2): + norm = torch.norm(vec, dim=-1, p=2, keepdim=True) + denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm)) + return vec * denom From e4a2ddf7a44ea649d001de132e58c679e923b331 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 3 Oct 2022 14:27:58 -0400 Subject: [PATCH 02/26] rebase molecule gen --- src/diffusers/models/dualencoder_gfn.py | 45 +++++++++++++------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/dualencoder_gfn.py index 288c329526c1..1024adfb1148 100644 --- a/src/diffusers/models/dualencoder_gfn.py +++ b/src/diffusers/models/dualencoder_gfn.py @@ -506,31 +506,32 @@ def sigmoid(x): class DualEncoderEpsNetwork(ModelMixin, ConfigMixin): - def __init__(self, config): + def __init__(self, hidden_dim, num_convs, num_convs_local, cutoff, mlp_act, beta_schedule, beta_start, beta_end, num_diffusion_timesteps, edge_order, edge_encoder, smooth_conv): super().__init__() - self.config = config + self.cutoff = cutoff + self.edge_encoder = edge_encoder """ edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done in SchNetEncoder """ - self.edge_encoder_global = MLPEdgeEncoder(config.hidden_dim, config.mlp_act) # get_edge_encoder(config) - self.edge_encoder_local = MLPEdgeEncoder(config.hidden_dim, config.mlp_act) # get_edge_encoder(config) + self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config) + self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config) """ The graph neural network that extracts node-wise features. """ self.encoder_global = SchNetEncoder( - hidden_channels=config.hidden_dim, - num_filters=config.hidden_dim, - num_interactions=config.num_convs, + hidden_channels=hidden_dim, + num_filters=hidden_dim, + num_interactions=num_convs, edge_channels=self.edge_encoder_global.out_channels, - cutoff=config.cutoff, - smooth=config.smooth_conv, + cutoff=cutoff, + smooth=smooth_conv, ) self.encoder_local = GINEncoder( - hidden_dim=config.hidden_dim, - num_convs=config.num_convs_local, + hidden_dim=hidden_dim, + num_convs=num_convs_local, ) """ @@ -538,11 +539,11 @@ def __init__(self, config): gradients w.r.t. edge_length (out_dim = 1). """ self.grad_global_dist_mlp = MultiLayerPerceptron( - 2 * config.hidden_dim, [config.hidden_dim, config.hidden_dim // 2, 1], activation=config.mlp_act + 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act ) self.grad_local_dist_mlp = MultiLayerPerceptron( - 2 * config.hidden_dim, [config.hidden_dim, config.hidden_dim // 2, 1], activation=config.mlp_act + 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act ) """ @@ -551,15 +552,15 @@ def __init__(self, config): self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp]) self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp]) - self.model_type = config.type # config.type # 'diffusion'; 'dsm' + self.model_type = type # type # 'diffusion'; 'dsm' # denoising diffusion ## betas betas = get_beta_schedule( - beta_schedule=config.beta_schedule, - beta_start=config.beta_start, - beta_end=config.beta_end, - num_diffusion_timesteps=config.num_diffusion_timesteps, + beta_schedule=beta_schedule, + beta_start=beta_start, + beta_end=beta_end, + num_diffusion_timesteps=num_diffusion_timesteps, ) betas = torch.from_numpy(betas).float() self.betas = nn.Parameter(betas, requires_grad=False) @@ -599,8 +600,8 @@ def forward( edge_index=bond_index, edge_type=bond_type, batch=batch, - order=self.config.edge_order, - cutoff=self.config.cutoff, + order=self.edge_order, + cutoff=self.cutoff, extend_order=extend_order, extend_radius=extend_radius, is_sidechain=is_sidechain, @@ -743,14 +744,14 @@ def get_loss_diffusion( train_edge_mask = is_train_edge(edge_index, is_sidechain) d_perturbed = torch.where(train_edge_mask.unsqueeze(-1), d_perturbed, d_gt) - if self.config.edge_encoder == "gaussian": + if self.edge_encoder == "gaussian": # Distances must be greater than 0 d_sgn = torch.sign(d_perturbed) d_perturbed = torch.clamp(d_perturbed * d_sgn, min=0.01, max=float("inf")) d_target = (d_gt - d_perturbed) / (1.0 - a_edge).sqrt() * a_edge.sqrt() # (E_global, 1), denoising direction global_mask = torch.logical_and( - torch.logical_or(d_perturbed <= self.config.cutoff, local_edge_mask.unsqueeze(-1)), + torch.logical_or(d_perturbed <= self.cutoff, local_edge_mask.unsqueeze(-1)), ~local_edge_mask.unsqueeze(-1), ) target_d_global = torch.where(global_mask, d_target, torch.zeros_like(d_target)) From 71753ef0c518ec93f8779a53c789dbeace4329d0 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Fri, 1 Jul 2022 09:24:13 -0700 Subject: [PATCH 03/26] make style --- src/diffusers/models/dualencoder_gfn.py | 63 +++++++++++++++---------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/dualencoder_gfn.py index 1024adfb1148..b5f1750766f9 100644 --- a/src/diffusers/models/dualencoder_gfn.py +++ b/src/diffusers/models/dualencoder_gfn.py @@ -1,18 +1,20 @@ # Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff +from typing import Callable, Union + +import numpy as np import torch -from torch import nn import torch.nn.functional as F -from torch import Tensor -from typing import Callable, Union -from torch.nn import Module, Sequential, ModuleList, Linear, Embedding +from torch import Tensor, nn +from torch.nn import Embedding, Linear, Module, ModuleList, Sequential + +from rdkit.Chem.rdchem import BondType as BT +from torch_geometric.nn import MessagePassing, radius, radius_graph +from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size +from torch_geometric.utils import dense_to_sparse, to_dense_adj from torch_scatter import scatter_add, scatter_mean -from torch_geometric.nn import MessagePassing, radius_graph, radius -from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size -from torch_geometric.utils import to_dense_adj, dense_to_sparse from torch_sparse import SparseTensor, coalesce -import numpy as np from tqdm.auto import tqdm -from rdkit.Chem.rdchem import BondType as BT + from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin @@ -22,9 +24,8 @@ class MultiLayerPerceptron(nn.Module): """ - Multi-layer Perceptron. - Note there is no activation or dropout in the last layer. Parameters: + Multi-layer Perceptron. Note there is no activation or dropout in the last layer. input_dim (int): input dimension hidden_dim (list of int): hidden dimensions activation (str or function, optional): activation function @@ -228,12 +229,10 @@ def __init__(self, hidden_dim, num_convs=3, activation="relu", short_cut=True, c def forward(self, z, edge_index, edge_attr): """ Input: - data: (torch_geometric.data.Data): batched graph - node_attr: node feature tensor with shape (num_node, hidden) - edge_attr: edge feature tensor with shape (num_edge, hidden) + data: (torch_geometric.data.Data): batched graph node_attr: node feature tensor with shape (num_node, + hidden) edge_attr: edge feature tensor with shape (num_edge, hidden) Output: - node_attr - graph feature + node_attr graph feature """ node_attr = self.node_emb(z) # (num_node, hidden) @@ -274,10 +273,9 @@ def out_channels(self): def forward(self, edge_length, edge_type): """ Input: - edge_length: The length of edges, shape=(E, 1). - edge_type: The type pf edges, shape=(E,) + edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,) Returns: - edge_attr: The representation of edges. (E, 2 * num_gaussians) + edge_attr: The representation of edges. (E, 2 * num_gaussians) """ d_emb = self.mlp(edge_length) # (num_edge, hidden_dim) edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim) @@ -321,8 +319,7 @@ def _extend_graph_order(num_nodes, edge_index, edge_type, order=3): edge_type: Bond types of the original graph. order: Extension order. Returns: - new_edge_index: Extended edge indices. - new_edge_type: Extended edge types. + new_edge_index: Extended edge indices. new_edge_type: Extended edge types. """ def binarize(x): @@ -338,7 +335,7 @@ def get_higher_order_adj_matrix(adj, order): - edge_index - edge_type Following attributes will be added to the data object: - - bond_edge_index: Original edge_index. + - bond_edge_index: Original edge_index. """ adj_mats = [ torch.eye(adj.size(0), dtype=torch.long, device=adj.device), @@ -506,14 +503,28 @@ def sigmoid(x): class DualEncoderEpsNetwork(ModelMixin, ConfigMixin): - def __init__(self, hidden_dim, num_convs, num_convs_local, cutoff, mlp_act, beta_schedule, beta_start, beta_end, num_diffusion_timesteps, edge_order, edge_encoder, smooth_conv): + def __init__( + self, + hidden_dim, + num_convs, + num_convs_local, + cutoff, + mlp_act, + beta_schedule, + beta_start, + beta_end, + num_diffusion_timesteps, + edge_order, + edge_encoder, + smooth_conv, + ): super().__init__() self.cutoff = cutoff self.edge_encoder = edge_encoder """ - edge_encoder: Takes both edge type and edge length as input and outputs a vector - [Note]: node embedding is done in SchNetEncoder + edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done + in SchNetEncoder """ self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config) self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config) @@ -535,7 +546,7 @@ def __init__(self, hidden_dim, num_convs, num_convs_local, cutoff, mlp_act, beta ) """ - `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs + `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs gradients w.r.t. edge_length (out_dim = 1). """ self.grad_global_dist_mlp = MultiLayerPerceptron( From 9064edaaf44a2aa1a42c9221682ff736e76c1773 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 5 Jul 2022 13:43:50 -0700 Subject: [PATCH 04/26] add property to self in init for colab --- src/diffusers/models/dualencoder_gfn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/dualencoder_gfn.py index b5f1750766f9..6176e30a61a8 100644 --- a/src/diffusers/models/dualencoder_gfn.py +++ b/src/diffusers/models/dualencoder_gfn.py @@ -521,6 +521,7 @@ def __init__( super().__init__() self.cutoff = cutoff self.edge_encoder = edge_encoder + self.edge_order = edge_order """ edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done From 7c15d6ba12e6ddd43eeaddd656454ddb72e13422 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 11 Jul 2022 09:25:45 -0700 Subject: [PATCH 05/26] small fix to types in forward() --- src/diffusers/models/dualencoder_gfn.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/dualencoder_gfn.py index 6176e30a61a8..3361e74320ad 100644 --- a/src/diffusers/models/dualencoder_gfn.py +++ b/src/diffusers/models/dualencoder_gfn.py @@ -662,8 +662,12 @@ def forward( edge_attr=edge_attr_local[local_edge_mask], ) # (E_local, 2H) ## Invariant features of edges (bond graph, local) - edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge[local_edge_mask]) # (E_local, 1) - + if isinstance(sigma_edge, torch.Tensor): + edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * ( + 1.0 / sigma_edge[local_edge_mask]) # (E_local, 1) + else: + edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1) + if return_edges: return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask else: From 120af84c98164ec11b8cf90b345b62316d4972d0 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 3 Oct 2022 14:34:02 -0400 Subject: [PATCH 06/26] rebase main, small updates --- src/diffusers/models/dualencoder_gfn.py | 31 +++++++++++---------- src/diffusers/schedulers/scheduling_ddpm.py | 10 +++++++ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/dualencoder_gfn.py index 3361e74320ad..5279b7f887be 100644 --- a/src/diffusers/models/dualencoder_gfn.py +++ b/src/diffusers/models/dualencoder_gfn.py @@ -24,8 +24,9 @@ class MultiLayerPerceptron(nn.Module): """ - Parameters: Multi-layer Perceptron. Note there is no activation or dropout in the last layer. + + Args: input_dim (int): input dimension hidden_dim (list of int): hidden dimensions activation (str or function, optional): activation function @@ -664,10 +665,11 @@ def forward( ## Invariant features of edges (bond graph, local) if isinstance(sigma_edge, torch.Tensor): edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * ( - 1.0 / sigma_edge[local_edge_mask]) # (E_local, 1) + 1.0 / sigma_edge[local_edge_mask] + ) # (E_local, 1) else: edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1) - + if return_edges: return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask else: @@ -905,12 +907,12 @@ def compute_alpha(beta, t): edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask] ) if clip_local is not None: - node_eq_local = clip_norm(node_eq_local, limit=clip_local) + node_eq_local = self.clip_norm(node_eq_local, limit=clip_local) # Global if sigmas[i] < global_start_sigma: edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float()) node_eq_global = eq_transform(edge_inv_global, pos, edge_index, edge_length) - node_eq_global = clip_norm(node_eq_global, limit=clip) + node_eq_global = self.clip_norm(node_eq_global, limit=clip) else: node_eq_global = 0 # Sum @@ -982,13 +984,18 @@ def compute_alpha(beta, t): if torch.isnan(pos).any(): print("NaN detected. Please restart.") raise FloatingPointError() - pos = center_pos(pos, batch) + pos = pos - scatter_mean(pos, batch, dim=0)[batch] # center_pos(pos, batch) if clip_pos is not None: pos = torch.clamp(pos, min=-clip_pos, max=clip_pos) pos_traj.append(pos.clone().cpu()) return pos, pos_traj + def clip_norm(vec, limit, p=2): + norm = torch.norm(vec, dim=-1, p=2, keepdim=True) + denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm)) + return vec * denom + def is_bond(edge_type): return torch.logical_and(edge_type < len(BOND_TYPES), edge_type > 0) @@ -1024,12 +1031,6 @@ def regularize_bond_length(edge_type, edge_length, rng=5.0): return d -def center_pos(pos, batch): - pos_center = pos - scatter_mean(pos, batch, dim=0)[batch] - return pos_center - - -def clip_norm(vec, limit, p=2): - norm = torch.norm(vec, dim=-1, p=2, keepdim=True) - denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm)) - return vec * denom +# def center_pos(pos, batch): +# pos_center = pos - scatter_mean(pos, batch, dim=0)[batch] +# return pos_center diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index cc17cee4c810..0b8b39935f4c 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -133,7 +133,17 @@ def __init__( ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule +<<<<<<< HEAD self.betas = betas_for_alpha_bar(num_train_timesteps) +======= + self.betas = betas_for_alpha_bar(timesteps) + elif beta_schedule == "sigmoid": + def sigmoid(x): + return 1 / (np.exp(-x) + 1) + + betas = np.linspace(-6, 6, timesteps) + self.betas = sigmoid(betas) * (beta_end - beta_start) + beta_start +>>>>>>> beaa1e0 (add sigmoid beta schedule to ddpm) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") From 9dd023a90993c1fc79aa746f2ddcc70e961305c5 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 11 Jul 2022 19:08:31 -0700 Subject: [PATCH 07/26] add helper function to trim colab --- src/diffusers/models/dualencoder_gfn.py | 96 ++++++++++++++++----- src/diffusers/schedulers/scheduling_ddpm.py | 1 + 2 files changed, 77 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/dualencoder_gfn.py index 5279b7f887be..10ff7412ef2b 100644 --- a/src/diffusers/models/dualencoder_gfn.py +++ b/src/diffusers/models/dualencoder_gfn.py @@ -463,17 +463,6 @@ def eq_transform(score_d, pos, edge_index, edge_length): return score_pos -def convert_cluster_score_d(cluster_score_d, cluster_pos, cluster_edge_index, cluster_edge_length, subgraph_index): - """ - Args: - cluster_score_d: (E_c, 1) - subgraph_index: (N, ) - """ - cluster_score_pos = eq_transform(cluster_score_d, cluster_pos, cluster_edge_index, cluster_edge_length) # (C, 3) - score_pos = cluster_score_pos[subgraph_index] - return score_pos - - def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): def sigmoid(x): return 1 / (np.exp(-x) + 1) @@ -843,6 +832,75 @@ def langevin_dynamics_sample( eta=kwargs.get("eta", 1.0), ) + def get_residual_params( + self, + t, + batch, + extend_order=False, + extend_radius=True, + step_lr=0.0000010, + clip=1000, + clip_local=None, + clip_pos=None, + min_sigma=0, + is_sidechain=None, + global_start_sigma=0.5, + w_global=1.0, + **kwargs, + ): + atom_type = batch.atom_type + bond_index = batch.edge_index + bond_type = batch.edge_type + num_graphs = batch.num_graphs + pos = batch.pos + + timesteps = torch.full(size=(num_graphs,), fill_value=t, dtype=torch.long, device=pos.device) + + edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self.forward( + atom_type=atom_type, + pos=batch.pos, + bond_index=bond_index, + bond_type=bond_type, + batch=batch.batch, + time_step=timesteps, + return_edges=True, + extend_order=extend_order, + extend_radius=extend_radius, + ) # (E_global, 1), (E_local, 1) + + # Local + node_eq_local = eq_transform(edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]) + if clip_local is not None: + node_eq_local = clip_norm(node_eq_local, limit=clip_local) + + return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask, node_eq_local + + def get_residual( + self, + pos, + sigma, + edge_inv_global, + local_edge_mask, + edge_index, + edge_length, + node_eq_local, + global_start_sigma=0.5, + w_global=1.0, + clip=1000.0, + ): + + # Global + if sigma < global_start_sigma: + edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float()) + node_eq_global = eq_transform(edge_inv_global, pos, edge_index, edge_length) + node_eq_global = clip_norm(node_eq_global, limit=clip) + else: + node_eq_global = 0 + + # Sum + eps_pos = node_eq_local + node_eq_global * w_global + return -eps_pos + def langevin_dynamics_sample_diffusion( self, atom_type, @@ -854,12 +912,9 @@ def langevin_dynamics_sample_diffusion( extend_order, extend_radius=True, n_steps=100, - step_lr=0.0000010, clip=1000, clip_local=None, clip_pos=None, - min_sigma=0, - is_sidechain=None, global_start_sigma=float("inf"), w_global=0.2, w_reg=1.0, @@ -907,12 +962,12 @@ def compute_alpha(beta, t): edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask] ) if clip_local is not None: - node_eq_local = self.clip_norm(node_eq_local, limit=clip_local) + node_eq_local = clip_norm(node_eq_local, limit=clip_local) # Global if sigmas[i] < global_start_sigma: edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float()) node_eq_global = eq_transform(edge_inv_global, pos, edge_index, edge_length) - node_eq_global = self.clip_norm(node_eq_global, limit=clip) + node_eq_global = clip_norm(node_eq_global, limit=clip) else: node_eq_global = 0 # Sum @@ -991,10 +1046,11 @@ def compute_alpha(beta, t): return pos, pos_traj - def clip_norm(vec, limit, p=2): - norm = torch.norm(vec, dim=-1, p=2, keepdim=True) - denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm)) - return vec * denom + +def clip_norm(vec, limit, p=2): + norm = torch.norm(vec, dim=-1, p=2, keepdim=True) + denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm)) + return vec * denom def is_bond(edge_type): diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 0b8b39935f4c..3c92c5af1f9c 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -138,6 +138,7 @@ def __init__( ======= self.betas = betas_for_alpha_bar(timesteps) elif beta_schedule == "sigmoid": + def sigmoid(x): return 1 / (np.exp(-x) + 1) From 2d1f748303bf5c07408d396189c759c3e2d2a2e2 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 11 Jul 2022 19:39:36 -0700 Subject: [PATCH 08/26] remove unused code --- src/diffusers/models/dualencoder_gfn.py | 269 +----------------------- 1 file changed, 10 insertions(+), 259 deletions(-) diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/dualencoder_gfn.py index 10ff7412ef2b..63c229447042 100644 --- a/src/diffusers/models/dualencoder_gfn.py +++ b/src/diffusers/models/dualencoder_gfn.py @@ -679,38 +679,6 @@ def get_loss( extend_order=True, extend_radius=True, is_sidechain=None, - ): - return self.get_loss_diffusion( - atom_type, - pos, - bond_index, - bond_type, - batch, - num_nodes_per_graph, - num_graphs, - anneal_power, - return_unreduced_loss, - return_unreduced_edge_loss, - extend_order, - extend_radius, - is_sidechain, - ) - - def get_loss_diffusion( - self, - atom_type, - pos, - bond_index, - bond_type, - batch, - num_nodes_per_graph, - num_graphs, - anneal_power=2.0, - return_unreduced_loss=False, - return_unreduced_edge_loss=False, - extend_order=True, - extend_radius=True, - is_sidechain=None, ): N = atom_type.size(0) node2graph = batch @@ -787,66 +755,13 @@ def get_loss_diffusion( else: return loss - def langevin_dynamics_sample( - self, - atom_type, - pos_init, - bond_index, - bond_type, - batch, - num_graphs, - extend_order, - extend_radius=True, - n_steps=100, - step_lr=0.0000010, - clip=1000, - clip_local=None, - clip_pos=None, - min_sigma=0, - is_sidechain=None, - global_start_sigma=float("inf"), - w_global=0.2, - w_reg=1.0, - **kwargs, - ): - return self.langevin_dynamics_sample_diffusion( - atom_type, - pos_init, - bond_index, - bond_type, - batch, - num_graphs, - extend_order, - extend_radius, - n_steps, - step_lr, - clip, - clip_local, - clip_pos, - min_sigma, - is_sidechain, - global_start_sigma, - w_global, - w_reg, - sampling_type=kwargs.get("sampling_type", "ddpm_noisy"), - eta=kwargs.get("eta", 1.0), - ) - def get_residual_params( self, t, batch, extend_order=False, extend_radius=True, - step_lr=0.0000010, - clip=1000, clip_local=None, - clip_pos=None, - min_sigma=0, - is_sidechain=None, - global_start_sigma=0.5, - w_global=1.0, - **kwargs, ): atom_type = batch.atom_type bond_index = batch.edge_index @@ -879,15 +794,20 @@ def get_residual( self, pos, sigma, - edge_inv_global, - local_edge_mask, - edge_index, - edge_length, - node_eq_local, + model_outputs, global_start_sigma=0.5, w_global=1.0, clip=1000.0, ): + ( + edge_inv_global, + edge_inv_local, + edge_index, + edge_type, + edge_length, + local_edge_mask, + node_eq_local, + ) = model_outputs # Global if sigma < global_start_sigma: @@ -901,151 +821,6 @@ def get_residual( eps_pos = node_eq_local + node_eq_global * w_global return -eps_pos - def langevin_dynamics_sample_diffusion( - self, - atom_type, - pos_init, - bond_index, - bond_type, - batch, - num_graphs, - extend_order, - extend_radius=True, - n_steps=100, - clip=1000, - clip_local=None, - clip_pos=None, - global_start_sigma=float("inf"), - w_global=0.2, - w_reg=1.0, - **kwargs, - ): - def compute_alpha(beta, t): - beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) - a = (1 - beta).cumprod(dim=0).index_select(0, t + 1) # .view(-1, 1, 1, 1) - return a - - sigmas = (1.0 - self.alphas).sqrt() / self.alphas.sqrt() - pos_traj = [] - if is_sidechain is not None: - assert pos_gt is not None, "need crd of backbone for sidechain prediction" - with torch.no_grad(): - # skip = self.num_timesteps // n_steps - # seq = range(0, self.num_timesteps, skip) - - ## to test sampling with less intermediate diffusion steps - # n_steps: the num of steps - seq = range(self.num_timesteps - n_steps, self.num_timesteps) - seq_next = [-1] + list(seq[:-1]) - - pos = pos_init * sigmas[-1] - if is_sidechain is not None: - pos[~is_sidechain] = pos_gt[~is_sidechain] - for i, j in tqdm(zip(reversed(seq), reversed(seq_next)), desc="sample"): - t = torch.full(size=(num_graphs,), fill_value=i, dtype=torch.long, device=pos.device) - - edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self( - atom_type=atom_type, - pos=pos, - bond_index=bond_index, - bond_type=bond_type, - batch=batch, - time_step=t, - return_edges=True, - extend_order=extend_order, - extend_radius=extend_radius, - is_sidechain=is_sidechain, - ) # (E_global, 1), (E_local, 1) - - # Local - node_eq_local = eq_transform( - edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask] - ) - if clip_local is not None: - node_eq_local = clip_norm(node_eq_local, limit=clip_local) - # Global - if sigmas[i] < global_start_sigma: - edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float()) - node_eq_global = eq_transform(edge_inv_global, pos, edge_index, edge_length) - node_eq_global = clip_norm(node_eq_global, limit=clip) - else: - node_eq_global = 0 - # Sum - eps_pos = node_eq_local + node_eq_global * w_global # + eps_pos_reg * w_reg - - # Update - - sampling_type = kwargs.get("sampling_type", "ddpm_noisy") # types: generalized, ddpm_noisy, ld - - noise = torch.randn_like(pos) # center_pos(torch.randn_like(pos), batch) - if sampling_type == "generalized" or sampling_type == "ddpm_noisy": - b = self.betas - t = t[0] - next_t = (torch.ones(1) * j).to(pos.device) - at = compute_alpha(b, t.long()) - at_next = compute_alpha(b, next_t.long()) - if sampling_type == "generalized": - eta = kwargs.get("eta", 1.0) - et = -eps_pos - ## original - # pos0_t = (pos - et * (1 - at).sqrt()) / at.sqrt() - ## reweighted - # pos0_t = pos - et * (1 - at).sqrt() / at.sqrt() - c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() - c2 = ((1 - at_next) - c1**2).sqrt() - # pos_next = at_next.sqrt() * pos0_t + c1 * noise + c2 * et - # pos_next = pos0_t + c1 * noise / at_next.sqrt() + c2 * et / at_next.sqrt() - - # pos_next = pos + et * (c2 / at_next.sqrt() - (1 - at).sqrt() / at.sqrt()) + noise * c1 / at_next.sqrt() - step_size_pos_ld = step_lr * (sigmas[i] / 0.01) ** 2 / sigmas[i] - step_size_pos_generalized = 5 * ((1 - at).sqrt() / at.sqrt() - c2 / at_next.sqrt()) - step_size_pos = ( - step_size_pos_ld - if step_size_pos_ld < step_size_pos_generalized - else step_size_pos_generalized - ) - - step_size_noise_ld = torch.sqrt((step_lr * (sigmas[i] / 0.01) ** 2) * 2) - step_size_noise_generalized = 3 * (c1 / at_next.sqrt()) - step_size_noise = ( - step_size_noise_ld - if step_size_noise_ld < step_size_noise_generalized - else step_size_noise_generalized - ) - - pos_next = pos - et * step_size_pos + noise * step_size_noise - - elif sampling_type == "ddpm_noisy": - atm1 = at_next - beta_t = 1 - at / atm1 - e = -eps_pos - pos0_from_e = (1.0 / at).sqrt() * pos - (1.0 / at - 1).sqrt() * e - mean_eps = ( - (atm1.sqrt() * beta_t) * pos0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * pos - ) / (1.0 - at) - mean = mean_eps - mask = 1 - (t == 0).float() - logvar = beta_t.log() - pos_next = mean + mask * torch.exp(0.5 * logvar) * noise - elif sampling_type == "ld": - step_size = step_lr * (sigmas[i] / 0.01) ** 2 - pos_next = pos + step_size * eps_pos / sigmas[i] + noise * torch.sqrt(step_size * 2) - - pos = pos_next - - if is_sidechain is not None: - pos[~is_sidechain] = pos_gt[~is_sidechain] - - if torch.isnan(pos).any(): - print("NaN detected. Please restart.") - raise FloatingPointError() - pos = pos - scatter_mean(pos, batch, dim=0)[batch] # center_pos(pos, batch) - if clip_pos is not None: - pos = torch.clamp(pos, min=-clip_pos, max=clip_pos) - pos_traj.append(pos.clone().cpu()) - - return pos, pos_traj - def clip_norm(vec, limit, p=2): norm = torch.norm(vec, dim=-1, p=2, keepdim=True) @@ -1053,20 +828,6 @@ def clip_norm(vec, limit, p=2): return vec * denom -def is_bond(edge_type): - return torch.logical_and(edge_type < len(BOND_TYPES), edge_type > 0) - - -def is_angle_edge(edge_type): - return edge_type == len(BOND_TYPES) + 1 - 1 - - -def is_dihedral_edge(edge_type): - return edge_type == len(BOND_TYPES) + 2 - 1 - - -def is_radius_edge(edge_type): - return edge_type == 0 def is_local_edge(edge_type): @@ -1080,13 +841,3 @@ def is_train_edge(edge_index, is_sidechain): is_sidechain = is_sidechain.bool() return torch.logical_or(is_sidechain[edge_index[0]], is_sidechain[edge_index[1]]) - -def regularize_bond_length(edge_type, edge_length, rng=5.0): - mask = is_bond(edge_type).float().reshape(-1, 1) - d = -torch.clamp(edge_length - rng, min=0.0, max=float("inf")) * mask - return d - - -# def center_pos(pos, batch): -# pos_center = pos - scatter_mean(pos, batch, dim=0)[batch] -# return pos_center From a4513e22d98c21647570686ed7e7bef209673775 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 13 Jul 2022 08:46:57 -0700 Subject: [PATCH 09/26] clean API for colab --- src/diffusers/models/dualencoder_gfn.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/dualencoder_gfn.py index 63c229447042..4a1c96e25494 100644 --- a/src/diffusers/models/dualencoder_gfn.py +++ b/src/diffusers/models/dualencoder_gfn.py @@ -783,7 +783,7 @@ def get_residual_params( extend_radius=extend_radius, ) # (E_global, 1), (E_local, 1) - # Local + # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff node_eq_local = eq_transform(edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]) if clip_local is not None: node_eq_local = clip_norm(node_eq_local, limit=clip_local) @@ -828,8 +828,6 @@ def clip_norm(vec, limit, p=2): return vec * denom - - def is_local_edge(edge_type): return edge_type > 0 @@ -840,4 +838,3 @@ def is_train_edge(edge_index, is_sidechain): else: is_sidechain = is_sidechain.bool() return torch.logical_or(is_sidechain[edge_index[0]], is_sidechain[edge_index[1]]) - From ce71e2fc87a81e80d5fa7bf7ff29062c305ecb32 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Thu, 21 Jul 2022 14:37:11 -0700 Subject: [PATCH 10/26] remove unused code --- src/diffusers/models/dualencoder_gfn.py | 146 +----------------------- 1 file changed, 2 insertions(+), 144 deletions(-) diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/dualencoder_gfn.py index 4a1c96e25494..6cac281c2b45 100644 --- a/src/diffusers/models/dualencoder_gfn.py +++ b/src/diffusers/models/dualencoder_gfn.py @@ -13,7 +13,6 @@ from torch_geometric.utils import dense_to_sparse, to_dense_adj from torch_scatter import scatter_add, scatter_mean from torch_sparse import SparseTensor, coalesce -from tqdm.auto import tqdm from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin @@ -463,35 +462,6 @@ def eq_transform(score_d, pos, edge_index, edge_length): return score_pos -def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): - def sigmoid(x): - return 1 / (np.exp(-x) + 1) - - if beta_schedule == "quad": - betas = ( - np.linspace( - beta_start**0.5, - beta_end**0.5, - num_diffusion_timesteps, - dtype=np.float64, - ) - ** 2 - ) - elif beta_schedule == "linear": - betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) - elif beta_schedule == "const": - betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) - elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 - betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) - elif beta_schedule == "sigmoid": - betas = np.linspace(-6, 6, num_diffusion_timesteps) - betas = sigmoid(betas) * (beta_end - beta_start) + beta_start - else: - raise NotImplementedError(beta_schedule) - assert betas.shape == (num_diffusion_timesteps,) - return betas - - class DualEncoderEpsNetwork(ModelMixin, ConfigMixin): def __init__( self, @@ -500,10 +470,6 @@ def __init__( num_convs_local, cutoff, mlp_act, - beta_schedule, - beta_start, - beta_end, - num_diffusion_timesteps, edge_order, edge_encoder, smooth_conv, @@ -554,23 +520,6 @@ def __init__( self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp]) self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp]) - self.model_type = type # type # 'diffusion'; 'dsm' - - # denoising diffusion - ## betas - betas = get_beta_schedule( - beta_schedule=beta_schedule, - beta_start=beta_start, - beta_end=beta_end, - num_diffusion_timesteps=num_diffusion_timesteps, - ) - betas = torch.from_numpy(betas).float() - self.betas = nn.Parameter(betas, requires_grad=False) - ## variances - alphas = (1.0 - betas).cumprod(dim=0) - self.alphas = nn.Parameter(alphas, requires_grad=False) - self.num_timesteps = self.betas.size(0) - def forward( self, atom_type, @@ -578,7 +527,7 @@ def forward( bond_index, bond_type, batch, - time_step, + time_step, # NOTE, model trained without timestep performed best edge_index=None, edge_type=None, edge_length=None, @@ -617,7 +566,6 @@ def forward( # Encoding global edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges - # edge_attr += temb_edge # Global node_attr_global = self.encoder_global( @@ -651,6 +599,7 @@ def forward( edge_index=edge_index[:, local_edge_mask], edge_attr=edge_attr_local[local_edge_mask], ) # (E_local, 2H) + ## Invariant features of edges (bond graph, local) if isinstance(sigma_edge, torch.Tensor): edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * ( @@ -664,97 +613,6 @@ def forward( else: return edge_inv_global, edge_inv_local - def get_loss( - self, - atom_type, - pos, - bond_index, - bond_type, - batch, - num_nodes_per_graph, - num_graphs, - anneal_power=2.0, - return_unreduced_loss=False, - return_unreduced_edge_loss=False, - extend_order=True, - extend_radius=True, - is_sidechain=None, - ): - N = atom_type.size(0) - node2graph = batch - - # Four elements for DDPM: original_data(pos), gaussian_noise(pos_noise), beta(sigma), time_step - # Sample noise levels - time_step = torch.randint(0, self.num_timesteps, size=(num_graphs // 2 + 1,), device=pos.device) - time_step = torch.cat([time_step, self.num_timesteps - time_step - 1], dim=0)[:num_graphs] - a = self.alphas.index_select(0, time_step) # (G, ) - # Perterb pos - a_pos = a.index_select(0, node2graph).unsqueeze(-1) # (N, 1) - pos_noise = torch.zeros(size=pos.size(), device=pos.device) - pos_noise.normal_() - pos_perturbed = pos + pos_noise * (1.0 - a_pos).sqrt() / a_pos.sqrt() - - # Update invariant edge features, as shown in equation 5-7 - edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self( - atom_type=atom_type, - pos=pos_perturbed, - bond_index=bond_index, - bond_type=bond_type, - batch=batch, - time_step=time_step, - return_edges=True, - extend_order=extend_order, - extend_radius=extend_radius, - is_sidechain=is_sidechain, - ) # (E_global, 1), (E_local, 1) - - edge2graph = node2graph.index_select(0, edge_index[0]) - # Compute sigmas_edge - a_edge = a.index_select(0, edge2graph).unsqueeze(-1) # (E, 1) - - # Compute original and perturbed distances - d_gt = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1) - d_perturbed = edge_length - # Filtering for protein - train_edge_mask = is_train_edge(edge_index, is_sidechain) - d_perturbed = torch.where(train_edge_mask.unsqueeze(-1), d_perturbed, d_gt) - - if self.edge_encoder == "gaussian": - # Distances must be greater than 0 - d_sgn = torch.sign(d_perturbed) - d_perturbed = torch.clamp(d_perturbed * d_sgn, min=0.01, max=float("inf")) - d_target = (d_gt - d_perturbed) / (1.0 - a_edge).sqrt() * a_edge.sqrt() # (E_global, 1), denoising direction - - global_mask = torch.logical_and( - torch.logical_or(d_perturbed <= self.cutoff, local_edge_mask.unsqueeze(-1)), - ~local_edge_mask.unsqueeze(-1), - ) - target_d_global = torch.where(global_mask, d_target, torch.zeros_like(d_target)) - edge_inv_global = torch.where(global_mask, edge_inv_global, torch.zeros_like(edge_inv_global)) - target_pos_global = eq_transform(target_d_global, pos_perturbed, edge_index, edge_length) - node_eq_global = eq_transform(edge_inv_global, pos_perturbed, edge_index, edge_length) - loss_global = (node_eq_global - target_pos_global) ** 2 - loss_global = 2 * torch.sum(loss_global, dim=-1, keepdim=True) - - target_pos_local = eq_transform( - d_target[local_edge_mask], pos_perturbed, edge_index[:, local_edge_mask], edge_length[local_edge_mask] - ) - node_eq_local = eq_transform( - edge_inv_local, pos_perturbed, edge_index[:, local_edge_mask], edge_length[local_edge_mask] - ) - loss_local = (node_eq_local - target_pos_local) ** 2 - loss_local = 5 * torch.sum(loss_local, dim=-1, keepdim=True) - - # loss for atomic eps regression - loss = loss_global + loss_local - - if return_unreduced_edge_loss: - pass - elif return_unreduced_loss: - return loss, loss_global, loss_local - else: - return loss - def get_residual_params( self, t, From 38658929732f66f9b83fc2641f294743e358f971 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 3 Oct 2022 14:35:20 -0400 Subject: [PATCH 11/26] weird rebase --- src/diffusers/models/dualencoder_gfn.py | 32 +- tests/test_modeling_utils.py | 967 ++++++++++++++++++++++++ 2 files changed, 983 insertions(+), 16 deletions(-) create mode 100755 tests/test_modeling_utils.py diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/dualencoder_gfn.py index 6cac281c2b45..367a64aa30c1 100644 --- a/src/diffusers/models/dualencoder_gfn.py +++ b/src/diffusers/models/dualencoder_gfn.py @@ -1,5 +1,5 @@ # Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff -from typing import Callable, Union +from typing import Callable, Union, Dict import numpy as np import torch @@ -520,7 +520,7 @@ def __init__( self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp]) self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp]) - def forward( + def _forward( self, atom_type, pos, @@ -613,28 +613,28 @@ def forward( else: return edge_inv_global, edge_inv_local - def get_residual_params( + def forward( self, - t, - batch, + sample, + timestep: Union[torch.Tensor, float, int], extend_order=False, extend_radius=True, clip_local=None, - ): - atom_type = batch.atom_type - bond_index = batch.edge_index - bond_type = batch.edge_type - num_graphs = batch.num_graphs - pos = batch.pos + )-> Dict[str, torch.FloatTensor]: + atom_type = sample.atom_type + bond_index = sample.edge_index + bond_type = sample.edge_type + num_graphs = sample.num_graphs + pos = sample.pos - timesteps = torch.full(size=(num_graphs,), fill_value=t, dtype=torch.long, device=pos.device) + timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device) - edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self.forward( + edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward( atom_type=atom_type, - pos=batch.pos, + pos=sample.pos, bond_index=bond_index, bond_type=bond_type, - batch=batch.batch, + batch=sample.batch, time_step=timesteps, return_edges=True, extend_order=extend_order, @@ -677,7 +677,7 @@ def get_residual( # Sum eps_pos = node_eq_local + node_eq_global * w_global - return -eps_pos + return {"sample": -eps_pos} def clip_norm(vec, limit, p=2): diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py new file mode 100755 index 000000000000..a791a565cc2a --- /dev/null +++ b/tests/test_modeling_utils.py @@ -0,0 +1,967 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +import pdb +import tempfile +import unittest + +import numpy as np +import torch + +import PIL +from diffusers import UNet2DConditionModel # noqa: F401 TODO(Patrick) - need to write tests with it +from diffusers import ( + AutoencoderKL, + DDIMPipeline, + DDIMScheduler, + DDPMPipeline, + DDPMScheduler, + DualEncoderEpsNetwork, + LDMPipeline, + LDMTextToImagePipeline, + PNDMPipeline, + PNDMScheduler, + ScoreSdeVePipeline, + ScoreSdeVeScheduler, + UNet2DModel, + VQModel, +) +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.testing_utils import floats_tensor, slow, torch_device +from diffusers.training_utils import EMAModel + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class SampleObject(ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 3], + ): + pass + + +class ConfigTester(unittest.TestCase): + def test_load_not_from_mixin(self): + with self.assertRaises(ValueError): + ConfigMixin.from_config("dummy_path") + + def test_register_to_config(self): + obj = SampleObject() + config = obj.config + assert config["a"] == 2 + assert config["b"] == 5 + assert config["c"] == (2, 5) + assert config["d"] == "for diffusion" + assert config["e"] == [1, 3] + + # init ignore private arguments + obj = SampleObject(_name_or_path="lalala") + config = obj.config + assert config["a"] == 2 + assert config["b"] == 5 + assert config["c"] == (2, 5) + assert config["d"] == "for diffusion" + assert config["e"] == [1, 3] + + # can override default + obj = SampleObject(c=6) + config = obj.config + assert config["a"] == 2 + assert config["b"] == 5 + assert config["c"] == 6 + assert config["d"] == "for diffusion" + assert config["e"] == [1, 3] + + # can use positional arguments. + obj = SampleObject(1, c=6) + config = obj.config + assert config["a"] == 1 + assert config["b"] == 5 + assert config["c"] == 6 + assert config["d"] == "for diffusion" + assert config["e"] == [1, 3] + + def test_save_load(self): + obj = SampleObject() + config = obj.config + + assert config["a"] == 2 + assert config["b"] == 5 + assert config["c"] == (2, 5) + assert config["d"] == "for diffusion" + assert config["e"] == [1, 3] + + with tempfile.TemporaryDirectory() as tmpdirname: + obj.save_config(tmpdirname) + new_obj = SampleObject.from_config(tmpdirname) + new_config = new_obj.config + + # unfreeze configs + config = dict(config) + new_config = dict(new_config) + + assert config.pop("c") == (2, 5) # instantiated as tuple + assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json + assert config == new_config + + +class ModelTesterMixin: + def test_from_pretrained_save_pretrained(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + + with torch.no_grad(): + image = model(**inputs_dict) + if isinstance(image, dict): + image = image["sample"] + + new_image = new_model(**inputs_dict) + + if isinstance(new_image, dict): + new_image = new_image["sample"] + + max_diff = (image - new_image).abs().sum().item() + self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") + + def test_determinism(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + with torch.no_grad(): + first = model(**inputs_dict) + if isinstance(first, dict): + first = first["sample"] + + second = model(**inputs_dict) + if isinstance(second, dict): + second = second["sample"] + + out_1 = first.cpu().numpy() + out_2 = second.cpu().numpy() + out_1 = out_1[~np.isnan(out_1)] + out_2 = out_2[~np.isnan(out_2)] + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output["sample"] + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_forward_signature(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["sample", "timestep"] + self.assertListEqual(arg_names[:2], expected_arg_names) + + def test_model_from_config(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + # test if the model can be loaded from the config + # and has all the expected shape + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_config(tmpdirname) + new_model = self.model_class.from_config(tmpdirname) + new_model.to(torch_device) + new_model.eval() + + # check if all paramters shape are the same + for param_name in model.state_dict().keys(): + param_1 = model.state_dict()[param_name] + param_2 = new_model.state_dict()[param_name] + self.assertEqual(param_1.shape, param_2.shape) + + with torch.no_grad(): + output_1 = model(**inputs_dict) + + if isinstance(output_1, dict): + output_1 = output_1["sample"] + + output_2 = new_model(**inputs_dict) + + if isinstance(output_2, dict): + output_2 = output_2["sample"] + + self.assertEqual(output_1.shape, output_2.shape) + + def test_training(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.train() + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output["sample"] + + noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + loss.backward() + + def test_ema_training(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.train() + ema_model = EMAModel(model, device=torch_device) + + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output["sample"] + + noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + loss.backward() + ema_model.step(model) + + +class UnetModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet2DModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64), + "down_block_types": ("DownBlock2D", "AttnDownBlock2D"), + "up_block_types": ("AttnUpBlock2D", "UpBlock2D"), + "attention_head_dim": None, + "out_channels": 3, + "in_channels": 3, + "layers_per_block": 2, + "sample_size": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + +# TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints +# def test_output_pretrained(self): +# model = UNet2DModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet") +# model.eval() +# +# torch.manual_seed(0) +# if torch.cuda.is_available(): +# torch.cuda.manual_seed_all(0) +# +# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) +# time_step = torch.tensor([10]) +# +# with torch.no_grad(): +# output = model(noise, time_step)["sample"] +# +# output_slice = output[0, -1, -3:, -3:].flatten() +# fmt: off +# expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053]) +# fmt: on +# self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + + +class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet2DModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "sample_size": 32, + "in_channels": 4, + "out_channels": 4, + "layers_per_block": 2, + "block_out_channels": (32, 64), + "attention_head_dim": 32, + "down_block_types": ("DownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "UpBlock2D"), + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) + + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input)["sample"] + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + time_step = torch.tensor([10] * noise.shape[0]) + + with torch.no_grad(): + output = model(noise, time_step)["sample"] + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + +# TODO(Patrick) - Re-add this test after having cleaned up LDM +# def test_output_pretrained_spatial_transformer(self): +# model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial") +# model.eval() +# +# torch.manual_seed(0) +# if torch.cuda.is_available(): +# torch.cuda.manual_seed_all(0) +# +# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) +# context = torch.ones((1, 16, 64), dtype=torch.float32) +# time_step = torch.tensor([10] * noise.shape[0]) +# +# with torch.no_grad(): +# output = model(noise, time_step, context=context) +# +# output_slice = output[0, -1, -3:, -3:].flatten() +# fmt: off +# expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890]) +# fmt: on +# +# self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) +# + + +class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet2DModel + + @property + def dummy_input(self, sizes=(32, 32)): + batch_size = 4 + num_channels = 3 + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor(batch_size * [10]).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": [32, 64, 64, 64], + "in_channels": 3, + "layers_per_block": 1, + "out_channels": 3, + "time_embedding_type": "fourier", + "norm_eps": 1e-6, + "mid_block_scale_factor": math.sqrt(2.0), + "norm_num_groups": None, + "down_block_types": [ + "SkipDownBlock2D", + "AttnSkipDownBlock2D", + "SkipDownBlock2D", + "SkipDownBlock2D", + ], + "up_block_types": [ + "SkipUpBlock2D", + "SkipUpBlock2D", + "AttnSkipUpBlock2D", + "SkipUpBlock2D", + ], + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + inputs = self.dummy_input + noise = floats_tensor((4, 3) + (256, 256)).to(torch_device) + inputs["sample"] = noise + image = model(**inputs) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained_ve_mid(self): + model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256") + model.to(torch_device) + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + batch_size = 4 + num_channels = 3 + sizes = (256, 256) + + noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) + + with torch.no_grad(): + output = model(noise, time_step)["sample"] + + output_slice = output[0, -3:, -3:, -1].flatten().cpu() + # fmt: off + expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + + def test_output_pretrained_ve_large(self): + model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update") + model.to(torch_device) + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) + + with torch.no_grad(): + output = model(noise, time_step)["sample"] + + output_slice = output[0, -3:, -3:, -1].flatten().cpu() + # fmt: off + expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + + +class VQModelTests(ModelTesterMixin, unittest.TestCase): + model_class = VQModel + + @property + def dummy_input(self, sizes=(32, 32)): + batch_size = 4 + num_channels = 3 + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "ch": 64, + "out_ch": 3, + "num_res_blocks": 1, + "in_channels": 3, + "attn_resolutions": [], + "resolution": 32, + "z_channels": 3, + "n_embed": 256, + "embed_dim": 3, + "sane_index_shape": False, + "ch_mult": (1,), + "double_z": False, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_forward_signature(self): + pass + + def test_training(self): + pass + + def test_from_pretrained_hub(self): + model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = VQModel.from_pretrained("fusing/vqgan-dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) + with torch.no_grad(): + output = model(image) + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + + +class DualEncoderEpsNetworkTests(unittest.TestCase): + model_class = DualEncoderEpsNetwork + + @property + def dummy_input(self): + batch_size = 4 + + time_step = torch.tensor([10] * batch_size).to(torch_device) + + class GeoDiffData: + num_nodes = 26 + atom_type = torch.randint(0, 6, (num_nodes,)).to(torch_device) + bond_edge_index = torch.randint(0, 20, (2,54,)).to(torch_device) + edge_type = torch.randint(0, 20, (238,)).to(torch_device) + num_graphs = 1 + pos = torch.randn(num_nodes, 3).to(torch_device) + + torch.manual_seed(0) + noise = GeoDiffData() + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 16, 14) + + @property + def output_shape(self): + return (4, 16, 14) + + def prepare_init_args_and_inputs_for_common(self): + init_dict ={ + "hidden_dim": 128, + "num_convs": 6, + "num_convs_local": 4, + "cutoff": 10.0, + "mlp_act": "relu", + "edge_order": 3, + "edge_encoder": "mlp", + "smooth_conv": True + } + + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = DualEncoderEpsNetwork.from_pretrained( + "fusing/gfn-molecule-gen-drugs", output_loading_info=True + ) + import ipdb; pdb.set_trace() + print() + self.assertIsNotNone(model) + import ipdb; pdb.set_trace() + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = DualEncoderEpsNetwork.from_pretrained("fusing/gfn-molecule-gen-drugs") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + seq_len = 16 + input = self.dummy_input + sample, time_step = input["sample"], input["timestep"] + import ipdb; pdb.set_trace() + with torch.no_grad(): + output = model(sample, time_step) + import ipdb; pdb.set_trace() + + output_slice = output[0, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) + + +class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): + model_class = AutoencoderKL + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "ch": 64, + "ch_mult": (1,), + "embed_dim": 4, + "in_channels": 3, + "attn_resolutions": [], + "num_res_blocks": 1, + "out_ch": 3, + "resolution": 32, + "z_channels": 4, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_forward_signature(self): + pass + + def test_training(self): + pass + + def test_from_pretrained_hub(self): + model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) + with torch.no_grad(): + output = model(image, sample_posterior=True) + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + + +class PipelineTesterMixin(unittest.TestCase): + def test_from_pretrained_save_pretrained(self): + # 1. Load models + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + schedular = DDPMScheduler(num_train_timesteps=10) + + ddpm = DDPMPipeline(model, schedular) + + with tempfile.TemporaryDirectory() as tmpdirname: + ddpm.save_pretrained(tmpdirname) + new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) + + generator = torch.manual_seed(0) + + image = ddpm(generator=generator, output_type="numpy")["sample"] + generator = generator.manual_seed(0) + new_image = new_ddpm(generator=generator, output_type="numpy")["sample"] + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + + @slow + def test_from_pretrained_hub(self): + model_path = "google/ddpm-cifar10-32" + + ddpm = DDPMPipeline.from_pretrained(model_path) + ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) + + ddpm.scheduler.num_timesteps = 10 + ddpm_from_hub.scheduler.num_timesteps = 10 + + generator = torch.manual_seed(0) + + image = ddpm(generator=generator, output_type="numpy")["sample"] + generator = generator.manual_seed(0) + new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"] + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + + @slow + def test_output_format(self): + model_path = "google/ddpm-cifar10-32" + + pipe = DDIMPipeline.from_pretrained(model_path) + + generator = torch.manual_seed(0) + images = pipe(generator=generator, output_type="numpy")["sample"] + assert images.shape == (1, 32, 32, 3) + assert isinstance(images, np.ndarray) + + images = pipe(generator=generator, output_type="pil")["sample"] + assert isinstance(images, list) + assert len(images) == 1 + assert isinstance(images[0], PIL.Image.Image) + + # use PIL by default + images = pipe(generator=generator)["sample"] + assert isinstance(images, list) + assert isinstance(images[0], PIL.Image.Image) + + @slow + def test_ddpm_cifar10(self): + model_id = "google/ddpm-cifar10-32" + + unet = UNet2DModel.from_pretrained(model_id) + scheduler = DDPMScheduler.from_config(model_id) + scheduler = scheduler.set_format("pt") + + ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) + + generator = torch.manual_seed(0) + image = ddpm(generator=generator, output_type="numpy")["sample"] + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @slow + def test_ddim_lsun(self): + model_id = "google/ddpm-ema-bedroom-256" + + unet = UNet2DModel.from_pretrained(model_id) + scheduler = DDIMScheduler.from_config(model_id) + + ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) + + generator = torch.manual_seed(0) + image = ddpm(generator=generator, output_type="numpy")["sample"] + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @slow + def test_ddim_cifar10(self): + model_id = "google/ddpm-cifar10-32" + + unet = UNet2DModel.from_pretrained(model_id) + scheduler = DDIMScheduler(tensor_format="pt") + + ddim = DDIMPipeline(unet=unet, scheduler=scheduler) + + generator = torch.manual_seed(0) + image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"] + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @slow + def test_pndm_cifar10(self): + model_id = "google/ddpm-cifar10-32" + + unet = UNet2DModel.from_pretrained(model_id) + scheduler = PNDMScheduler(tensor_format="pt") + + pndm = PNDMPipeline(unet=unet, scheduler=scheduler) + generator = torch.manual_seed(0) + image = pndm(generator=generator, output_type="numpy")["sample"] + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @slow + def test_ldm_text2img(self): + ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[ + "sample" + ] + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @slow + def test_ldm_text2img_fast(self): + ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"] + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @slow + def test_score_sde_ve_pipeline(self): + model = UNet2DModel.from_pretrained("google/ncsnpp-church-256") + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256") + + sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) + + torch.manual_seed(0) + image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"] + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @slow + def test_ldm_uncond(self): + ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256") + + generator = torch.manual_seed(0) + image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"] + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 From 127f72a63afa5d2b56538fb13987f23305b67147 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Fri, 22 Jul 2022 15:25:57 -0700 Subject: [PATCH 12/26] tests pass --- src/diffusers/models/dualencoder_gfn.py | 16 +-- tests/test_modeling_utils.py | 143 ++++++++++++++++++++---- 2 files changed, 129 insertions(+), 30 deletions(-) diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/dualencoder_gfn.py index 367a64aa30c1..bc3b5d707482 100644 --- a/src/diffusers/models/dualencoder_gfn.py +++ b/src/diffusers/models/dualencoder_gfn.py @@ -465,14 +465,14 @@ def eq_transform(score_d, pos, edge_index, edge_length): class DualEncoderEpsNetwork(ModelMixin, ConfigMixin): def __init__( self, - hidden_dim, - num_convs, - num_convs_local, - cutoff, - mlp_act, - edge_order, - edge_encoder, - smooth_conv, + hidden_dim=128, + num_convs=6, + num_convs_local=4, + cutoff=10.0, + mlp_act="relu", + edge_order=3, + edge_encoder="mlp", + smooth_conv=True, ): super().__init__() self.cutoff = cutoff diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index a791a565cc2a..ae6593090bba 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -608,35 +608,74 @@ def test_output_pretrained(self): self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) -class DualEncoderEpsNetworkTests(unittest.TestCase): +class DualEncoderEpsNetworkTests(ModelTesterMixin, unittest.TestCase): model_class = DualEncoderEpsNetwork @property def dummy_input(self): - batch_size = 4 - - time_step = torch.tensor([10] * batch_size).to(torch_device) + batch_size = 2 + time_step = 10 class GeoDiffData: + # constants corresponding to a molecule num_nodes = 26 - atom_type = torch.randint(0, 6, (num_nodes,)).to(torch_device) - bond_edge_index = torch.randint(0, 20, (2,54,)).to(torch_device) - edge_type = torch.randint(0, 20, (238,)).to(torch_device) + num_edges = 54 num_graphs = 1 - pos = torch.randn(num_nodes, 3).to(torch_device) + + # sampling + torch.Generator(device=torch_device) + torch.manual_seed(0) + + # molecule components / properties + atom_type = torch.randint(0, 6, (num_nodes*batch_size,)).to(torch_device) + edge_index = torch.randint(0, 20, (2, num_edges*batch_size,)).to(torch_device) + edge_type = torch.randint(0, 20, (num_edges*batch_size,)).to(torch_device) + pos = torch.randn(num_nodes*batch_size, 3).to(torch_device) + batch = torch.tensor([*range(batch_size)]).repeat_interleave(num_nodes) + nx = batch_size torch.manual_seed(0) noise = GeoDiffData() return {"sample": noise, "timestep": time_step} - @property - def input_shape(self): - return (4, 16, 14) - @property def output_shape(self): - return (4, 16, 14) + # subset of shapes for dummy input + class GeoDiffShapes: + shape_0 = torch.Size([1305, 1]) + shape_1 = torch.Size([92, 1]) + return GeoDiffShapes() + + # training not implemented for this model yet + def test_training(self): + pass + + def test_ema_training(self): + pass + + def test_determinism(self): + # TODO + pass + + def test_output(self): + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output["sample"] + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + shapes = self.output_shapes() + self.assertEqual(output[0].shape, shapes.shape_0, "Input and output shapes do not match") + self.assertEqual(output[1].shape, shapes.shape_1, "Input and output shapes do not match") def prepare_init_args_and_inputs_for_common(self): init_dict ={ @@ -657,10 +696,7 @@ def test_from_pretrained_hub(self): model, loading_info = DualEncoderEpsNetwork.from_pretrained( "fusing/gfn-molecule-gen-drugs", output_loading_info=True ) - import ipdb; pdb.set_trace() - print() self.assertIsNotNone(model) - import ipdb; pdb.set_trace() self.assertEqual(len(loading_info["missing_keys"]), 0) model.to(torch_device) @@ -676,20 +712,83 @@ def test_output_pretrained(self): if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - seq_len = 16 input = self.dummy_input sample, time_step = input["sample"], input["timestep"] - import ipdb; pdb.set_trace() with torch.no_grad(): output = model(sample, time_step) - import ipdb; pdb.set_trace() - output_slice = output[0, -3:, -3:].flatten() + # outputs correspond to molecule conformation + # variables: edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask, node_eq_local + + output_slice_0 = output[0][-6:].flatten() # fmt: off - expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584]) + expected_output_slice_0 = torch.tensor([10633.3818, + 20670.0996, + 14251.5283, + 19087.3828, + 18654.5586, + 18116.2617]) # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) + output_slice_1 = output[1][-6:].flatten() + # fmt: off + expected_output_slice_1 = torch.tensor([-946.4217, + 4.1009, + -60.8591, + 4.5929, + 192.5891, + -17.7297]) + # fmt: on + + + output_slice_4 = output[4][-6:].flatten() + # fmt: off + expected_output_slice_4 = torch.tensor([1.9213, + 2.2776, + 1.1385, + 0.8868, + 1.0347, + 0.8127]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice_0, expected_output_slice_0, rtol=1e-3)) + self.assertTrue(torch.allclose(output_slice_1, expected_output_slice_1, rtol=1e-3)) + self.assertTrue(torch.allclose(output_slice_4, expected_output_slice_4, rtol=1e-3)) + + def test_model_from_config(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + # test if the model can be loaded from the config + # and has all the expected shape + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_config(tmpdirname) + new_model = self.model_class.from_config(tmpdirname) + new_model.to(torch_device) + new_model.eval() + + # check if all paramters shape are the same + for param_name in model.state_dict().keys(): + param_1 = model.state_dict()[param_name] + param_2 = new_model.state_dict()[param_name] + self.assertEqual(param_1.shape, param_2.shape) + + with torch.no_grad(): + output_1 = model(**inputs_dict) + + if isinstance(output_1, dict): + output_1 = output_1["sample"] + + output_2 = new_model(**inputs_dict) + + if isinstance(output_2, dict): + output_2 = output_2["sample"] + + self.assertEqual(output_1[0].shape, output_2[0].shape) + self.assertEqual(output_1[1].shape, output_2[1].shape) class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): From 2f0ac21fc52f617f51491ce64e6447bb4a222bc3 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Fri, 22 Jul 2022 15:29:53 -0700 Subject: [PATCH 13/26] make style and fix-copies --- src/diffusers/__init__.py | 2 +- src/diffusers/models/dualencoder_gfn.py | 6 ++--- tests/test_modeling_utils.py | 35 +++++++++++++++---------- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1cf64a4a2ebf..90edee7b77c9 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,7 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel + from .models import AutoencoderKL, DualEncoderEpsNetwork, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/dualencoder_gfn.py index bc3b5d707482..09525dd57f7b 100644 --- a/src/diffusers/models/dualencoder_gfn.py +++ b/src/diffusers/models/dualencoder_gfn.py @@ -1,5 +1,5 @@ # Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff -from typing import Callable, Union, Dict +from typing import Callable, Dict, Union import numpy as np import torch @@ -527,7 +527,7 @@ def _forward( bond_index, bond_type, batch, - time_step, # NOTE, model trained without timestep performed best + time_step, # NOTE, model trained without timestep performed best edge_index=None, edge_type=None, edge_length=None, @@ -620,7 +620,7 @@ def forward( extend_order=False, extend_radius=True, clip_local=None, - )-> Dict[str, torch.FloatTensor]: + ) -> Dict[str, torch.FloatTensor]: atom_type = sample.atom_type bond_index = sample.edge_index bond_type = sample.edge_type diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index ae6593090bba..19ecc62bb936 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -627,10 +627,17 @@ class GeoDiffData: torch.manual_seed(0) # molecule components / properties - atom_type = torch.randint(0, 6, (num_nodes*batch_size,)).to(torch_device) - edge_index = torch.randint(0, 20, (2, num_edges*batch_size,)).to(torch_device) - edge_type = torch.randint(0, 20, (num_edges*batch_size,)).to(torch_device) - pos = torch.randn(num_nodes*batch_size, 3).to(torch_device) + atom_type = torch.randint(0, 6, (num_nodes * batch_size,)).to(torch_device) + edge_index = torch.randint( + 0, + 20, + ( + 2, + num_edges * batch_size, + ), + ).to(torch_device) + edge_type = torch.randint(0, 20, (num_edges * batch_size,)).to(torch_device) + pos = torch.randn(num_nodes * batch_size, 3).to(torch_device) batch = torch.tensor([*range(batch_size)]).repeat_interleave(num_nodes) nx = batch_size @@ -645,6 +652,7 @@ def output_shape(self): class GeoDiffShapes: shape_0 = torch.Size([1305, 1]) shape_1 = torch.Size([92, 1]) + return GeoDiffShapes() # training not implemented for this model yet @@ -678,15 +686,15 @@ def test_output(self): self.assertEqual(output[1].shape, shapes.shape_1, "Input and output shapes do not match") def prepare_init_args_and_inputs_for_common(self): - init_dict ={ - "hidden_dim": 128, - "num_convs": 6, - "num_convs_local": 4, - "cutoff": 10.0, - "mlp_act": "relu", - "edge_order": 3, - "edge_encoder": "mlp", - "smooth_conv": True + init_dict = { + "hidden_dim": 128, + "num_convs": 6, + "num_convs_local": 4, + "cutoff": 10.0, + "mlp_act": "relu", + "edge_order": 3, + "edge_encoder": "mlp", + "smooth_conv": True, } inputs_dict = self.dummy_input @@ -740,7 +748,6 @@ def test_output_pretrained(self): -17.7297]) # fmt: on - output_slice_4 = output[4][-6:].flatten() # fmt: off expected_output_slice_4 = torch.tensor([1.9213, From 25ec89d7dcb9f4b1483a984cfa0711f55ed386a4 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 25 Jul 2022 09:22:54 -0700 Subject: [PATCH 14/26] rename model and file --- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 1 + .../models/{dualencoder_gfn.py => molecule_gnn.py} | 2 +- tests/test_modeling_utils.py | 10 +++++----- 4 files changed, 8 insertions(+), 7 deletions(-) rename src/diffusers/models/{dualencoder_gfn.py => molecule_gnn.py} (99%) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 90edee7b77c9..5fecc0554ef6 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,7 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, DualEncoderEpsNetwork, UNet2DConditionModel, UNet2DModel, VQModel + from .models import AutoencoderKL, MoleculeGNN, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 1242ad6fca7f..e04037fa434e 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -19,6 +19,7 @@ from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .vae import AutoencoderKL, VQModel + from .molecule_gnn import MoleculeGNN if is_flax_available(): from .unet_2d_condition_flax import FlaxUNet2DConditionModel diff --git a/src/diffusers/models/dualencoder_gfn.py b/src/diffusers/models/molecule_gnn.py similarity index 99% rename from src/diffusers/models/dualencoder_gfn.py rename to src/diffusers/models/molecule_gnn.py index 09525dd57f7b..8cc517cbc58b 100644 --- a/src/diffusers/models/dualencoder_gfn.py +++ b/src/diffusers/models/molecule_gnn.py @@ -462,7 +462,7 @@ def eq_transform(score_d, pos, edge_index, edge_length): return score_pos -class DualEncoderEpsNetwork(ModelMixin, ConfigMixin): +class MoleculeGNN(ModelMixin, ConfigMixin): def __init__( self, hidden_dim=128, diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 19ecc62bb936..42f159fc911c 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -30,7 +30,7 @@ DDIMScheduler, DDPMPipeline, DDPMScheduler, - DualEncoderEpsNetwork, + MoleculeGNN, LDMPipeline, LDMTextToImagePipeline, PNDMPipeline, @@ -608,8 +608,8 @@ def test_output_pretrained(self): self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) -class DualEncoderEpsNetworkTests(ModelTesterMixin, unittest.TestCase): - model_class = DualEncoderEpsNetwork +class MoleculeGNNTests(ModelTesterMixin, unittest.TestCase): + model_class = MoleculeGNN @property def dummy_input(self): @@ -701,7 +701,7 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def test_from_pretrained_hub(self): - model, loading_info = DualEncoderEpsNetwork.from_pretrained( + model, loading_info = MoleculeGNN.from_pretrained( "fusing/gfn-molecule-gen-drugs", output_loading_info=True ) self.assertIsNotNone(model) @@ -713,7 +713,7 @@ def test_from_pretrained_hub(self): assert image is not None, "Make sure output is not None" def test_output_pretrained(self): - model = DualEncoderEpsNetwork.from_pretrained("fusing/gfn-molecule-gen-drugs") + model = MoleculeGNN.from_pretrained("fusing/gfn-molecule-gen-drugs") model.eval() torch.manual_seed(0) From 79f25d6db8087e383f146a17d985f8fa86315f9c Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 25 Jul 2022 14:32:14 -0700 Subject: [PATCH 15/26] update API, update tests, rename class --- src/diffusers/models/molecule_gnn.py | 63 ++++++++++++++++++---------- tests/test_modeling_utils.py | 54 ++++++------------------ 2 files changed, 54 insertions(+), 63 deletions(-) diff --git a/src/diffusers/models/molecule_gnn.py b/src/diffusers/models/molecule_gnn.py index 8cc517cbc58b..68be6e106a4e 100644 --- a/src/diffusers/models/molecule_gnn.py +++ b/src/diffusers/models/molecule_gnn.py @@ -617,10 +617,16 @@ def forward( self, sample, timestep: Union[torch.Tensor, float, int], + sigma=1.0, + global_start_sigma=0.5, + w_global=1.0, extend_order=False, extend_radius=True, clip_local=None, + clip_global=1000.0, ) -> Dict[str, torch.FloatTensor]: + + # unpack sample atom_type = sample.atom_type bond_index = sample.edge_index bond_type = sample.edge_type @@ -646,32 +652,11 @@ def forward( if clip_local is not None: node_eq_local = clip_norm(node_eq_local, limit=clip_local) - return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask, node_eq_local - - def get_residual( - self, - pos, - sigma, - model_outputs, - global_start_sigma=0.5, - w_global=1.0, - clip=1000.0, - ): - ( - edge_inv_global, - edge_inv_local, - edge_index, - edge_type, - edge_length, - local_edge_mask, - node_eq_local, - ) = model_outputs - # Global if sigma < global_start_sigma: edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float()) node_eq_global = eq_transform(edge_inv_global, pos, edge_index, edge_length) - node_eq_global = clip_norm(node_eq_global, limit=clip) + node_eq_global = clip_norm(node_eq_global, limit=clip_global) else: node_eq_global = 0 @@ -680,6 +665,40 @@ def get_residual( return {"sample": -eps_pos} + # return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask, node_eq_local + + # def get_residual( + # self, + # pos, + # sigma, + # model_outputs, + # global_start_sigma=0.5, + # w_global=1.0, + # clip=1000.0, + # ): + # ( + # edge_inv_global, + # edge_inv_local, + # edge_index, + # edge_type, + # edge_length, + # local_edge_mask, + # node_eq_local, + # ) = model_outputs + # + # # Global + # if sigma < global_start_sigma: + # edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float()) + # node_eq_global = eq_transform(edge_inv_global, pos, edge_index, edge_length) + # node_eq_global = clip_norm(node_eq_global, limit=clip) + # else: + # node_eq_global = 0 + # + # # Sum + # eps_pos = node_eq_local + node_eq_global * w_global + # return {"sample": -eps_pos} + + def clip_norm(vec, limit, p=2): norm = torch.norm(vec, dim=-1, p=2, keepdim=True) denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm)) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 42f159fc911c..81fe93f060f4 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -618,33 +618,33 @@ def dummy_input(self): class GeoDiffData: # constants corresponding to a molecule - num_nodes = 26 - num_edges = 54 + num_nodes = 6 + num_edges = 10 num_graphs = 1 # sampling torch.Generator(device=torch_device) - torch.manual_seed(0) + torch.manual_seed(3) # molecule components / properties atom_type = torch.randint(0, 6, (num_nodes * batch_size,)).to(torch_device) edge_index = torch.randint( 0, - 20, + num_edges, ( 2, num_edges * batch_size, ), ).to(torch_device) - edge_type = torch.randint(0, 20, (num_edges * batch_size,)).to(torch_device) - pos = torch.randn(num_nodes * batch_size, 3).to(torch_device) + edge_type = torch.randint(0, 5, (num_edges * batch_size,)).to(torch_device) + pos = .001*torch.randn(num_nodes * batch_size, 3).to(torch_device) batch = torch.tensor([*range(batch_size)]).repeat_interleave(num_nodes) nx = batch_size torch.manual_seed(0) noise = GeoDiffData() - return {"sample": noise, "timestep": time_step} + return {"sample": noise, "timestep": time_step, "sigma": 1.0} @property def output_shape(self): @@ -721,46 +721,18 @@ def test_output_pretrained(self): torch.cuda.manual_seed_all(0) input = self.dummy_input - sample, time_step = input["sample"], input["timestep"] + sample, time_step, sigma = input["sample"], input["timestep"], input["sigma"] with torch.no_grad(): - output = model(sample, time_step) - - # outputs correspond to molecule conformation - # variables: edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask, node_eq_local + output = model(sample, time_step, sigma=sigma)["sample"] - output_slice_0 = output[0][-6:].flatten() - # fmt: off - expected_output_slice_0 = torch.tensor([10633.3818, - 20670.0996, - 14251.5283, - 19087.3828, - 18654.5586, - 18116.2617]) - # fmt: on - - output_slice_1 = output[1][-6:].flatten() - # fmt: off - expected_output_slice_1 = torch.tensor([-946.4217, - 4.1009, - -60.8591, - 4.5929, - 192.5891, - -17.7297]) - # fmt: on - output_slice_4 = output[4][-6:].flatten() + output_slice = output[:3][:].flatten() # fmt: off - expected_output_slice_4 = torch.tensor([1.9213, - 2.2776, - 1.1385, - 0.8868, - 1.0347, - 0.8127]) + expected_output_slice = torch.tensor([ -3.7335, -7.4622, -29.5600, 16.9646, -11.2205, -32.5315, 1.2303, + 4.2985, 8.8828]) # fmt: on - self.assertTrue(torch.allclose(output_slice_0, expected_output_slice_0, rtol=1e-3)) - self.assertTrue(torch.allclose(output_slice_1, expected_output_slice_1, rtol=1e-3)) - self.assertTrue(torch.allclose(output_slice_4, expected_output_slice_4, rtol=1e-3)) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) def test_model_from_config(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() From 7a85d046b5e66ae0a0eecb9218edc4beb919e2b4 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 25 Jul 2022 14:51:42 -0700 Subject: [PATCH 16/26] clean model & tests --- src/diffusers/models/molecule_gnn.py | 87 ++-------------------------- tests/test_modeling_utils.py | 10 +--- 2 files changed, 9 insertions(+), 88 deletions(-) diff --git a/src/diffusers/models/molecule_gnn.py b/src/diffusers/models/molecule_gnn.py index 68be6e106a4e..0e681890bb9c 100644 --- a/src/diffusers/models/molecule_gnn.py +++ b/src/diffusers/models/molecule_gnn.py @@ -7,20 +7,16 @@ from torch import Tensor, nn from torch.nn import Embedding, Linear, Module, ModuleList, Sequential -from rdkit.Chem.rdchem import BondType as BT from torch_geometric.nn import MessagePassing, radius, radius_graph from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size from torch_geometric.utils import dense_to_sparse, to_dense_adj -from torch_scatter import scatter_add, scatter_mean +from torch_scatter import scatter_add from torch_sparse import SparseTensor, coalesce -from ..configuration_utils import ConfigMixin +from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin -BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} - - class MultiLayerPerceptron(nn.Module): """ Multi-layer Perceptron. Note there is no activation or dropout in the last layer. @@ -288,29 +284,6 @@ def assemble_atom_pair_feature(node_attr, edge_index, edge_attr): return h_pair -def generate_symmetric_edge_noise(num_nodes_per_graph, edge_index, edge2graph, device): - num_cum_nodes = num_nodes_per_graph.cumsum(0) # (G, ) - node_offset = num_cum_nodes - num_nodes_per_graph # (G, ) - edge_offset = node_offset[edge2graph] # (E, ) - - num_nodes_square = num_nodes_per_graph**2 # (G, ) - num_nodes_square_cumsum = num_nodes_square.cumsum(-1) # (G, ) - edge_start = num_nodes_square_cumsum - num_nodes_square # (G, ) - edge_start = edge_start[edge2graph] - - all_len = num_nodes_square_cumsum[-1] - - node_index = edge_index.t() - edge_offset.unsqueeze(-1) - node_large = node_index.max(dim=-1)[0] - node_small = node_index.min(dim=-1)[0] - undirected_edge_id = node_large * (node_large + 1) + node_small + edge_start - - symm_noise = torch.zeros(size=[all_len.item()], device=device) - symm_noise.normal_() - d_noise = symm_noise[undirected_edge_id].unsqueeze(-1) # (E, 1) - return d_noise - - def _extend_graph_order(num_nodes, edge_index, edge_type, order=3): """ Args: @@ -351,8 +324,9 @@ def get_higher_order_adj_matrix(adj, order): return order_mat - num_types = len(BOND_TYPES) - + num_types = 22 + # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} + # from rdkit.Chem.rdchem import BondType as BT N = num_nodes adj = to_dense_adj(edge_index).squeeze(0) adj_order = get_higher_order_adj_matrix(adj, order) # (N, N) @@ -368,14 +342,6 @@ def get_higher_order_adj_matrix(adj, order): # data.bond_edge_index = data.edge_index # Save original edges new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data - # [Note] This is not necessary - # data.is_bond = (data.edge_type < num_types) - - # [Note] In earlier versions, `edge_order` attribute will be added. - # However, it doesn't seem to be necessary anymore so I removed it. - # edge_index_1, data.edge_order = coalesce(new_edge_index, edge_order.long(), N, N) # modify data - # assert (data.edge_index == edge_index_1).all() - return new_edge_index, new_edge_type @@ -463,6 +429,7 @@ def eq_transform(score_d, pos, edge_index, edge_length): class MoleculeGNN(ModelMixin, ConfigMixin): + @register_to_config def __init__( self, hidden_dim=128, @@ -665,40 +632,6 @@ def forward( return {"sample": -eps_pos} - # return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask, node_eq_local - - # def get_residual( - # self, - # pos, - # sigma, - # model_outputs, - # global_start_sigma=0.5, - # w_global=1.0, - # clip=1000.0, - # ): - # ( - # edge_inv_global, - # edge_inv_local, - # edge_index, - # edge_type, - # edge_length, - # local_edge_mask, - # node_eq_local, - # ) = model_outputs - # - # # Global - # if sigma < global_start_sigma: - # edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float()) - # node_eq_global = eq_transform(edge_inv_global, pos, edge_index, edge_length) - # node_eq_global = clip_norm(node_eq_global, limit=clip) - # else: - # node_eq_global = 0 - # - # # Sum - # eps_pos = node_eq_local + node_eq_global * w_global - # return {"sample": -eps_pos} - - def clip_norm(vec, limit, p=2): norm = torch.norm(vec, dim=-1, p=2, keepdim=True) denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm)) @@ -707,11 +640,3 @@ def clip_norm(vec, limit, p=2): def is_local_edge(edge_type): return edge_type > 0 - - -def is_train_edge(edge_index, is_sidechain): - if is_sidechain is None: - return torch.ones(edge_index.size(1), device=edge_index.device).bool() - else: - is_sidechain = is_sidechain.bool() - return torch.logical_or(is_sidechain[edge_index[0]], is_sidechain[edge_index[1]]) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 81fe93f060f4..bf01bf0f8719 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -15,7 +15,6 @@ import inspect import math -import pdb import tempfile import unittest @@ -30,9 +29,9 @@ DDIMScheduler, DDPMPipeline, DDPMScheduler, - MoleculeGNN, LDMPipeline, LDMTextToImagePipeline, + MoleculeGNN, PNDMPipeline, PNDMScheduler, ScoreSdeVePipeline, @@ -637,7 +636,7 @@ class GeoDiffData: ), ).to(torch_device) edge_type = torch.randint(0, 5, (num_edges * batch_size,)).to(torch_device) - pos = .001*torch.randn(num_nodes * batch_size, 3).to(torch_device) + pos = 0.001 * torch.randn(num_nodes * batch_size, 3).to(torch_device) batch = torch.tensor([*range(batch_size)]).repeat_interleave(num_nodes) nx = batch_size @@ -701,9 +700,7 @@ def prepare_init_args_and_inputs_for_common(self): return init_dict, inputs_dict def test_from_pretrained_hub(self): - model, loading_info = MoleculeGNN.from_pretrained( - "fusing/gfn-molecule-gen-drugs", output_loading_info=True - ) + model, loading_info = MoleculeGNN.from_pretrained("fusing/gfn-molecule-gen-drugs", output_loading_info=True) self.assertIsNotNone(model) self.assertEqual(len(loading_info["missing_keys"]), 0) @@ -725,7 +722,6 @@ def test_output_pretrained(self): with torch.no_grad(): output = model(sample, time_step, sigma=sigma)["sample"] - output_slice = output[:3][:].flatten() # fmt: off expected_output_slice = torch.tensor([ -3.7335, -7.4622, -29.5600, 16.9646, -11.2205, -32.5315, 1.2303, From a90d1be482f766781ae87e0948020d0ade8f50c9 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 25 Jul 2022 17:10:52 -0700 Subject: [PATCH 17/26] add checking for imports --- src/diffusers/__init__.py | 8 + src/diffusers/utils/__init__.py | 145 ++++++++++++++++++ .../utils/dummy_torch_geometric_objects.py | 10 ++ tests/test_modeling_utils.py | 9 +- 4 files changed, 171 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/utils/dummy_torch_geometric_objects.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5fecc0554ef6..e39e3976382d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -4,6 +4,8 @@ is_onnx_available, is_scipy_available, is_torch_available, + is_inflect_available, + is_torch_geometric_available, is_transformers_available, is_unidecode_available, ) @@ -83,3 +85,9 @@ from .pipelines import FlaxStableDiffusionPipeline else: from .utils.dummy_flax_and_transformers_objects import * # noqa F403 + from .utils.dummy_transformers_objects import * + +if is_torch_geometric_available(): + from .models import MoleculeGNN +else: + from .utils.dummy_torch_geometric_objects import * diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index b63dbd2b285c..48006173243e 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -54,3 +54,148 @@ DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) +<<<<<<< HEAD +======= + + +_transformers_available = importlib.util.find_spec("transformers") is not None +try: + _transformers_version = importlib_metadata.version("transformers") + logger.debug(f"Successfully imported transformers version {_transformers_version}") +except importlib_metadata.PackageNotFoundError: + _transformers_available = False + + +_inflect_available = importlib.util.find_spec("inflect") is not None +try: + _inflect_version = importlib_metadata.version("inflect") + logger.debug(f"Successfully imported inflect version {_inflect_version}") +except importlib_metadata.PackageNotFoundError: + _inflect_available = False + + +_unidecode_available = importlib.util.find_spec("unidecode") is not None +try: + _unidecode_version = importlib_metadata.version("unidecode") + logger.debug(f"Successfully imported unidecode version {_unidecode_version}") +except importlib_metadata.PackageNotFoundError: + _unidecode_available = False + + +_modelcards_available = importlib.util.find_spec("modelcards") is not None +try: + _modelcards_version = importlib_metadata.version("modelcards") + logger.debug(f"Successfully imported modelcards version {_modelcards_version}") +except importlib_metadata.PackageNotFoundError: + _modelcards_available = False + +_torch_scatter_available = importlib.util.find_spec("torch_scatter") is not None +try: + _torch_scatter_version = importlib_metadata.version("torch_scatter") + logger.debug(f"Successfully imported torch_scatter version {_torch_scatter_version}") +except importlib_metadata.PackageNotFoundError: + _torch_scatter_available = False + +_torch_scatter_available = importlib.util.find_spec("torch_geometric") is not None +try: + _torch_geometric_version = importlib_metadata.version("torch_geometric") + logger.debug(f"Successfully imported torch_geometric version {_torch_geometric_version}") +except importlib_metadata.PackageNotFoundError: + _torch_geometric_available = False + + +def is_transformers_available(): + return _transformers_available + + +def is_inflect_available(): + return _inflect_available + + +def is_unidecode_available(): + return _unidecode_available + + +def is_modelcards_available(): + return _modelcards_available + + +def is_torch_scatter_available(): + return _torch_scatter_available + + +def is_torch_geometric_available(): + # the model source of the Molecule Generation GNN requires a specific torch geometric version + # for more info, see the original repo https://github.com/MinkaiXu/GeoDiff or our colab in readme + return _torch_geometric_version == "1.7.2" + + +class RepositoryNotFoundError(HTTPError): + """ + Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does + not have access to. + """ + + +class EntryNotFoundError(HTTPError): + """Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename.""" + + +class RevisionNotFoundError(HTTPError): + """Raised when trying to access a hf.co URL with a valid repository but an invalid revision.""" + + +TRANSFORMERS_IMPORT_ERROR = """ +{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip +install transformers` +""" + + +UNIDECODE_IMPORT_ERROR = """ +{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install +Unidecode` +""" + + +INFLECT_IMPORT_ERROR = """ +{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install +inflect` +""" + +TORCH_GEOMETRIC_IMPORT_ERROR = """ +{0} requires version 1.7.2 of torch_geometric but it was not found in your environment. You can install it with conda: +`conda install -c rusty1s pytorch-geometric=1.7.2`, given pytorch 1.8 +""" + + +BACKENDS_MAPPING = OrderedDict( + [ + ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), + ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), + ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ] +) + + +def requires_backends(obj, backends): + if not isinstance(backends, (list, tuple)): + backends = [backends] + + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError("".join(failed)) + + +class DummyObject(type): + """ + Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by + `requires_backend` each time a user tries to access any method of that class. + """ + + def __getattr__(cls, key): + if key.startswith("_"): + return super().__getattr__(cls, key) + requires_backends(cls, cls._backends) +>>>>>>> bf87817 (add checking for imports) diff --git a/src/diffusers/utils/dummy_torch_geometric_objects.py b/src/diffusers/utils/dummy_torch_geometric_objects.py new file mode 100644 index 000000000000..adb24fe6730c --- /dev/null +++ b/src/diffusers/utils/dummy_torch_geometric_objects.py @@ -0,0 +1,10 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa +from ..utils import DummyObject, requires_backends + + +class MoleculeGNN(metaclass=DummyObject): + _backends = ["torch_geometric"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch_geometric"]) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index bf01bf0f8719..7a394c07cf4d 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -31,7 +31,6 @@ DDPMScheduler, LDMPipeline, LDMTextToImagePipeline, - MoleculeGNN, PNDMPipeline, PNDMScheduler, ScoreSdeVePipeline, @@ -39,6 +38,14 @@ UNet2DModel, VQModel, ) +from diffusers.utils import is_torch_geometric_available + + +if is_torch_geometric_available(): + from diffusers import MoleculeGNN +else: + from diffusers.utils.dummy_torch_geometric_objects import * + from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.pipeline_utils import DiffusionPipeline from diffusers.testing_utils import floats_tensor, slow, torch_device From 4d2397681744cb2198e609e5b8109f3becd1ecc7 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 25 Jul 2022 17:11:50 -0700 Subject: [PATCH 18/26] minor formatting nit --- tests/test_modeling_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 7a394c07cf4d..f4e961d6ecd6 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -39,8 +39,6 @@ VQModel, ) from diffusers.utils import is_torch_geometric_available - - if is_torch_geometric_available(): from diffusers import MoleculeGNN else: From 506eb3c9d53cd48f92f6bee3cc3ab52abbc2129f Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 27 Jul 2022 08:41:28 -0700 Subject: [PATCH 19/26] add attribution of original codebase --- src/diffusers/models/molecule_gnn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/molecule_gnn.py b/src/diffusers/models/molecule_gnn.py index 0e681890bb9c..c93480fc5a19 100644 --- a/src/diffusers/models/molecule_gnn.py +++ b/src/diffusers/models/molecule_gnn.py @@ -1,4 +1,5 @@ # Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff +# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models from typing import Callable, Dict, Union import numpy as np From 4d158a3f2264916ca1f9bbae4ab70c34c75739ce Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 1 Aug 2022 16:11:06 -0700 Subject: [PATCH 20/26] style and readibility improvements --- src/diffusers/models/molecule_gnn.py | 44 ++++++++++++++++------------ tests/test_modeling_utils.py | 1 + 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/molecule_gnn.py b/src/diffusers/models/molecule_gnn.py index c93480fc5a19..2d6956d2dea7 100644 --- a/src/diffusers/models/molecule_gnn.py +++ b/src/diffusers/models/molecule_gnn.py @@ -36,8 +36,9 @@ def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0): if isinstance(activation, str): self.activation = getattr(F, activation) else: + print(f"Warning, activation passed {activation} is not string and ignored") self.activation = None - if dropout: + if dropout > 0: self.dropout = nn.Dropout(dropout) else: self.dropout = None @@ -46,9 +47,8 @@ def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0): for i in range(len(self.dims) - 1): self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1])) - def forward(self, input): + def forward(self, x): """""" - x = input for i, layer in enumerate(self.layers): x = layer(x) if i < len(self.layers) - 1: @@ -69,11 +69,11 @@ def forward(self, x): class CFConv(MessagePassing): - def __init__(self, in_channels, out_channels, num_filters, nn, cutoff, smooth): + def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth): super(CFConv, self).__init__(aggr="add") self.lin1 = Linear(in_channels, num_filters, bias=False) self.lin2 = Linear(num_filters, out_channels) - self.nn = nn + self.nn = mlp self.cutoff = cutoff self.smooth = smooth @@ -97,7 +97,7 @@ def forward(self, x, edge_index, edge_length, edge_attr): x = self.lin2(x) return x - def message(self, x_j, W): + def message(self, x_j: torch.Tensor, W) -> torch.Tensor: return x_j * W @@ -151,9 +151,14 @@ def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True): class GINEConv(MessagePassing): - def __init__(self, nn: Callable, eps: float = 0.0, train_eps: bool = False, activation="softplus", **kwargs): + """ + Custom class of the graph isomorphism operator from the "How Powerful are Graph Neural Networks? + https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation. + """ + + def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation="softplus", **kwargs): super(GINEConv, self).__init__(aggr="add", **kwargs) - self.nn = nn + self.nn = mlp self.initial_eps = eps if isinstance(activation, str): @@ -226,10 +231,11 @@ def __init__(self, hidden_dim, num_convs=3, activation="relu", short_cut=True, c def forward(self, z, edge_index, edge_attr): """ Input: - data: (torch_geometric.data.Data): batched graph node_attr: node feature tensor with shape (num_node, - hidden) edge_attr: edge feature tensor with shape (num_edge, hidden) + data: (torch_geometric.data.Data): batched graph + edge_index: bond indices of the original graph (num_node, hidden) + edge_attr: edge feature tensor with shape (num_edge, hidden) Output: - node_attr graph feature + node_feature: graph feature """ node_attr = self.node_emb(z) # (num_node, hidden) @@ -380,8 +386,6 @@ def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecifi ) composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T) - # edge_index = composed_adj.indices() - # dist = (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1) new_edge_index = composed_adj.indices() new_edge_type = composed_adj.values().long() @@ -405,8 +409,6 @@ def extend_graph_order_radius( edge_index, edge_type = _extend_graph_order( num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order ) - # edge_index_order = edge_index - # edge_type_order = edge_type if extend_radius: edge_index, edge_type = _extend_to_radius_graph( @@ -420,7 +422,11 @@ def get_distance(pos, edge_index): return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1) -def eq_transform(score_d, pos, edge_index, edge_length): +def graph_field_network(score_d, pos, edge_index, edge_length): + """ + Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. + See equations 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf + """ N = pos.size(0) dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3) score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add( @@ -616,14 +622,16 @@ def forward( ) # (E_global, 1), (E_local, 1) # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff - node_eq_local = eq_transform(edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]) + node_eq_local = graph_field_network( + edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask] + ) if clip_local is not None: node_eq_local = clip_norm(node_eq_local, limit=clip_local) # Global if sigma < global_start_sigma: edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float()) - node_eq_global = eq_transform(edge_inv_global, pos, edge_index, edge_length) + node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length) node_eq_global = clip_norm(node_eq_global, limit=clip_global) else: node_eq_global = 0 diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index f4e961d6ecd6..e896103c4afb 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -39,6 +39,7 @@ VQModel, ) from diffusers.utils import is_torch_geometric_available + if is_torch_geometric_available(): from diffusers import MoleculeGNN else: From 7e73190f61317659c38577e3820bc95651a2f981 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 3 Oct 2022 14:52:59 -0400 Subject: [PATCH 21/26] fixes post large rebase --- src/diffusers/schedulers/scheduling_ddpm.py | 6 +- src/diffusers/utils/__init__.py | 148 +-------------- src/diffusers/utils/import_utils.py | 29 +++ tests/test_modeling_utils.py | 156 ---------------- tests/test_models_gnn.py | 190 ++++++++++++++++++++ 5 files changed, 222 insertions(+), 307 deletions(-) create mode 100644 tests/test_models_gnn.py diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 3c92c5af1f9c..3d54211e5f78 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -133,18 +133,14 @@ def __init__( ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule -<<<<<<< HEAD self.betas = betas_for_alpha_bar(num_train_timesteps) -======= - self.betas = betas_for_alpha_bar(timesteps) elif beta_schedule == "sigmoid": def sigmoid(x): return 1 / (np.exp(-x) + 1) - betas = np.linspace(-6, 6, timesteps) + betas = np.linspace(-6, 6, num_train_timesteps) self.betas = sigmoid(betas) * (beta_end - beta_start) + beta_start ->>>>>>> beaa1e0 (add sigmoid beta schedule to ddpm) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 48006173243e..39b66f25c1a2 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,6 +30,7 @@ is_tf_available, is_torch_available, is_transformers_available, + is_torch_geometric_available, is_unidecode_available, requires_backends, ) @@ -53,149 +54,4 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" -HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) -<<<<<<< HEAD -======= - - -_transformers_available = importlib.util.find_spec("transformers") is not None -try: - _transformers_version = importlib_metadata.version("transformers") - logger.debug(f"Successfully imported transformers version {_transformers_version}") -except importlib_metadata.PackageNotFoundError: - _transformers_available = False - - -_inflect_available = importlib.util.find_spec("inflect") is not None -try: - _inflect_version = importlib_metadata.version("inflect") - logger.debug(f"Successfully imported inflect version {_inflect_version}") -except importlib_metadata.PackageNotFoundError: - _inflect_available = False - - -_unidecode_available = importlib.util.find_spec("unidecode") is not None -try: - _unidecode_version = importlib_metadata.version("unidecode") - logger.debug(f"Successfully imported unidecode version {_unidecode_version}") -except importlib_metadata.PackageNotFoundError: - _unidecode_available = False - - -_modelcards_available = importlib.util.find_spec("modelcards") is not None -try: - _modelcards_version = importlib_metadata.version("modelcards") - logger.debug(f"Successfully imported modelcards version {_modelcards_version}") -except importlib_metadata.PackageNotFoundError: - _modelcards_available = False - -_torch_scatter_available = importlib.util.find_spec("torch_scatter") is not None -try: - _torch_scatter_version = importlib_metadata.version("torch_scatter") - logger.debug(f"Successfully imported torch_scatter version {_torch_scatter_version}") -except importlib_metadata.PackageNotFoundError: - _torch_scatter_available = False - -_torch_scatter_available = importlib.util.find_spec("torch_geometric") is not None -try: - _torch_geometric_version = importlib_metadata.version("torch_geometric") - logger.debug(f"Successfully imported torch_geometric version {_torch_geometric_version}") -except importlib_metadata.PackageNotFoundError: - _torch_geometric_available = False - - -def is_transformers_available(): - return _transformers_available - - -def is_inflect_available(): - return _inflect_available - - -def is_unidecode_available(): - return _unidecode_available - - -def is_modelcards_available(): - return _modelcards_available - - -def is_torch_scatter_available(): - return _torch_scatter_available - - -def is_torch_geometric_available(): - # the model source of the Molecule Generation GNN requires a specific torch geometric version - # for more info, see the original repo https://github.com/MinkaiXu/GeoDiff or our colab in readme - return _torch_geometric_version == "1.7.2" - - -class RepositoryNotFoundError(HTTPError): - """ - Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does - not have access to. - """ - - -class EntryNotFoundError(HTTPError): - """Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename.""" - - -class RevisionNotFoundError(HTTPError): - """Raised when trying to access a hf.co URL with a valid repository but an invalid revision.""" - - -TRANSFORMERS_IMPORT_ERROR = """ -{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip -install transformers` -""" - - -UNIDECODE_IMPORT_ERROR = """ -{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install -Unidecode` -""" - - -INFLECT_IMPORT_ERROR = """ -{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install -inflect` -""" - -TORCH_GEOMETRIC_IMPORT_ERROR = """ -{0} requires version 1.7.2 of torch_geometric but it was not found in your environment. You can install it with conda: -`conda install -c rusty1s pytorch-geometric=1.7.2`, given pytorch 1.8 -""" - - -BACKENDS_MAPPING = OrderedDict( - [ - ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), - ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), - ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), - ] -) - - -def requires_backends(obj, backends): - if not isinstance(backends, (list, tuple)): - backends = [backends] - - name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ - checks = (BACKENDS_MAPPING[backend] for backend in backends) - failed = [msg.format(name) for available, msg in checks if not available()] - if failed: - raise ImportError("".join(failed)) - - -class DummyObject(type): - """ - Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by - `requires_backend` each time a user tries to access any method of that class. - """ - - def __getattr__(cls, key): - if key.startswith("_"): - return super().__getattr__(cls, key) - requires_backends(cls, cls._backends) ->>>>>>> bf87817 (add checking for imports) +HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) \ No newline at end of file diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index de344d074da0..d81cca75241d 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -159,6 +159,20 @@ except importlib_metadata.PackageNotFoundError: _scipy_available = False +_torch_scatter_available = importlib.util.find_spec("torch_scatter") is not None +try: + _torch_scatter_version = importlib_metadata.version("torch_scatter") + logger.debug(f"Successfully imported torch_scatter version {_torch_scatter_version}") +except importlib_metadata.PackageNotFoundError: + _torch_scatter_available = False + +_torch_geometric_available = importlib.util.find_spec("torch_geometric") is not None +try: + _torch_geometric_version = importlib_metadata.version("torch_geometric") + logger.debug(f"Successfully imported torch_geometric version {_torch_geometric_version}") +except importlib_metadata.PackageNotFoundError: + _torch_geometric_available = False + def is_torch_available(): return _torch_available @@ -195,6 +209,14 @@ def is_onnx_available(): def is_scipy_available(): return _scipy_available +def is_torch_scatter_available(): + return _torch_scatter_available + + +def is_torch_geometric_available(): + # the model source of the Molecule Generation GNN requires a specific torch geometric version + # for more info, see the original repo https://github.com/MinkaiXu/GeoDiff or our colab in readme + return _torch_geometric_version == "1.7.2" # docstyle-ignore FLAX_IMPORT_ERROR = """ @@ -244,6 +266,13 @@ def is_scipy_available(): Unidecode` """ +# docstyle-ignore +TORCH_GEOMETRIC_IMPORT_ERROR = """ +{0} requires version 1.7.2 of torch_geometric but it was not found in your environment. You can install it with conda: +`conda install -c rusty1s pytorch-geometric=1.7.2`, given pytorch 1.8 +""" + + BACKENDS_MAPPING = OrderedDict( [ diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index e896103c4afb..327d593cebf7 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -613,163 +613,7 @@ def test_output_pretrained(self): self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) -class MoleculeGNNTests(ModelTesterMixin, unittest.TestCase): - model_class = MoleculeGNN - @property - def dummy_input(self): - batch_size = 2 - time_step = 10 - - class GeoDiffData: - # constants corresponding to a molecule - num_nodes = 6 - num_edges = 10 - num_graphs = 1 - - # sampling - torch.Generator(device=torch_device) - torch.manual_seed(3) - - # molecule components / properties - atom_type = torch.randint(0, 6, (num_nodes * batch_size,)).to(torch_device) - edge_index = torch.randint( - 0, - num_edges, - ( - 2, - num_edges * batch_size, - ), - ).to(torch_device) - edge_type = torch.randint(0, 5, (num_edges * batch_size,)).to(torch_device) - pos = 0.001 * torch.randn(num_nodes * batch_size, 3).to(torch_device) - batch = torch.tensor([*range(batch_size)]).repeat_interleave(num_nodes) - nx = batch_size - - torch.manual_seed(0) - noise = GeoDiffData() - - return {"sample": noise, "timestep": time_step, "sigma": 1.0} - - @property - def output_shape(self): - # subset of shapes for dummy input - class GeoDiffShapes: - shape_0 = torch.Size([1305, 1]) - shape_1 = torch.Size([92, 1]) - - return GeoDiffShapes() - - # training not implemented for this model yet - def test_training(self): - pass - - def test_ema_training(self): - pass - - def test_determinism(self): - # TODO - pass - - def test_output(self): - def test_output(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output["sample"] - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - shapes = self.output_shapes() - self.assertEqual(output[0].shape, shapes.shape_0, "Input and output shapes do not match") - self.assertEqual(output[1].shape, shapes.shape_1, "Input and output shapes do not match") - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "hidden_dim": 128, - "num_convs": 6, - "num_convs_local": 4, - "cutoff": 10.0, - "mlp_act": "relu", - "edge_order": 3, - "edge_encoder": "mlp", - "smooth_conv": True, - } - - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - model, loading_info = MoleculeGNN.from_pretrained("fusing/gfn-molecule-gen-drugs", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = MoleculeGNN.from_pretrained("fusing/gfn-molecule-gen-drugs") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - input = self.dummy_input - sample, time_step, sigma = input["sample"], input["timestep"], input["sigma"] - with torch.no_grad(): - output = model(sample, time_step, sigma=sigma)["sample"] - - output_slice = output[:3][:].flatten() - # fmt: off - expected_output_slice = torch.tensor([ -3.7335, -7.4622, -29.5600, 16.9646, -11.2205, -32.5315, 1.2303, - 4.2985, 8.8828]) - # fmt: on - - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) - - def test_model_from_config(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - # test if the model can be loaded from the config - # and has all the expected shape - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_config(tmpdirname) - new_model = self.model_class.from_config(tmpdirname) - new_model.to(torch_device) - new_model.eval() - - # check if all paramters shape are the same - for param_name in model.state_dict().keys(): - param_1 = model.state_dict()[param_name] - param_2 = new_model.state_dict()[param_name] - self.assertEqual(param_1.shape, param_2.shape) - - with torch.no_grad(): - output_1 = model(**inputs_dict) - - if isinstance(output_1, dict): - output_1 = output_1["sample"] - - output_2 = new_model(**inputs_dict) - - if isinstance(output_2, dict): - output_2 = output_2["sample"] - - self.assertEqual(output_1[0].shape, output_2[0].shape) - self.assertEqual(output_1[1].shape, output_2[1].shape) class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): diff --git a/tests/test_models_gnn.py b/tests/test_models_gnn.py new file mode 100644 index 000000000000..1db7a49e076e --- /dev/null +++ b/tests/test_models_gnn.py @@ -0,0 +1,190 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import tempfile + +import torch + +from diffusers.testing_utils import floats_tensor, slow, torch_device + +from .test_modeling_common import ModelTesterMixin +from diffusers.utils import is_torch_geometric_available + +if is_torch_geometric_available(): + from diffusers import MoleculeGNN +else: + from diffusers.utils.dummy_torch_geometric_objects import * + + +torch.backends.cuda.matmul.allow_tf32 = False + +class MoleculeGNNTests(ModelTesterMixin, unittest.TestCase): + model_class = MoleculeGNN + + @property + def dummy_input(self): + batch_size = 2 + time_step = 10 + + class GeoDiffData: + # constants corresponding to a molecule + num_nodes = 6 + num_edges = 10 + num_graphs = 1 + + # sampling + torch.Generator(device=torch_device) + torch.manual_seed(3) + + # molecule components / properties + atom_type = torch.randint(0, 6, (num_nodes * batch_size,)).to(torch_device) + edge_index = torch.randint( + 0, + num_edges, + ( + 2, + num_edges * batch_size, + ), + ).to(torch_device) + edge_type = torch.randint(0, 5, (num_edges * batch_size,)).to(torch_device) + pos = 0.001 * torch.randn(num_nodes * batch_size, 3).to(torch_device) + batch = torch.tensor([*range(batch_size)]).repeat_interleave(num_nodes) + nx = batch_size + + torch.manual_seed(0) + noise = GeoDiffData() + + return {"sample": noise, "timestep": time_step, "sigma": 1.0} + + @property + def output_shape(self): + # subset of shapes for dummy input + class GeoDiffShapes: + shape_0 = torch.Size([1305, 1]) + shape_1 = torch.Size([92, 1]) + + return GeoDiffShapes() + + # training not implemented for this model yet + def test_training(self): + pass + + def test_ema_training(self): + pass + + def test_determinism(self): + # TODO + pass + + def test_output(self): + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output["sample"] + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + shapes = self.output_shapes() + self.assertEqual(output[0].shape, shapes.shape_0, "Input and output shapes do not match") + self.assertEqual(output[1].shape, shapes.shape_1, "Input and output shapes do not match") + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "hidden_dim": 128, + "num_convs": 6, + "num_convs_local": 4, + "cutoff": 10.0, + "mlp_act": "relu", + "edge_order": 3, + "edge_encoder": "mlp", + "smooth_conv": True, + } + + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = MoleculeGNN.from_pretrained("fusing/gfn-molecule-gen-drugs", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = MoleculeGNN.from_pretrained("fusing/gfn-molecule-gen-drugs") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + input = self.dummy_input + sample, time_step, sigma = input["sample"], input["timestep"], input["sigma"] + with torch.no_grad(): + output = model(sample, time_step, sigma=sigma)["sample"] + + output_slice = output[:3][:].flatten() + # fmt: off + expected_output_slice = torch.tensor([ -3.7335, -7.4622, -29.5600, 16.9646, -11.2205, -32.5315, 1.2303, + 4.2985, 8.8828]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) + + def test_model_from_config(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + # test if the model can be loaded from the config + # and has all the expected shape + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_config(tmpdirname) + new_model = self.model_class.from_config(tmpdirname) + new_model.to(torch_device) + new_model.eval() + + # check if all paramters shape are the same + for param_name in model.state_dict().keys(): + param_1 = model.state_dict()[param_name] + param_2 = new_model.state_dict()[param_name] + self.assertEqual(param_1.shape, param_2.shape) + + with torch.no_grad(): + output_1 = model(**inputs_dict) + + if isinstance(output_1, dict): + output_1 = output_1["sample"] + + output_2 = new_model(**inputs_dict) + + if isinstance(output_2, dict): + output_2 = output_2["sample"] + + self.assertEqual(output_1[0].shape, output_2[0].shape) + self.assertEqual(output_1[1].shape, output_2[1].shape) \ No newline at end of file From 682eb472b3dcb4425fba4edbf8b2f50a53528ec1 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 3 Oct 2022 15:09:13 -0400 Subject: [PATCH 22/26] fix tests --- src/diffusers/models/molecule_gnn.py | 33 ++++++++++++++++++++++++++-- tests/test_models_gnn.py | 6 ++++- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/molecule_gnn.py b/src/diffusers/models/molecule_gnn.py index 2d6956d2dea7..ebc4fcf6692d 100644 --- a/src/diffusers/models/molecule_gnn.py +++ b/src/diffusers/models/molecule_gnn.py @@ -1,6 +1,8 @@ # Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff # Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models from typing import Callable, Dict, Union +from dataclasses import dataclass +from typing import Optional, Tuple, Union import numpy as np import torch @@ -16,6 +18,17 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin +from ..utils import BaseOutput + +@dataclass +class MoleculeGNNOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states output. Output of last layer of model. + """ + + sample: torch.FloatTensor class MultiLayerPerceptron(nn.Module): @@ -591,6 +604,7 @@ def forward( self, sample, timestep: Union[torch.Tensor, float, int], + return_dict: bool = True, sigma=1.0, global_start_sigma=0.5, w_global=1.0, @@ -598,7 +612,18 @@ def forward( extend_radius=True, clip_local=None, clip_global=1000.0, - ) -> Dict[str, torch.FloatTensor]: + ) -> Union[MoleculeGNNOutput, Tuple]: + r""" + Args: + sample: packed torch geometric object + timestep (`torch.FloatTensor` or `float` or `int): TODO verify type and shape (batch) timesteps + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple. + + Returns: + [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ # unpack sample atom_type = sample.atom_type @@ -638,7 +663,11 @@ def forward( # Sum eps_pos = node_eq_local + node_eq_global * w_global - return {"sample": -eps_pos} + + if not return_dict: + return (-eps_pos,) + + return MoleculeGNNOutput(sample=torch.FloatTensor(-eps_pos)) def clip_norm(vec, limit, p=2): diff --git a/tests/test_models_gnn.py b/tests/test_models_gnn.py index 1db7a49e076e..486f80c81890 100644 --- a/tests/test_models_gnn.py +++ b/tests/test_models_gnn.py @@ -187,4 +187,8 @@ def test_model_from_config(self): output_2 = output_2["sample"] self.assertEqual(output_1[0].shape, output_2[0].shape) - self.assertEqual(output_1[1].shape, output_2[1].shape) \ No newline at end of file + self.assertEqual(output_1[1].shape, output_2[1].shape) + + def test_forward_with_norm_groups(self): + # not implemented for this model + pass \ No newline at end of file From 2ef3727146a0a6ad19a8c1f58de09f3f0e161d86 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 3 Oct 2022 18:36:07 -0400 Subject: [PATCH 23/26] make quality and style --- src/diffusers/__init__.py | 7 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/molecule_gnn.py | 25 +- src/diffusers/utils/__init__.py | 4 +- src/diffusers/utils/import_utils.py | 3 +- tests/test_modeling_utils.py | 891 --------------------------- tests/test_models_gnn.py | 16 +- 7 files changed, 28 insertions(+), 920 deletions(-) delete mode 100755 tests/test_modeling_utils.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e39e3976382d..d2d20f319149 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -4,7 +4,6 @@ is_onnx_available, is_scipy_available, is_torch_available, - is_inflect_available, is_torch_geometric_available, is_transformers_available, is_unidecode_available, @@ -20,7 +19,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, MoleculeGNN, UNet2DConditionModel, UNet2DModel, VQModel + from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, @@ -85,9 +84,9 @@ from .pipelines import FlaxStableDiffusionPipeline else: from .utils.dummy_flax_and_transformers_objects import * # noqa F403 - from .utils.dummy_transformers_objects import * + from .utils.dummy_transformers_objects import * # noqa F403 if is_torch_geometric_available(): from .models import MoleculeGNN else: - from .utils.dummy_torch_geometric_objects import * + from .utils.dummy_torch_geometric_objects import * # noqa F403 diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e04037fa434e..a638d8c58468 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,10 +16,10 @@ if is_torch_available(): + from .molecule_gnn import MoleculeGNN from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .vae import AutoencoderKL, VQModel - from .molecule_gnn import MoleculeGNN if is_flax_available(): from .unet_2d_condition_flax import FlaxUNet2DConditionModel diff --git a/src/diffusers/models/molecule_gnn.py b/src/diffusers/models/molecule_gnn.py index ebc4fcf6692d..be2f035ace09 100644 --- a/src/diffusers/models/molecule_gnn.py +++ b/src/diffusers/models/molecule_gnn.py @@ -1,8 +1,7 @@ # Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff # Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models -from typing import Callable, Dict, Union from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Tuple, Union import numpy as np import torch @@ -20,6 +19,7 @@ from ..modeling_utils import ModelMixin from ..utils import BaseOutput + @dataclass class MoleculeGNNOutput(BaseOutput): """ @@ -244,9 +244,8 @@ def __init__(self, hidden_dim, num_convs=3, activation="relu", short_cut=True, c def forward(self, z, edge_index, edge_attr): """ Input: - data: (torch_geometric.data.Data): batched graph - edge_index: bond indices of the original graph (num_node, hidden) - edge_attr: edge feature tensor with shape (num_edge, hidden) + data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node, + hidden) edge_attr: edge feature tensor with shape (num_edge, hidden) Output: node_feature: graph feature """ @@ -437,8 +436,8 @@ def get_distance(pos, edge_index): def graph_field_network(score_d, pos, edge_index, edge_length): """ - Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. - See equations 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf + Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations + 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf """ N = pos.size(0) dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3) @@ -561,13 +560,13 @@ def _forward( edge_length=edge_length, edge_attr=edge_attr_global, ) - ## Assemble pairwise features + # Assemble pairwise features h_pair_global = assemble_atom_pair_feature( node_attr=node_attr_global, edge_index=edge_index, edge_attr=edge_attr_global, ) # (E_global, 2H) - ## Invariant features of edges (radius graph, global) + # Invariant features of edges (radius graph, global) edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1) # Encoding local @@ -580,14 +579,14 @@ def _forward( edge_index=edge_index[:, local_edge_mask], edge_attr=edge_attr_local[local_edge_mask], ) - ## Assemble pairwise features + # Assemble pairwise features h_pair_local = assemble_atom_pair_feature( node_attr=node_attr_local, edge_index=edge_index[:, local_edge_mask], edge_attr=edge_attr_local[local_edge_mask], ) # (E_local, 2H) - ## Invariant features of edges (bond graph, local) + # Invariant features of edges (bond graph, local) if isinstance(sigma_edge, torch.Tensor): edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * ( 1.0 / sigma_edge[local_edge_mask] @@ -621,8 +620,8 @@ def forward( Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple. Returns: - [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if `return_dict` is True, - otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # unpack sample diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index a58b7b2be2b5..9f38bc215ff6 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,8 +30,8 @@ is_scipy_available, is_tf_available, is_torch_available, - is_transformers_available, is_torch_geometric_available, + is_transformers_available, is_unidecode_available, requires_backends, ) @@ -56,4 +56,4 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" -HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) \ No newline at end of file +HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index d81cca75241d..217fb8d1a01c 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -209,6 +209,7 @@ def is_onnx_available(): def is_scipy_available(): return _scipy_available + def is_torch_scatter_available(): return _torch_scatter_available @@ -218,6 +219,7 @@ def is_torch_geometric_available(): # for more info, see the original repo https://github.com/MinkaiXu/GeoDiff or our colab in readme return _torch_geometric_version == "1.7.2" + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -273,7 +275,6 @@ def is_torch_geometric_available(): """ - BACKENDS_MAPPING = OrderedDict( [ ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py deleted file mode 100755 index 327d593cebf7..000000000000 --- a/tests/test_modeling_utils.py +++ /dev/null @@ -1,891 +0,0 @@ -# coding=utf-8 -# Copyright 2022 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import math -import tempfile -import unittest - -import numpy as np -import torch - -import PIL -from diffusers import UNet2DConditionModel # noqa: F401 TODO(Patrick) - need to write tests with it -from diffusers import ( - AutoencoderKL, - DDIMPipeline, - DDIMScheduler, - DDPMPipeline, - DDPMScheduler, - LDMPipeline, - LDMTextToImagePipeline, - PNDMPipeline, - PNDMScheduler, - ScoreSdeVePipeline, - ScoreSdeVeScheduler, - UNet2DModel, - VQModel, -) -from diffusers.utils import is_torch_geometric_available - -if is_torch_geometric_available(): - from diffusers import MoleculeGNN -else: - from diffusers.utils.dummy_torch_geometric_objects import * - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.testing_utils import floats_tensor, slow, torch_device -from diffusers.training_utils import EMAModel - - -torch.backends.cuda.matmul.allow_tf32 = False - - -class SampleObject(ConfigMixin): - config_name = "config.json" - - @register_to_config - def __init__( - self, - a=2, - b=5, - c=(2, 5), - d="for diffusion", - e=[1, 3], - ): - pass - - -class ConfigTester(unittest.TestCase): - def test_load_not_from_mixin(self): - with self.assertRaises(ValueError): - ConfigMixin.from_config("dummy_path") - - def test_register_to_config(self): - obj = SampleObject() - config = obj.config - assert config["a"] == 2 - assert config["b"] == 5 - assert config["c"] == (2, 5) - assert config["d"] == "for diffusion" - assert config["e"] == [1, 3] - - # init ignore private arguments - obj = SampleObject(_name_or_path="lalala") - config = obj.config - assert config["a"] == 2 - assert config["b"] == 5 - assert config["c"] == (2, 5) - assert config["d"] == "for diffusion" - assert config["e"] == [1, 3] - - # can override default - obj = SampleObject(c=6) - config = obj.config - assert config["a"] == 2 - assert config["b"] == 5 - assert config["c"] == 6 - assert config["d"] == "for diffusion" - assert config["e"] == [1, 3] - - # can use positional arguments. - obj = SampleObject(1, c=6) - config = obj.config - assert config["a"] == 1 - assert config["b"] == 5 - assert config["c"] == 6 - assert config["d"] == "for diffusion" - assert config["e"] == [1, 3] - - def test_save_load(self): - obj = SampleObject() - config = obj.config - - assert config["a"] == 2 - assert config["b"] == 5 - assert config["c"] == (2, 5) - assert config["d"] == "for diffusion" - assert config["e"] == [1, 3] - - with tempfile.TemporaryDirectory() as tmpdirname: - obj.save_config(tmpdirname) - new_obj = SampleObject.from_config(tmpdirname) - new_config = new_obj.config - - # unfreeze configs - config = dict(config) - new_config = dict(new_config) - - assert config.pop("c") == (2, 5) # instantiated as tuple - assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json - assert config == new_config - - -class ModelTesterMixin: - def test_from_pretrained_save_pretrained(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - new_model = self.model_class.from_pretrained(tmpdirname) - new_model.to(torch_device) - - with torch.no_grad(): - image = model(**inputs_dict) - if isinstance(image, dict): - image = image["sample"] - - new_image = new_model(**inputs_dict) - - if isinstance(new_image, dict): - new_image = new_image["sample"] - - max_diff = (image - new_image).abs().sum().item() - self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") - - def test_determinism(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - with torch.no_grad(): - first = model(**inputs_dict) - if isinstance(first, dict): - first = first["sample"] - - second = model(**inputs_dict) - if isinstance(second, dict): - second = second["sample"] - - out_1 = first.cpu().numpy() - out_2 = second.cpu().numpy() - out_1 = out_1[~np.isnan(out_1)] - out_2 = out_2[~np.isnan(out_2)] - max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) - - def test_output(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output["sample"] - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_forward_signature(self): - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - signature = inspect.signature(model.forward) - # signature.parameters is an OrderedDict => so arg_names order is deterministic - arg_names = [*signature.parameters.keys()] - - expected_arg_names = ["sample", "timestep"] - self.assertListEqual(arg_names[:2], expected_arg_names) - - def test_model_from_config(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - # test if the model can be loaded from the config - # and has all the expected shape - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_config(tmpdirname) - new_model = self.model_class.from_config(tmpdirname) - new_model.to(torch_device) - new_model.eval() - - # check if all paramters shape are the same - for param_name in model.state_dict().keys(): - param_1 = model.state_dict()[param_name] - param_2 = new_model.state_dict()[param_name] - self.assertEqual(param_1.shape, param_2.shape) - - with torch.no_grad(): - output_1 = model(**inputs_dict) - - if isinstance(output_1, dict): - output_1 = output_1["sample"] - - output_2 = new_model(**inputs_dict) - - if isinstance(output_2, dict): - output_2 = output_2["sample"] - - self.assertEqual(output_1.shape, output_2.shape) - - def test_training(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.train() - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output["sample"] - - noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) - loss = torch.nn.functional.mse_loss(output, noise) - loss.backward() - - def test_ema_training(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model.train() - ema_model = EMAModel(model, device=torch_device) - - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output["sample"] - - noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) - loss = torch.nn.functional.mse_loss(output, noise) - loss.backward() - ema_model.step(model) - - -class UnetModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNet2DModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - - return {"sample": noise, "timestep": time_step} - - @property - def input_shape(self): - return (3, 32, 32) - - @property - def output_shape(self): - return (3, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": (32, 64), - "down_block_types": ("DownBlock2D", "AttnDownBlock2D"), - "up_block_types": ("AttnUpBlock2D", "UpBlock2D"), - "attention_head_dim": None, - "out_channels": 3, - "in_channels": 3, - "layers_per_block": 2, - "sample_size": 32, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - -# TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints -# def test_output_pretrained(self): -# model = UNet2DModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet") -# model.eval() -# -# torch.manual_seed(0) -# if torch.cuda.is_available(): -# torch.cuda.manual_seed_all(0) -# -# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) -# time_step = torch.tensor([10]) -# -# with torch.no_grad(): -# output = model(noise, time_step)["sample"] -# -# output_slice = output[0, -1, -3:, -3:].flatten() -# fmt: off -# expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053]) -# fmt: on -# self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) - - -class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNet2DModel - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - - return {"sample": noise, "timestep": time_step} - - @property - def input_shape(self): - return (4, 32, 32) - - @property - def output_shape(self): - return (4, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "sample_size": 32, - "in_channels": 4, - "out_channels": 4, - "layers_per_block": 2, - "block_out_channels": (32, 64), - "attention_head_dim": 32, - "down_block_types": ("DownBlock2D", "DownBlock2D"), - "up_block_types": ("UpBlock2D", "UpBlock2D"), - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) - - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input)["sample"] - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) - time_step = torch.tensor([10] * noise.shape[0]) - - with torch.no_grad(): - output = model(noise, time_step)["sample"] - - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800]) - # fmt: on - - self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) - - -# TODO(Patrick) - Re-add this test after having cleaned up LDM -# def test_output_pretrained_spatial_transformer(self): -# model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial") -# model.eval() -# -# torch.manual_seed(0) -# if torch.cuda.is_available(): -# torch.cuda.manual_seed_all(0) -# -# noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) -# context = torch.ones((1, 16, 64), dtype=torch.float32) -# time_step = torch.tensor([10] * noise.shape[0]) -# -# with torch.no_grad(): -# output = model(noise, time_step, context=context) -# -# output_slice = output[0, -1, -3:, -3:].flatten() -# fmt: off -# expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890]) -# fmt: on -# -# self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) -# - - -class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): - model_class = UNet2DModel - - @property - def dummy_input(self, sizes=(32, 32)): - batch_size = 4 - num_channels = 3 - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [10]).to(torch_device) - - return {"sample": noise, "timestep": time_step} - - @property - def input_shape(self): - return (3, 32, 32) - - @property - def output_shape(self): - return (3, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": [32, 64, 64, 64], - "in_channels": 3, - "layers_per_block": 1, - "out_channels": 3, - "time_embedding_type": "fourier", - "norm_eps": 1e-6, - "mid_block_scale_factor": math.sqrt(2.0), - "norm_num_groups": None, - "down_block_types": [ - "SkipDownBlock2D", - "AttnSkipDownBlock2D", - "SkipDownBlock2D", - "SkipDownBlock2D", - ], - "up_block_types": [ - "SkipUpBlock2D", - "SkipUpBlock2D", - "AttnSkipUpBlock2D", - "SkipUpBlock2D", - ], - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_from_pretrained_hub(self): - model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - inputs = self.dummy_input - noise = floats_tensor((4, 3) + (256, 256)).to(torch_device) - inputs["sample"] = noise - image = model(**inputs) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained_ve_mid(self): - model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256") - model.to(torch_device) - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - batch_size = 4 - num_channels = 3 - sizes = (256, 256) - - noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) - - with torch.no_grad(): - output = model(noise, time_step)["sample"] - - output_slice = output[0, -3:, -3:, -1].flatten().cpu() - # fmt: off - expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114]) - # fmt: on - - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) - - def test_output_pretrained_ve_large(self): - model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update") - model.to(torch_device) - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) - - with torch.no_grad(): - output = model(noise, time_step)["sample"] - - output_slice = output[0, -3:, -3:, -1].flatten().cpu() - # fmt: off - expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256]) - # fmt: on - - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) - - -class VQModelTests(ModelTesterMixin, unittest.TestCase): - model_class = VQModel - - @property - def dummy_input(self, sizes=(32, 32)): - batch_size = 4 - num_channels = 3 - - image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - - return {"sample": image} - - @property - def input_shape(self): - return (3, 32, 32) - - @property - def output_shape(self): - return (3, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "ch": 64, - "out_ch": 3, - "num_res_blocks": 1, - "in_channels": 3, - "attn_resolutions": [], - "resolution": 32, - "z_channels": 3, - "n_embed": 256, - "embed_dim": 3, - "sane_index_shape": False, - "ch_mult": (1,), - "double_z": False, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_forward_signature(self): - pass - - def test_training(self): - pass - - def test_from_pretrained_hub(self): - model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = VQModel.from_pretrained("fusing/vqgan-dummy") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) - with torch.no_grad(): - output = model(image) - - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218]) - # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) - - - - - -class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): - model_class = AutoencoderKL - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - - return {"sample": image} - - @property - def input_shape(self): - return (3, 32, 32) - - @property - def output_shape(self): - return (3, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "ch": 64, - "ch_mult": (1,), - "embed_dim": 4, - "in_channels": 3, - "attn_resolutions": [], - "num_res_blocks": 1, - "out_ch": 3, - "resolution": 32, - "z_channels": 4, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_forward_signature(self): - pass - - def test_training(self): - pass - - def test_from_pretrained_hub(self): - model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) - - model.to(torch_device) - image = model(**self.dummy_input) - - assert image is not None, "Make sure output is not None" - - def test_output_pretrained(self): - model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") - model.eval() - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) - with torch.no_grad(): - output = model(image, sample_posterior=True) - - output_slice = output[0, -1, -3:, -3:].flatten() - # fmt: off - expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750]) - # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) - - -class PipelineTesterMixin(unittest.TestCase): - def test_from_pretrained_save_pretrained(self): - # 1. Load models - model = UNet2DModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, - in_channels=3, - out_channels=3, - down_block_types=("DownBlock2D", "AttnDownBlock2D"), - up_block_types=("AttnUpBlock2D", "UpBlock2D"), - ) - schedular = DDPMScheduler(num_train_timesteps=10) - - ddpm = DDPMPipeline(model, schedular) - - with tempfile.TemporaryDirectory() as tmpdirname: - ddpm.save_pretrained(tmpdirname) - new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) - - generator = torch.manual_seed(0) - - image = ddpm(generator=generator, output_type="numpy")["sample"] - generator = generator.manual_seed(0) - new_image = new_ddpm(generator=generator, output_type="numpy")["sample"] - - assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" - - @slow - def test_from_pretrained_hub(self): - model_path = "google/ddpm-cifar10-32" - - ddpm = DDPMPipeline.from_pretrained(model_path) - ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) - - ddpm.scheduler.num_timesteps = 10 - ddpm_from_hub.scheduler.num_timesteps = 10 - - generator = torch.manual_seed(0) - - image = ddpm(generator=generator, output_type="numpy")["sample"] - generator = generator.manual_seed(0) - new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"] - - assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" - - @slow - def test_output_format(self): - model_path = "google/ddpm-cifar10-32" - - pipe = DDIMPipeline.from_pretrained(model_path) - - generator = torch.manual_seed(0) - images = pipe(generator=generator, output_type="numpy")["sample"] - assert images.shape == (1, 32, 32, 3) - assert isinstance(images, np.ndarray) - - images = pipe(generator=generator, output_type="pil")["sample"] - assert isinstance(images, list) - assert len(images) == 1 - assert isinstance(images[0], PIL.Image.Image) - - # use PIL by default - images = pipe(generator=generator)["sample"] - assert isinstance(images, list) - assert isinstance(images[0], PIL.Image.Image) - - @slow - def test_ddpm_cifar10(self): - model_id = "google/ddpm-cifar10-32" - - unet = UNet2DModel.from_pretrained(model_id) - scheduler = DDPMScheduler.from_config(model_id) - scheduler = scheduler.set_format("pt") - - ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) - - generator = torch.manual_seed(0) - image = ddpm(generator=generator, output_type="numpy")["sample"] - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - def test_ddim_lsun(self): - model_id = "google/ddpm-ema-bedroom-256" - - unet = UNet2DModel.from_pretrained(model_id) - scheduler = DDIMScheduler.from_config(model_id) - - ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) - - generator = torch.manual_seed(0) - image = ddpm(generator=generator, output_type="numpy")["sample"] - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - def test_ddim_cifar10(self): - model_id = "google/ddpm-cifar10-32" - - unet = UNet2DModel.from_pretrained(model_id) - scheduler = DDIMScheduler(tensor_format="pt") - - ddim = DDIMPipeline(unet=unet, scheduler=scheduler) - - generator = torch.manual_seed(0) - image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"] - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - def test_pndm_cifar10(self): - model_id = "google/ddpm-cifar10-32" - - unet = UNet2DModel.from_pretrained(model_id) - scheduler = PNDMScheduler(tensor_format="pt") - - pndm = PNDMPipeline(unet=unet, scheduler=scheduler) - generator = torch.manual_seed(0) - image = pndm(generator=generator, output_type="numpy")["sample"] - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - def test_ldm_text2img(self): - ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") - - prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) - image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[ - "sample" - ] - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - def test_ldm_text2img_fast(self): - ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") - - prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) - image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"] - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - def test_score_sde_ve_pipeline(self): - model = UNet2DModel.from_pretrained("google/ncsnpp-church-256") - - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256") - - sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) - - torch.manual_seed(0) - image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"] - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @slow - def test_ldm_uncond(self): - ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256") - - generator = torch.manual_seed(0) - image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"] - - image_slice = image[0, -3:, -3:, -1] - - assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/test_models_gnn.py b/tests/test_models_gnn.py index 486f80c81890..a0d87b835b44 100644 --- a/tests/test_models_gnn.py +++ b/tests/test_models_gnn.py @@ -13,24 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import tempfile +import unittest import torch -from diffusers.testing_utils import floats_tensor, slow, torch_device +from diffusers.utils import is_torch_geometric_available +from diffusers.utils.testing_utils import torch_device from .test_modeling_common import ModelTesterMixin -from diffusers.utils import is_torch_geometric_available + if is_torch_geometric_available(): from diffusers import MoleculeGNN else: - from diffusers.utils.dummy_torch_geometric_objects import * + from diffusers.utils.dummy_torch_geometric_objects import * # noqa F403 torch.backends.cuda.matmul.allow_tf32 = False + class MoleculeGNNTests(ModelTesterMixin, unittest.TestCase): model_class = MoleculeGNN @@ -103,7 +105,6 @@ def test_output(self): output = output["sample"] self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape shapes = self.output_shapes() self.assertEqual(output[0].shape, shapes.shape_0, "Input and output shapes do not match") self.assertEqual(output[1].shape, shapes.shape_1, "Input and output shapes do not match") @@ -148,8 +149,7 @@ def test_output_pretrained(self): output_slice = output[:3][:].flatten() # fmt: off - expected_output_slice = torch.tensor([ -3.7335, -7.4622, -29.5600, 16.9646, -11.2205, -32.5315, 1.2303, - 4.2985, 8.8828]) + expected_output_slice = torch.tensor([-3.7335, -7.4622, -29.5600, 16.9646, -11.2205, -32.5315, 1.2303, 4.2985, 8.8828]) # fmt: on self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) @@ -191,4 +191,4 @@ def test_model_from_config(self): def test_forward_with_norm_groups(self): # not implemented for this model - pass \ No newline at end of file + pass From 47af5ce39cccb51b45f2996116061a7959e92332 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 3 Oct 2022 18:39:28 -0400 Subject: [PATCH 24/26] only import moleculegnn when ready --- src/diffusers/models/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index a638d8c58468..78eb92dba12d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils import is_flax_available, is_torch_available +from ..utils import is_flax_available, is_torch_available, is_torch_geometric_available if is_torch_available(): - from .molecule_gnn import MoleculeGNN from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .vae import AutoencoderKL, VQModel @@ -24,3 +23,6 @@ if is_flax_available(): from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL + +if is_torch_geometric_available(): + from .molecule_gnn import MoleculeGNN From f5f25769b82b4d7c6c3748ea1339636125d4f266 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 3 Oct 2022 18:42:53 -0400 Subject: [PATCH 25/26] fix torch_geometric check --- src/diffusers/utils/import_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 217fb8d1a01c..7c651ec8b39e 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -217,6 +217,8 @@ def is_torch_scatter_available(): def is_torch_geometric_available(): # the model source of the Molecule Generation GNN requires a specific torch geometric version # for more info, see the original repo https://github.com/MinkaiXu/GeoDiff or our colab in readme + if not _torch_geometric_available: + return False return _torch_geometric_version == "1.7.2" From 104ec262a4fed8554b663e9c9b5298a322ca98ec Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 3 Oct 2022 18:45:27 -0400 Subject: [PATCH 26/26] remove dummy tranformers objects --- src/diffusers/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d2d20f319149..4a39530add05 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -84,7 +84,6 @@ from .pipelines import FlaxStableDiffusionPipeline else: from .utils.dummy_flax_and_transformers_objects import * # noqa F403 - from .utils.dummy_transformers_objects import * # noqa F403 if is_torch_geometric_available(): from .models import MoleculeGNN