Commit 7e26c6ba authored by Lelkada L L P S M's avatar Lelkada L L P S M

word generation

parent 452104ae
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "7a96eefc-6ba6-4c4e-8fa9-9c1c2ea9b44b",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics.pairwise import euclidean_distances\n",
"import numpy as np\n",
"import pickle\n",
"from transformers import AutoTokenizer, AutoModel\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "cdd793ae-fdc5-4f37-a515-f78f9a2def8c",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "676018a113d94fb2b3e793d3330d5a84",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/466k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"model_name = 'bert-base-uncased'\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = AutoModel.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ed9f5d2d-4fc6-461e-a2ad-825ddf1d7cdf",
"metadata": {},
"outputs": [],
"source": [
"# Load the saved embeddings from the file\n",
"context_embeddings = np.load('./embeddings/vocabulary_embeddings.npy')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6d2a7cc1-51db-437a-8a23-84478e2b502c",
"metadata": {},
"outputs": [],
"source": [
"# Define the input word you want to find similar words for\n",
"word = 'example'\n",
"\n",
"# Tokenize the word and get its token ID\n",
"tokens = tokenizer.encode(word, add_special_tokens=False)\n",
"token_id = tokens[0]\n",
"\n",
"# Convert the token ID to a tensor\n",
"input_ids = torch.tensor([token_id]).unsqueeze(0)\n",
"\n",
"# Extract the embedding for the input word from the pre-trained BERT model\n",
"with torch.no_grad():\n",
" embedding = model(input_ids)[0][0][1:-1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a19a512-7564-4d69-91f2-bd4961995fb2",
"metadata": {},
"outputs": [],
"source": [
"distances = euclidean_distances(embedding.reshape(1,-1), context_embeddings)\n",
"\n",
"# Sort the distances in ascending order and get the indices of the most similar words\n",
"similar_indices = distances.argsort()[0]\n",
"\n",
"# Print the top 5 most similar words\n",
"for i in range(5):\n",
" similar_word = tokenizer.decode(context_tokens[similar_indices[i]])\n",
" print(similar_word)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ba8bec7-03b7-4e09-852e-99a95f6c376f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "default:Python",
"language": "python",
"name": "conda-env-default-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{ {
"cells": [], "cells": [
"metadata": {}, {
"cell_type": "code",
"execution_count": null,
"id": "2ab73bc9-076e-488c-9048-9c457afb42aa",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "default:Python",
"language": "python",
"name": "conda-env-default-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 5 "nbformat_minor": 5
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment