Commit f2a7242e authored by D.P.C.H Pathirana's avatar D.P.C.H Pathirana 😼

GNN Model

parent ede754db
# -*- coding: utf-8 -*-
"""Poision Edible Classification.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1r40VuKZJbBTeFVASnAhetlId5x-OjmBQ
"""
! pip install torch_geometric
! pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cpu.html
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
from torch_geometric.data import Data
# Define the Graph Convolutional Network (GCN) layer
class GCNConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__()
self.linear = nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
x = self.linear(x)
return torch.matmul(x, edge_index)
# Define the Graph Neural Network model
class GNNModel(nn.Module):
def __init__(self, in_channels, hidden_channels, num_classes):
super(GNNModel, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
return x.log_softmax(dim=1)
# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
from torchvision.transforms import Resize, ToTensor
from torchvision import transforms
# Define the dataset and data loaders
# transform = nn.Sequential(Resize((224, 224)), ToTensor())
transform=transforms.Compose([Resize((64,3)),ToTensor()])
train_dataset = ImageFolder('/content/drive/MyDrive/Mushroom Classification Project/component poision edible /poisonousmushroomdetection/train', transform=transform)
val_dataset = ImageFolder('/content/drive/MyDrive/Mushroom Classification Project/component poision edible /poisonousmushroomdetection/validation', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)
train_dataset
val_dataset
# Define the model and optimizer
model = GNNModel(in_channels=3, hidden_channels=64, num_classes=len(train_dataset.classes)).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
import torch
tensor = torch.randn(64,2)
tensor = tensor.t()
tensor = tensor.mm(torch.ones( 64,3072))
print(tensor)
edge_index = torch.randn(64,3072)
# Training loop
model.train()
for epoch in range(10):
for data in train_loader:
x = data[0].to(device)
y = data[1].to(device)
edge_index = edge_index
graph_data = Data(x=x, edge_index=edge_index, y=y)
graph_data = graph_data.to(device)
optimizer.zero_grad()
output = model(graph_data)
loss = nn.functional.nll_loss(output, graph_data.y)
loss.backward()
optimizer.step()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment