Commit bd066a99 authored by nazeerxexagen's avatar nazeerxexagen

feature extraction model phase 2

parent 1d64ea33
......@@ -630,6 +630,224 @@
" plt.legend()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# **SPATIAL FEATURE EXTRACTION MODEL(PHASE 02)** "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"def set_parameter_requires_grad(model):\n",
" for param in model.parameters():\n",
" param.requires_grad = False\n",
" \n",
"def feature_extractor():\n",
" feature_extraction_layers = []\n",
"\n",
" for layer in model[:-3]:\n",
" feature_extraction_layers.append(layer)\n",
"\n",
" feature_extraction_model = nn.Sequential(*feature_extraction_layers) \n",
"\n",
" set_parameter_requires_grad(feature_extraction_model)\n",
" \n",
" return feature_extraction_model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"feature_extraction_model = feature_extractor()\n",
"if not os.path.exists(FEATURE_MODEL_WEIGHT_PATH):\n",
" if not os.path.exists(WEIGHT_PATH):\n",
" train_loop()\n",
" save_model(model, WEIGHT_PATH)\n",
" visualize_performance()\n",
" else:\n",
" load_model(model, WEIGHT_PATH)\n",
" save_model(feature_extraction_model, FEATURE_MODEL_WEIGHT_PATH)\n",
"else:\n",
" load_model(model, WEIGHT_PATH)\n",
" load_model(feature_extraction_model, FEATURE_MODEL_WEIGHT_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"# Inference\n",
"\n",
"image_paths = np.array(dataset.Images)\n",
"labels = np.array(dataset.Labels)\n",
"\n",
"def predict_on_single_image(image_path):\n",
" img = Image.open(image_path).convert('RGB')\n",
" img = np.asarray(img)\n",
" img = torch.unsqueeze(image_transform(img), 0).to(device)\n",
" features = feature_extraction_model(img).squeeze()\n",
" return features\n",
"\n",
"def extract_features_all():\n",
" if not os.path.exists(FEATURE_PATH):\n",
" ALL_FEATURES = torch.empty(\n",
" (len(labels),4096), \n",
" dtype=torch.float32, \n",
" device = device\n",
" )\n",
"\n",
" for idx, image_path in enumerate(image_paths):\n",
" features = predict_on_single_image(image_path)\n",
" ALL_FEATURES[idx, :] = features\n",
" \n",
" torch.save(ALL_FEATURES, FEATURE_PATH)\n",
" \n",
" else:\n",
" ALL_FEATURES = torch.load(FEATURE_PATH)\n",
"\n",
" return ALL_FEATURES"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"ALL_FEATURES = extract_features_all()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"def pairwise_cosine_similarity(V1, V2):\n",
" return F.cosine_similarity(\n",
" torch.unsqueeze(V1, 0),\n",
" torch.unsqueeze(V2, 0)\n",
" ).squeeze().item()\n",
"\n",
"# def sample_inference_visualization():\n",
"# test_img_path = np.random.choice(image_paths)\n",
"# test_feature = predict_on_single_image(test_img_path)\n",
" \n",
"# V1 = test_feature\n",
"# cosine_similarities = np.array([pairwise_cosine_similarity(V1, V2) for V2 in ALL_FEATURES])\n",
"# similarity_order = np.argsort(cosine_similarities)[::-1]\n",
" \n",
"# similar_images = image_paths[similarity_order][1:9]\n",
"# similar_images = np.insert(similar_images, 0, test_img_path)\n",
"\n",
"# w = 10\n",
"# h = 10\n",
"# fig = plt.figure(figsize=(12, 12))\n",
"# columns = 3\n",
"# rows = 3\n",
"\n",
"# for i in range(1, columns*rows + 1):\n",
"# img = Image.open(similar_images[i-1]).convert('RGB')\n",
"# img = np.asarray(img)\n",
"# fig.add_subplot(rows, columns, i)\n",
"# plt.imshow(img)\n",
"\n",
"# plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"# for _ in range(10):\n",
"# sample_inference_visualization()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"def mate_inference(test_img_path):\n",
" test_feature = predict_on_single_image(test_img_path)\n",
" \n",
" V1 = test_feature\n",
" cosine_similarities = np.array([pairwise_cosine_similarity(V1, V2) for V2 in ALL_FEATURES])\n",
" similarity_order = np.argsort(cosine_similarities)[::-1]\n",
" \n",
" similar_images = image_paths[similarity_order][1:9]\n",
" similar_images = np.insert(similar_images, 0, test_img_path)\n",
"\n",
" w = 10\n",
" h = 10\n",
" fig = plt.figure(figsize=(12, 12))\n",
" columns = 3\n",
" rows = 3\n",
"\n",
" for i in range(1, columns*rows + 1):\n",
" img = Image.open(similar_images[i-1]).convert('RGB')\n",
" img = np.asarray(img)\n",
" fig.add_subplot(rows, columns, i)\n",
" plt.imshow(img)\n",
"\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"test_img_path = ''\n",
"mate_inference(test_img_path)"
]
}
],
"metadata": {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment