Commit 5894542d authored by Lelkada L L P S M's avatar Lelkada L L P S M

unique word generation

parent b868bcab
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
......@@ -49,7 +49,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 18,
"id": "a979a3a5-aa3e-459b-a5b0-553e12d291db",
"metadata": {},
"outputs": [],
......@@ -64,7 +64,8 @@
"import pymongo\n",
"from sklearn.cluster import KMeans\n",
"from tqdm import tqdm\n",
"import re"
"import re\n",
"from collections import Counter\n"
]
},
{
......@@ -118,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 24,
"id": "23eaa3b8-5c1c-4eae-9fe5-224cc4e8f2a0",
"metadata": {},
"outputs": [],
......@@ -136,9 +137,8 @@
" subword_tokens = []\n",
" subword_to_word_indices = []\n",
" for idx, word in enumerate(words):\n",
" tokenized_word = tokenizer.tokenize(word)\n",
" subword_tokens.extend(tokenized_word)\n",
" subword_to_word_indices.extend([idx] * len(tokenized_word))\n",
" subword_tokens.extend(tokenizer.tokenize(word))\n",
" subword_to_word_indices.extend([idx] * len(tokenizer.tokenize(word)))\n",
"\n",
" # Encode the subword tokens and obtain the subword embeddings\n",
" inputs = tokenizer(subword_tokens, return_tensors=\"pt\", padding=True, truncation=True)\n",
......@@ -146,21 +146,22 @@
" outputs = model(**inputs)\n",
" subword_embeddings = outputs.last_hidden_state.squeeze(0).numpy()\n",
"\n",
" # Average the subword embeddings for each word (only considering the first subword token)\n",
" # Average the subword embeddings for each word\n",
" word_embeddings = []\n",
" for idx in range(len(words)):\n",
" word_subword_embeddings = subword_embeddings[np.where(np.array(subword_to_word_indices) == idx)]\n",
" word_embedding = word_subword_embeddings[0]\n",
" word_embedding = np.mean(word_subword_embeddings, axis=0)\n",
" word_embeddings.append({\"word\": words[idx], \"embedding\": word_embedding})\n",
"\n",
" all_word_embeddings.append(word_embeddings)\n",
"\n",
" return all_word_embeddings\n"
" return all_word_embeddings\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 21,
"id": "7db45f3c-7c3a-4798-8629-894814e7878d",
"metadata": {},
"outputs": [],
......@@ -174,28 +175,21 @@
" words.append(word_embedding[\"word\"])\n",
" embeddings.append(word_embedding[\"embedding\"])\n",
"\n",
" # Average embeddings for the same word across different sentences\n",
" word_to_embeddings = {}\n",
" for word, embedding in zip(words, embeddings):\n",
" if word not in word_to_embeddings:\n",
" word_to_embeddings[word] = {\"embedding\": embedding, \"count\": 1}\n",
" else:\n",
" word_to_embeddings[word][\"embedding\"] = np.add(word_to_embeddings[word][\"embedding\"], embedding)\n",
" word_to_embeddings[word][\"count\"] += 1\n",
" # Determine the average shape\n",
" avg_shape = tuple(map(int, np.mean([emb.shape for emb in embeddings], axis=0)))\n",
"\n",
" # Resize inconsistent embeddings to match the average shape\n",
" embeddings = [np.resize(emb, avg_shape) if emb.shape != avg_shape else emb for emb in embeddings]\n",
"\n",
" averaged_embeddings = []\n",
" unique_words = []\n",
" for word, embedding_data in word_to_embeddings.items():\n",
" averaged_embedding = embedding_data[\"embedding\"] / embedding_data[\"count\"]\n",
" unique_words.append(word)\n",
" averaged_embeddings.append(averaged_embedding)\n",
" # Reshape the embeddings array to be 2D\n",
" embeddings = np.vstack(embeddings)\n",
"\n",
" # Perform k-means clustering\n",
" kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(averaged_embeddings)\n",
" kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(embeddings)\n",
"\n",
" # Group words and embeddings by cluster\n",
" clusters = {i: {\"words\": [], \"embeddings\": []} for i in range(n_clusters)}\n",
" for word, embedding, label in zip(unique_words, averaged_embeddings, kmeans.labels_):\n",
" for word, embedding, label in zip(words, embeddings, kmeans.labels_):\n",
" clusters[label][\"words\"].append(word)\n",
" clusters[label][\"embeddings\"].append(embedding)\n",
"\n",
......@@ -207,7 +201,8 @@
" # Store centroids for each cluster\n",
" centroids = [{\"cluster_no\": i, \"centroid\": centroid} for i, centroid in enumerate(kmeans.cluster_centers_)]\n",
"\n",
" return clusters, centroids\n"
" return clusters, centroids\n",
"\n"
]
},
{
......@@ -253,7 +248,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 25,
"id": "b6e51e97-bc00-44fa-aeb2-88bb2faa83f3",
"metadata": {},
"outputs": [
......@@ -264,7 +259,19 @@
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Processing sentences: 44%|████▍ | 407/928 [00:37<00:37, 13.94sentence/s]"
"Processing sentences: 0%| | 3/928 [00:00<01:35, 9.70sentence/s]\n"
]
},
{
"ename": "ValueError",
"evalue": "operands could not be broadcast together with shapes (3,768) (5,768) (3,768) ",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_1/1737146920.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Example usage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mword_embeddings_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgenerate_word_embeddings\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msentences\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/tmp/ipykernel_1/2837270854.py\u001b[0m in \u001b[0;36mgenerate_word_embeddings\u001b[0;34m(sentences)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0mword_to_embeddings\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mwords\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\"embedding\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mword_embedding\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"count\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0mword_to_embeddings\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mwords\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"embedding\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mword_embedding\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0mword_to_embeddings\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mwords\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"count\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: operands could not be broadcast together with shapes (3,768) (5,768) (3,768) "
]
}
],
......@@ -275,7 +282,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 22,
"id": "6425ba1e-5156-4d8a-ad2a-f1758d39e30a",
"metadata": {},
"outputs": [],
......@@ -287,7 +294,7 @@
},
{
"cell_type": "code",
"execution_count": 68,
"execution_count": 23,
"id": "80ea8a95-371c-48d8-a50b-ff4871b81eed",
"metadata": {},
"outputs": [],
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment