Commit dfd004de authored by Lelkada L L P S M's avatar Lelkada L L P S M

word generation integration with backend

parent 45aea284
......@@ -2,6 +2,7 @@ import numpy as np
from flask import Flask, request, jsonify, request
import pickle
from word_card_game import wordGameData
from word_generation import get_similar_words
app = Flask(__name__)
......@@ -12,6 +13,7 @@ app = Flask(__name__)
# send a json {'exp':1.8,} as a post request to make a prediction
def predict():
data = request.get_json(force=True)
......@@ -24,7 +26,28 @@ def predict():
def default_get():
return "<p>HereMe Backend !</p>"
@app.route('/api/word-game', methods=['GET'])
def word_game_api():
w1 = request.args.get('w1')
w2 = request.args.get('w2')
w3 = request.args.get('w3')
if not all([w1, w2, w3]):
return jsonify({'error': 'All three words must be provided'}), 400
data = wordGameData(w1, w2, w3)
return jsonify(data)
@app.route('/api/similar-words', methods=['GET'])
def similar_words_api():
word = request.args.get('word')
if not word:
return jsonify({'error': 'A word must be provided'}), 400
similar_words = get_similar_words(word)
return jsonify({'similar_words': similar_words})
if __name__ == '__main__':, debug=True)
\ No newline at end of file
import torch
from transformers import RobertaTokenizer, RobertaForMaskedLM
# Load the pretrained RoBERTa model and tokenizer
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForMaskedLM.from_pretrained('roberta-base')
def get_similar_words(input_word, top_k=3):
# Create a masked sentence with the input word
masked_sentence = f"The {input_word} is related to the {tokenizer.mask_token}."
# Tokenize the masked sentence
inputs = tokenizer(masked_sentence, return_tensors='pt')
# Get the index of the mask token
mask_token_index = torch.where(inputs['input_ids'][0] == tokenizer.mask_token_id)[0].item()
# Predict words for the mask token
with torch.no_grad():
output = model(**inputs)
predictions = output.logits[0, mask_token_index]
# Get the top k predicted words
top_k_indices = torch.topk(predictions, top_k).indices.tolist()
related_words = [tokenizer.decode(idx).strip() for idx in top_k_indices]
return related_words
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment