| """Utilities for converting Graphein Networks to Geometric Deep Learning formats. |
| """ |
| |
| |
| |
| |
| |
| |
| from __future__ import annotations |
|
|
| from typing import List, Optional |
|
|
| import networkx as nx |
| import numpy as np |
| import torch |
|
|
| try: |
| from graphein.utils.dependencies import import_message |
| except ImportError: |
| raise Exception('You need to install graphein from source in addition to DSSP to use this model please refer to https://github.com/a-r-j/graphein and https://ssbio.readthedocs.io/en/latest/instructions/dssp.html') |
|
|
| try: |
| import torch_geometric |
| from torch_geometric.data import Data |
| except ImportError: |
| import_message( |
| submodule="graphein.ml.conversion", |
| package="torch_geometric", |
| pip_install=True, |
| conda_channel="rusty1s", |
| ) |
|
|
| try: |
| import dgl |
| except ImportError: |
| import_message( |
| submodule="graphein.ml.conversion", |
| package="dgl", |
| pip_install=True, |
| conda_channel="dglteam", |
| ) |
|
|
| try: |
| import jax.numpy as jnp |
| except ImportError: |
| import_message( |
| submodule="graphein.ml.conversion", |
| package="jax", |
| pip_install=True, |
| conda_channel="conda-forge", |
| ) |
| try: |
| import jraph |
| except ImportError: |
| import_message( |
| submodule="graphein.ml.conversion", |
| package="jraph", |
| pip_install=True, |
| conda_channel="conda-forge", |
| ) |
|
|
|
|
| SUPPORTED_FORMATS = ["nx", "pyg", "dgl", "jraph"] |
| """Supported conversion formats. |
| |
| ``"nx"``: NetworkX graph |
| |
| ``"pyg"``: PyTorch Geometric Data object |
| |
| ``"dgl"``: DGL graph |
| |
| ``"Jraph"``: Jraph GraphsTuple |
| """ |
|
|
| SUPPORTED_VERBOSITY = ["gnn", "default", "all_info"] |
| """Supported verbosity levels for preserving graph features in conversion.""" |
|
|
|
|
| class GraphFormatConvertor: |
| """ |
| Provides conversion utilities between NetworkX Graphs and geometric deep learning library destination formats. |
| Currently, we provide support for converstion from ``nx.Graph`` to ``dgl.DGLGraph`` and ``pytorch_geometric.Data``. Supported conversion |
| formats can be retrieved from :const:`~graphein.ml.conversion.SUPPORTED_FORMATS`. |
| |
| :param src_format: The type of graph you'd like to convert from. Supported formats are available in :const:`~graphein.ml.conversion.SUPPORTED_FORMATS` |
| :type src_format: Literal["nx", "pyg", "dgl", "jraph"] |
| :param dst_format: The type of graph format you'd like to convert to. Supported formats are available in: |
| ``graphein.ml.conversion.SUPPORTED_FORMATS`` |
| :type dst_format: Literal["nx", "pyg", "dgl", "jraph"] |
| :param verbose: Select from ``"gnn"``, ``"default"``, ``"all_info"`` to determine how much information is preserved (features) |
| as some are unsupported by various downstream frameworks |
| :type verbose: graphein.ml.conversion.SUPPORTED_VERBOSITY |
| :param columns: List of columns in the node features to retain |
| :type columns: List[str], optional |
| """ |
|
|
| def __init__( |
| self, |
| src_format: str, |
| dst_format: str, |
| verbose: SUPPORTED_VERBOSITY = "gnn", |
| columns: Optional[List[str]] = None, |
| ): |
| if (src_format not in SUPPORTED_FORMATS) or ( |
| dst_format not in SUPPORTED_FORMATS |
| ): |
| raise ValueError( |
| "Please specify from supported format, " |
| + "/".join(SUPPORTED_FORMATS) |
| ) |
| self.src_format = src_format |
| self.dst_format = dst_format |
|
|
| |
| if (columns is None) and (verbose not in SUPPORTED_VERBOSITY): |
| raise ValueError( |
| "Please specify the supported verbose mode (" |
| + "/".join(SUPPORTED_VERBOSITY) |
| + ") or specify column names!" |
| ) |
|
|
| if columns is None: |
| if verbose == "gnn": |
| columns = [ |
| "edge_index", |
| "coords", |
| "dist_mat", |
| "name", |
| "node_id", |
| ] |
| elif verbose == "default": |
| columns = [ |
| "b_factor", |
| "chain_id", |
| "coords", |
| "dist_mat", |
| "edge_index", |
| "kind", |
| "name", |
| "node_id", |
| "residue_name", |
| ] |
| elif verbose == "all_info": |
| columns = [ |
| "atom_type", |
| "b_factor", |
| "chain_id", |
| "chain_ids", |
| "config", |
| "coords", |
| "dist_mat", |
| "edge_index", |
| "element_symbol", |
| "kind", |
| "name", |
| "node_id", |
| "node_type", |
| "pdb_df", |
| "raw_pdb_df", |
| "residue_name", |
| "residue_number", |
| "rgroup_df", |
| "sequence_A", |
| "sequence_B", |
| ] |
| self.columns = columns |
|
|
| self.type2form = { |
| "atom_type": "str", |
| "b_factor": "float", |
| "chain_id": "str", |
| "coords": "np.array", |
| "dist_mat": "np.array", |
| "element_symbol": "str", |
| "node_id": "str", |
| "residue_name": "str", |
| "residue_number": "int", |
| "edge_index": "torch.tensor", |
| "kind": "str", |
| } |
|
|
| def convert_nx_to_dgl(self, G: nx.Graph) -> dgl.DGLGraph: |
| """ |
| Converts ``NetworkX`` graph to ``DGL`` |
| |
| :param G: ``nx.Graph`` to convert to ``DGLGraph`` |
| :type G: nx.Graph |
| :return: ``DGLGraph`` object version of input ``NetworkX`` graph |
| :rtype: dgl.DGLGraph |
| """ |
| g = dgl.DGLGraph() |
| node_id = list(G.nodes()) |
| G = nx.convert_node_labels_to_integers(G) |
|
|
| |
|
|
| node_dict = {} |
| for i, (_, feat_dict) in enumerate(G.nodes(data=True)): |
| for key, value in feat_dict.items(): |
| if str(key) in self.columns: |
| node_dict[str(key)] = ( |
| [value] if i == 0 else node_dict[str(key)] + [value] |
| ) |
|
|
| string_dict = {} |
| node_dict_transformed = {} |
| for i, j in node_dict.items(): |
| if i == "coords": |
| node_dict_transformed[i] = torch.Tensor(np.asarray(j)).type( |
| "torch.FloatTensor" |
| ) |
| elif i == "dist_mat": |
| node_dict_transformed[i] = torch.Tensor( |
| np.asarray(j[0].values) |
| ).type("torch.FloatTensor") |
| elif self.type2form[i] == "str": |
| string_dict[i] = j |
| elif self.type2form[i] in ["float", "int"]: |
| node_dict_transformed[i] = torch.Tensor(np.array(j)) |
| g.add_nodes( |
| len(node_id), |
| node_dict_transformed, |
| ) |
|
|
| edge_dict = {} |
| edge_index = torch.LongTensor(list(G.edges)).t().contiguous() |
|
|
| |
| for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): |
| for key, value in feat_dict.items(): |
| if str(key) in self.columns: |
| edge_dict[str(key)] = ( |
| list(value) |
| if i == 0 |
| else edge_dict[str(key)] + list(value) |
| ) |
|
|
| edge_transform_dict = {} |
| for i, j in node_dict.items(): |
| if self.type2form[i] == "str": |
| string_dict[i] = j |
| elif self.type2form[i] in ["float", "int"]: |
| edge_transform_dict[i] = torch.Tensor(np.array(j)) |
| g.add_edges(edge_index[0], edge_index[1], edge_transform_dict) |
|
|
| |
| graph_dict = { |
| str(feat_name): [G.graph[feat_name]] |
| for feat_name in G.graph |
| if str(feat_name) in self.columns |
| } |
|
|
| return g |
|
|
| def convert_nx_to_pyg(self, G: nx.Graph) -> Data: |
| """ |
| Converts ``NetworkX`` graph to ``pytorch_geometric.data.Data`` object. Requires ``PyTorch Geometric`` (https://pytorch-geometric.readthedocs.io/en/latest/) to be installed. |
| |
| :param G: ``nx.Graph`` to convert to PyTorch Geometric ``Data`` object |
| :type G: nx.Graph |
| :return: ``Data`` object containing networkx graph data |
| :rtype: pytorch_geometric.data.Data |
| """ |
|
|
| |
| data = {"node_id": list(G.nodes())} |
| G = nx.convert_node_labels_to_integers(G) |
|
|
| |
| edge_index = torch.LongTensor(list(G.edges)).t().contiguous() |
|
|
| |
| for i, (_, feat_dict) in enumerate(G.nodes(data=True)): |
| for key, value in feat_dict.items(): |
| if str(key) in self.columns: |
| data[str(key)] = ( |
| [value] if i == 0 else data[str(key)] + [value] |
| ) |
|
|
| |
| for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): |
| for key, value in feat_dict.items(): |
| if str(key) in self.columns: |
| data[str(key)] = ( |
| list(value) if i == 0 else data[str(key)] + list(value) |
| ) |
|
|
| |
| for feat_name in G.graph: |
| if str(feat_name) in self.columns: |
| data[str(feat_name)] = [G.graph[feat_name]] |
|
|
| if "edge_index" in self.columns: |
| data["edge_index"] = edge_index.view(2, -1) |
|
|
| data = Data.from_dict(data) |
| data.num_nodes = G.number_of_nodes() |
| return data |
|
|
| @staticmethod |
| def convert_nx_to_nx(G: nx.Graph) -> nx.Graph: |
| """ |
| Converts NetworkX graph (``nx.Graph``) to NetworkX graph (``nx.Graph``) object. Redundant - returns itself. |
| |
| :param G: NetworkX Graph |
| :type G: nx.Graph |
| :return: NetworkX Graph |
| :rtype: nx.Graph |
| """ |
| return G |
|
|
| @staticmethod |
| def convert_dgl_to_nx(G: dgl.DGLGraph) -> nx.Graph: |
| """ |
| Converts a DGL Graph (``dgl.DGLGraph``) to a NetworkX (``nx.Graph``) object. Preserves node and edge attributes. |
| |
| :param G: ``dgl.DGLGraph`` to convert to ``NetworkX`` graph. |
| :type G: dgl.DGLGraph |
| :return: NetworkX graph object. |
| :rtype: nx.Graph |
| """ |
| node_attrs = G.node_attr_schemes().keys() |
| edge_attrs = G.edge_attr_schemes().keys() |
| return dgl.to_networkx(G, node_attrs, edge_attrs) |
|
|
| @staticmethod |
| def convert_pyg_to_nx(G: Data) -> nx.Graph: |
| """Converts PyTorch Geometric ``Data`` object to NetworkX graph (``nx.Graph``). |
| |
| :param G: Pytorch Geometric Data. |
| :type G: torch_geometric.data.Data |
| :returns: NetworkX graph. |
| :rtype: nx.Graph |
| """ |
| return torch_geometric.utils.to_networkx(G) |
|
|
| def convert_nx_to_jraph(self, G: nx.Graph) -> jraph.GraphsTuple: |
| """Converts NetworkX graph (``nx.Graph``) to Jraph GraphsTuple graph. Requires ``jax`` and ``Jraph``. |
| |
| :param G: Networkx graph to convert. |
| :type G: nx.Graph |
| :return: Jraph GraphsTuple graph. |
| :rtype: jraph.GraphsTuple |
| """ |
| G = nx.convert_node_labels_to_integers(G) |
|
|
| n_node = len(G) |
| n_edge = G.number_of_edges() |
| edge_list = list(G.edges()) |
| senders, receivers = zip(*edge_list) |
| senders, receivers = jnp.array(senders), jnp.array(receivers) |
|
|
| |
| node_features = {} |
| for i, (_, feat_dict) in enumerate(G.nodes(data=True)): |
| for key, value in feat_dict.items(): |
| if str(key) in self.columns: |
| |
| |
| |
| |
| |
| feat = ( |
| [value] |
| if i == 0 |
| else node_features[str(key)] + [value] |
| ) |
| try: |
| feat = torch.tensor(feat) |
| node_features[str(key)] = feat |
| except TypeError: |
| node_features[str(key)] = feat |
|
|
| |
| edge_features = {} |
| for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): |
| for key, value in feat_dict.items(): |
| if str(key) in self.columns: |
| edge_features[str(key)] = ( |
| list(value) |
| if i == 0 |
| else edge_features[str(key)] + list(value) |
| ) |
|
|
| |
| global_context = { |
| str(feat_name): [G.graph[feat_name]] |
| for feat_name in G.graph |
| if str(feat_name) in self.columns |
| } |
|
|
| return jraph.GraphsTuple( |
| nodes=node_features, |
| senders=senders, |
| receivers=receivers, |
| edges=edge_features, |
| n_node=n_node, |
| n_edge=n_edge, |
| globals=global_context, |
| ) |
|
|
| def __call__(self, G: nx.Graph): |
| nx_g = eval("self.convert_" + self.src_format + "_to_nx(G)") |
| dst_g = eval("self.convert_nx_to_" + self.dst_format + "(nx_g)") |
| return dst_g |
|
|
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| def convert_nx_to_pyg_data(G: nx.Graph) -> Data: |
| |
| data = {"node_id": list(G.nodes())} |
|
|
| G = nx.convert_node_labels_to_integers(G) |
|
|
| |
| edge_index = torch.LongTensor(list(G.edges)).t().contiguous() |
|
|
| |
| for i, (_, feat_dict) in enumerate(G.nodes(data=True)): |
| for key, value in feat_dict.items(): |
| data[str(key)] = [value] if i == 0 else data[str(key)] + [value] |
|
|
|
|
| |
| for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): |
| for key, value in feat_dict.items(): |
| if key == 'distance': |
| data[str(key)] = ( |
| [value] if i == 0 else data[str(key)] + [value] |
| ) |
| else: |
| data[str(key)] = ( |
| [list(value)] if i == 0 else data[str(key)] + [list(value)] |
| ) |
|
|
| |
| for feat_name in G.graph: |
| data[str(feat_name)] = [G.graph[feat_name]] |
|
|
| data["edge_index"] = edge_index.view(2, -1) |
| data = Data.from_dict(data) |
| data.num_nodes = G.number_of_nodes() |
|
|
| return data |
|
|