import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
import warnings; warnings.filterwarnings("ignore", category=UserWarning)
from download import download_LaMini_model; download_LaMini_model()

checkpoint = "./LaMini/"  # LaMini-Flan-T5-248M
tokenizer = AutoTokenizer.from_pretrained(checkpoint, device='cpu')
base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to('cpu')

def embed_and_encode(input_text):
    # tokenizer
    tokens = tokenizer.encode(input_text, return_tensors="pt")

    # add start
    pad = base_model.config.pad_token_id # 0 
    start = torch.tensor([[pad]])
    input_tokens = torch.concatenate([start,tokens],dim=1)

    # embedding
    embed = base_model.shared(input_tokens) 

    # encode
    def encode(x, mask=None):
        for block in base_model.encoder.block:
            x, mask = block(x, mask)
        return base_model.encoder.final_layer_norm(x)

    # hidden
    hidden = encode(embed) 

    return hidden

def decode_and_wipeout(hidden):
    # decode
    def decode(x, mask, crossx, crossmask):
        for block in base_model.decoder.block:
            x, mask, crossmask = block(x, mask, None, crossx, crossmask)
        return base_model.decoder.final_layer_norm(x)

    # generate
    pad = base_model.config.pad_token_id # 0 
    eos = base_model.config.eos_token_id # 1
    start = torch.tensor([[pad]])
    output_tokens = start
    while True:
        embed = base_model.shared(output_tokens) 
        output = decode(embed, None, hidden, torch.ones(hidden.shape[:2])) 
        logits = base_model.lm_head(output) 
        next_token = torch.argmax(logits[:,-1,:]) # 0-32127 
        output_tokens = torch.concatenate([output_tokens,torch.tensor([[next_token]])],dim=1)
        if next_token == eos:
            break

    # wipe out
    output_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    
    return output_text
    
def cosine_similarity(u,v):
    return ((u @ v) / (torch.norm(u)*torch.norm(v))).item()

if __name__ == "__main__": 
    # chatbot
    print(decode_and_wipeout(embed_and_encode("What is the capital of the USA?")))
    # classifier
    cls = 0
    print(cosine_similarity(embed_and_encode("the ball")[0,cls],embed_and_encode("the spheric object")[0,cls]))
    print(cosine_similarity(embed_and_encode("the ball")[0,cls],embed_and_encode("the philosopher")[0,cls]))