Commit 5d365cbf authored by Lelkada L L P S M's avatar Lelkada L L P S M

retrieve cluster data

parent 7d7dac45
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "c3d5322b-ce1b-482d-b224-135da10982c1",
"metadata": {},
"outputs": [],
"source": [
"!pip install transformers\n",
"!pip install pymongo"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5aeb5561-495d-4225-881f-1d3ee830b27c",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModel\n",
"import torch\n",
"import numpy as np\n",
"import pymongo\n",
"from scipy.spatial import distance"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f0f0fd1f-7f30-4c0e-86ff-47978850a3a4",
"metadata": {},
"outputs": [],
"source": [
"def get_centroids(db_name=\"word_embeddings\"):\n",
" # Connect to MongoDB\n",
" client = pymongo.MongoClient(\"mongodb+srv://hearme:hearme678@cluster0.kz66vdr.mongodb.net\")\n",
" db = client['word_embedding']\n",
"\n",
" # Retrieve centroids from the collection\n",
" centroids_collection = db[\"centroids\"]\n",
" centroids_data = list(centroids_collection.find({}))\n",
"\n",
" # Convert centroid embeddings to NumPy arrays and store them in a dictionary\n",
" centroids = {}\n",
" for data in centroids_data:\n",
" cluster_no = data[\"cluster_no\"]\n",
" centroid = np.array(data[\"centroid\"])\n",
" centroids[cluster_no] = {\"cluster_no\": cluster_no, \"centroid\": centroid}\n",
"\n",
" # Close the MongoDB connection\n",
" client.close()\n",
"\n",
" return centroids"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "44b8acee-7a5e-40cb-9278-b9bda61dbeef",
"metadata": {},
"outputs": [],
"source": [
"def generate_input_embedding(word):\n",
" tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
" model = AutoModel.from_pretrained(\"bert-base-uncased\")\n",
"\n",
" # Tokenize the word into subwords\n",
" subword_tokens = tokenizer.tokenize(word)\n",
"\n",
" # Encode the subword tokens and obtain the subword embeddings\n",
" inputs = tokenizer(subword_tokens, return_tensors=\"pt\")\n",
" with torch.no_grad():\n",
" outputs = model(**inputs)\n",
" subword_embeddings = outputs.last_hidden_state.squeeze(0).numpy()\n",
"\n",
" # Average the subword embeddings to obtain the word embedding\n",
" word_embedding = np.mean(subword_embeddings, axis=0)\n",
"\n",
" return word_embedding"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "1735effa-2020-48d9-889a-39c9d954aa3a",
"metadata": {},
"outputs": [],
"source": [
"def find_closest_cluster(input_embedding, centroids):\n",
" min_distance = float('inf')\n",
" closest_cluster_no = None\n",
"\n",
" for cluster_no, centroid_data in centroids.items():\n",
" centroid_embedding = centroid_data['centroid']\n",
" current_distance = distance.euclidean(input_embedding, centroid_embedding)\n",
" print(current_distance)\n",
" if current_distance < min_distance:\n",
" min_distance = current_distance\n",
" closest_cluster_no = cluster_no\n",
"\n",
" return closest_cluster_no"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9edf418b-279e-4c26-b4ac-a8d474f61668",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "f54ccf87-702c-4e73-834a-9ec489f63084",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 26,
"id": "cf2e79ca-461b-4368-98cc-26d9e55488fd",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"13.285225168586905\n",
"11.51259043024684\n",
"8.962974422826848\n",
"12.52214870605026\n",
"10.757946725715826\n",
"The closest cluster for the word 'ocean' is cluster 0\n"
]
}
],
"source": [
"# Example usage\n",
"input_word = \"ocean\"\n",
"input_embedding = generate_input_embedding(input_word)\n",
"centroid_list = get_centroids()\n",
"closest_cluster = find_closest_cluster(input_embedding, centroid_list)\n",
"print(f\"The closest cluster for the word '{input_word}' is cluster {closest_cluster}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65a19259-dccf-4681-988c-3f292dfd1182",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Pytorch (Local)",
"language": "python",
"name": "local-pytorch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
......@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": null,
"id": "a979a3a5-aa3e-459b-a5b0-553e12d291db",
"metadata": {},
"outputs": [],
......@@ -33,7 +33,7 @@
},
{
"cell_type": "code",
"execution_count": 58,
"execution_count": null,
"id": "be38074e-f1a7-406e-bab2-a52487151fab",
"metadata": {},
"outputs": [],
......@@ -46,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": null,
"id": "4114422b-cb8c-4428-a755-c46995082027",
"metadata": {},
"outputs": [],
......@@ -56,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": null,
"id": "088c2ed1-bd44-423a-9107-b112d965a41d",
"metadata": {},
"outputs": [],
......@@ -82,7 +82,7 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": null,
"id": "23eaa3b8-5c1c-4eae-9fe5-224cc4e8f2a0",
"metadata": {},
"outputs": [],
......@@ -123,7 +123,7 @@
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": null,
"id": "7db45f3c-7c3a-4798-8629-894814e7878d",
"metadata": {},
"outputs": [],
......@@ -162,7 +162,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": null,
"id": "075bfc5a-ffa3-4264-8bef-db407cb29859",
"metadata": {},
"outputs": [],
......@@ -193,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 64,
"execution_count": null,
"id": "f5d27eeb-76ca-40f8-bc76-21c91336fbef",
"metadata": {},
"outputs": [],
......@@ -213,7 +213,7 @@
},
{
"cell_type": "code",
"execution_count": 66,
"execution_count": null,
"id": "b6e51e97-bc00-44fa-aeb2-88bb2faa83f3",
"metadata": {},
"outputs": [
......@@ -235,7 +235,7 @@
},
{
"cell_type": "code",
"execution_count": 67,
"execution_count": null,
"id": "6425ba1e-5156-4d8a-ad2a-f1758d39e30a",
"metadata": {},
"outputs": [],
......@@ -247,7 +247,7 @@
},
{
"cell_type": "code",
"execution_count": 68,
"execution_count": null,
"id": "80ea8a95-371c-48d8-a50b-ff4871b81eed",
"metadata": {},
"outputs": [],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment