from __future__ import division
from __future__ import print_function

import functools
import logging
import os
import pickle
import re
import time

import nltk
import numpy as np
import pandas as pd
# from gensim.models import KeyedVectors
from gensim.models import Word2Vec
from nltk.translate import bleu_score
from scipy import spatial
from sklearn.metrics.pairwise import cosine_similarity

import warnings

warnings.filterwarnings('ignore')

from .language_features import MORPH_FEATS  # , SYNTAX_LINKS
from .meaningfulwords_v3 import words, first_and_last_occur

LOGDIR = 'log/'
if not os.path.exists(LOGDIR):
    os.makedirs(LOGDIR)


class FeaturesProcessor:
    CATEGORY = 'category_id'

    def __init__(self,
                 verbose=False):

        self.verbose = verbose
        if self.verbose:
            print("Processor initialization...\t", end="", flush=True)

        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.DEBUG)
        formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
        handler = logging.FileHandler(LOGDIR + 'features_processor.log', mode='a')
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)
        self.logger.info('Call FeaturesProcessor.__init__()')

        self.stop_words = nltk.corpus.stopwords.words('russian')
        self.meaningful = words
        self.first_and_last_occur = first_and_last_occur

        self.vectorizer = pickle.load(open('utils/tf_idf_pipeline.save', 'rb'))

        # preprocessing functions
        self._uppercased = lambda snippet, length: sum(
            [word[0].isupper() if len(word) > 0 else False for word in snippet.split()]) / length
        self._start_with_uppercase = lambda snippet, length: sum(
            [word[0].isupper() if len(word) > 0 else False for word in snippet.split(' ')]) / length

        self._preprocess_text = lambda snippet: ' '.join(self._tag_mystem(text=snippet, mapping=mystem2upos))

        # word2vec, table of universal postags
        # self.word2vec_model = KeyedVectors.load_word2vec_format('models/ruwikiruscorpora_upos_cbow_300_20_2017.bin', binary=True)
        self.word2vec_model = Word2Vec.load('models_w2v/model2_tokenized')
        try:
            self.word2vec_vector_length = len(self.word2vec_model.wv.get_vector('дерево'))
        except KeyError:
            self.word2vec_vector_length = len(self.word2vec_model.wv.get_vector('дерево_NOUN'))

        # in linguistics texts > are replaced by '&gt;' but in .rs3 they are not replaced
        self.text_html_map = {
            r'&gt;': r'>',
            r'&lt;': r'<',
            r'&amp;': r'&',
            r'&quot;': r'"',
            r'&ndash;': r'–',
        }

        mystem2upos = {
            'A': 'ADJ',
            'ADV': 'ADV',
            'ADVPRO': 'ADV',
            'ANUM': 'ADJ',
            'APRO': 'DET',
            'COM': 'ADJ',
            'CONJ': 'SCONJ',
            'INTJ': 'INTJ',
            'NONLEX': 'X',
            'NUM': 'NUM',
            'PART': 'PART',
            'PR': 'ADP',
            'S': 'NOUN',
            'SPRO': 'PRON',
            'UNKN': 'X',
            'V': 'VERB'
        }

        self.fpos_combinations = [
            'ADJ_NOUN',
            'ADJ_X',
            'ADP_ADJ',
            'ADP_NOUN',
            'ADV_ADV',
            'ADV_NOUN',
            'ADV_VERB',
            'ADV_X',
            'CONJ_ADV',
            'CONJ_NOUN',
            'CONJ_X',
            'NOUN_ADJ',
            'NOUN_ADV',
            'NOUN_NOUN',
            'NOUN_PART',
            'NOUN_VERB',
            'NOUN_X',
            'NUM_NOUN',
            'PART_VERB',
            'PART_X',
            'PRON_NOUN',
            'PRON_PRON',
            'PRON_X',
            'VERB_ADJ',
            'VERB_ADP',
            'VERB_ADV',
            'VERB_NOUN',
            'VERB_PRON',
            'VERB_X',
            'X',
            'X_NOUN',
            'X_X']

        if self.verbose:
            print('[DONE]')

    def __call__(self, df_, annotations):

        df = df_[:]
        self.annotations = annotations

        if self.verbose:
            t = time.time()
            t1 = t
            print('1.1\t', end="", flush=True)

        df['new_paragraph_x'] = df.snippet_x.str.startswith('#####').astype('int')
        df['new_paragraph_y'] = df.snippet_y.str.startswith('#####').astype('int')
        df['snippet_x'].replace({r'##### ': r'', r'\\\\': r'\\'}, regex=True, inplace=True)
        df['snippet_y'].replace({r'##### ': r'', r'\\\\': r'\\'}, regex=True, inplace=True)

        if self.verbose:
            print(time.time() - t)
            t = time.time()
            print('1.2\t', end="", flush=True)

        # map discourse units to annotations
        df['loc_x'] = df.snippet_x.map(self.annotations['text'].find)
        df['loc_y'] = df.apply(lambda row: self.annotations['text'].find(row.snippet_y, row.loc_x+1), axis=1)
        df['token_begin_x'] = df.loc_x.map(self.locate_token)
        df['token_begin_y'] = df.loc_y.map(self.locate_token)

        # there is a bug in ling_20
        df = df[df['loc_y'] != -1]
        df['token_end_y'] = df.apply(lambda row: self.locate_token(row.loc_y + len(row.snippet_y)) + 1, axis=1)

        # length of tokens sequence
        df['len_w_x'] = df['token_begin_y'] - df['token_begin_x']
        df['len_w_y'] = df['token_end_y'] - df['token_begin_y'] # +1

        df['snippet_x_locs'] = df.apply(lambda row: [[pair for pair in [self.token_to_sent_word(token) for token in
                                                                       range(row.token_begin_x, row.token_begin_y)] if
                                                     pair]], axis=1)
        df['snippet_x_locs'] = df.snippet_x_locs.map(lambda row: row[0])
        df['snippet_y_locs'] = df.apply(lambda row: [[pair for pair in [self.token_to_sent_word(token) for token in
                                                                        range(row.token_begin_y, row.token_end_y)] if
                                                      pair]], axis=1)
        df['snippet_y_locs'] = df.snippet_y_locs.map(lambda row: row[0])
        df.drop(columns=['loc_x', 'loc_y'], inplace=True)

        if self.verbose:
            print(time.time() - t)
            t = time.time()
            print('4\t', end="", flush=True)


        # get tokens
        df['tokens_x'] = df.apply(lambda row: self.get_tokens(row.token_begin_x, row.token_begin_y), axis=1)
        df['tokens_y'] = df.apply(lambda row: self.get_tokens(row.token_begin_y, row.token_end_y), axis=1)

        # average word length
        df['len_av_x'] = df.tokens_x.map(lambda row: sum([len(word) for word in row])) / (df.len_w_x + 1e-8)
        df['len_av_y'] = df.tokens_y.map(lambda row: sum([len(word) for word in row])) / (df.len_w_y + 1e-8)

        # get lemmas
        df['lemmas_x'] = df.snippet_x_locs.map(self.get_lemma)
        df['lemmas_y'] = df.snippet_y_locs.map(self.get_lemma)

        if self.verbose:
            print(time.time() - t)
            t = time.time()
            print('5\t', end="", flush=True)

        # ratio of uppercased words
        df['upper_x'] = df.tokens_x.map(lambda row: sum(token.isupper() for token in row) / len(row))
        df['upper_y'] = df.tokens_y.map(lambda row: sum(token.isupper() for token in row) / len(row))

        # ratio of the words starting with upper case
        df['st_up_x'] = df.tokens_x.map(lambda row: sum(token[0].isupper() for token in row) / len(row))
        df['st_up_y'] = df.tokens_y.map(lambda row: sum(token[0].isupper() for token in row) / len(row))

        # whether DU starts with upper case
        df['du_st_up_x'] = df.tokens_x.map(lambda row: row[0][0].isupper()).astype(int)
        df['du_st_up_y'] = df.tokens_y.map(lambda row: row[0][0].isupper()).astype(int)

        if self.verbose:
            print(time.time() - t)
            t = time.time()
            print('6.1\t', end="", flush=True)

        # get morphology
        df['morph_x'] = df.snippet_x_locs.map(self.get_morph)
        df['morph_y'] = df.snippet_y_locs.map(self.get_morph)

        # count presense and/or quantity of various language features in the whole DUs and at the beginning/end of them
        df = df.apply(lambda row: self._linguistic_features(row, tags=MORPH_FEATS), axis=1)
        df = df.apply(lambda row: self._first_and_last_pair(row), axis=1)

        if self.verbose:
            print(time.time() - t)
            t = time.time()
            print('6.2\t', end="", flush=True)

        # count various vectors similarity metrics for morhology
        linknames_for_snippet_x = df[[name + '_x' for name in MORPH_FEATS]]
        linknames_for_snippet_y = df[[name + '_y' for name in MORPH_FEATS]]
        df.reset_index(inplace=True)
        df['morph_vec_x'] = pd.Series(self.columns_to_vectors_(linknames_for_snippet_x))
        df['morph_vec_y'] = pd.Series(self.columns_to_vectors_(linknames_for_snippet_y))
        df['morph_correlation'] = df[['morph_vec_x', 'morph_vec_y']].apply(
            lambda row: spatial.distance.correlation(*row), axis=1)
        df['morph_canberra'] = df[['morph_vec_x', 'morph_vec_y']].apply(lambda row: spatial.distance.canberra(*row),
                                                                        axis=1)
        df['morph_hamming'] = df[['morph_vec_x', 'morph_vec_y']].apply(lambda row: spatial.distance.hamming(*row),
                                                                       axis=1)
        df['morph_matching'] = df[['morph_vec_x', 'morph_vec_y']].apply(
            lambda row: self.get_match_between_vectors_(*row), axis=1)
        df.set_index('index', drop=True, inplace=True)
        df = df.drop(columns=['morph_vec_x', 'morph_vec_y'])

        if self.verbose:
            print(time.time() - t)
            t = time.time()
            print('7\t', end="", flush=True)

        # detect discourse markers
        for word in self.meaningful:
            df[word + '_count' + '_x'] = df.snippet_x.map(lambda row: self.count_marker_(word, row))
            df[word + '_count' + '_y'] = df.snippet_y.map(lambda row: self.count_marker_(word, row))

        if self.verbose:
            print(time.time() - t)
            t = time.time()
            print('8\t', end="", flush=True)

        # count stop words in the texts
        df['stopwords_x'] = df.lemmas_x.map(self._count_stop_words)
        df['stopwords_y'] = df.lemmas_y.map(self._count_stop_words)

        if self.verbose:
            print(time.time() - t)
            t = time.time()
            print('9\t', end="", flush=True)

        # vectorize
        df.reset_index(drop=True, inplace=True)
        tf_idf_x = self.vectorizer.transform(df['snippet_x'])
        tf_idf_y = self.vectorizer.transform(df['snippet_y'])
        df['cos_simil'] = self.get_cosine_sim(tf_idf_x, tf_idf_y)

        tf_idf_x = pd.DataFrame(tf_idf_x).add_prefix('tf_idf_x_')
        tf_idf_y = pd.DataFrame(tf_idf_y).add_prefix('tf_idf_y_')

        df = pd.concat([df, tf_idf_x, tf_idf_y], axis=1)

        if self.verbose:
            print(time.time() - t)
            t = time.time()
            print('10\t', end="", flush=True)

        # count various lexical similarity metrics
        df['jac_simil'] = df.apply(lambda row: self.get_jaccard_sim(row.lemmas_x, row.lemmas_y), axis=1)
        df['bleu'] = df.apply(lambda row: self.get_bleu_score(row.lemmas_x, row.lemmas_y), axis=1)

        if self.verbose:
            print(time.time() - t)
            t = time.time()
            print('11\t', end="", flush=True)

        # Get average vector for each text
        df = self._get_vectors(df)

        df = df.drop(columns=[
            'lemmas_x', 'lemmas_y',
            'snippet_x_locs', 'snippet_y_locs',
            'morph_x', 'morph_y',
            'tokens_x', 'tokens_y',
            'common_root_fpos',
            'common_root_att',
            'common_root'
        ])

        if self.verbose:
            print(time.time() - t)
            print('[DONE]')
            print('estimated time:', time.time() - t1)

        return df

    def locate_token(self, start):
        for i, token in enumerate(self.annotations['tokens']):
            if token.begin > start:
                return i - 1
            elif token.begin == start:
                return i
        return i

    def map_to_token(self, pair):
        if pair == -1:
            return -1

        sentence, word = pair
        if type(word) == list and len(word) == 1:
            word = word[0]

        return self.annotations['sentences'][sentence].begin + word

    def token_to_sent_word(self, token):
        for i, sentence in enumerate(self.annotations['sentences']):
            if sentence.begin <= token < sentence.end:
                return i, [token - sentence.begin]
        return ()

    def locate_root(self, row):
        if row.same_sentence:
            for i, wordsynt in enumerate(self.annotations['syntax_dep_tree'][row.sentence_begin_x]):
                if wordsynt.parent == -1:
                    return row.sentence_begin_x, [i]
        return -1

    def get_roots(self, locations):
        res = []
        for word in locations:
            parent = self.annotations['syntax_dep_tree'][word[0]][word[1][0]].parent
            if parent == -1:
                res.append(word)
        return res

    def locate_attached(self, row):
        res = []
        sent_begin = self.annotations['sentences'][row.sentence_begin_x].begin
        for i, wordsynt in enumerate(self.annotations['syntax_dep_tree'][row.sentence_begin_x]):
            if row.token_begin_x - sent_begin <= i < row.token_end_y - sent_begin:
                if wordsynt.parent == -1:
                    res.append(i)
        return res

    def get_tokens(self, begin, end):
        return [self.annotations['tokens'][i].text for i in range(begin, end)]

    def get_lemma(self, positions):
        return [self.annotations['lemma'][position[0]][position[1][0]] for position in positions]

    def get_postag(self, positions):
        if positions:
            if positions[0] == -1:
                return ['']
            result = [self.annotations['postag'][position[0]][position[1][0]] for position in positions]
            if not result:
                return ['X']
            return result
        return ['']

    def get_morph(self, positions):
        return [self.annotations['morph'][position[0]][position[1][0]] for position in positions]

    def columns_to_vectors_(self, columns):
        return [row + 1e-05 for row in np.array(columns.values.tolist())]

    def get_match_between_vectors_(self, vector1, vector2):
        return spatial.distance.hamming([k > 0.01 for k in vector1], [k > 0.01 for k in vector2])

    def _get_fpos_vectors(self, row):
        result = {}
        for header in ['VERB', 'NOUN', '', 'ADV', 'ADJ', 'ADP', 'CONJ', 'PART', 'PRON' 'NUM']:
            result[header + '_common_root'] = int(row.common_root_fpos == header)

        for header in ['VERB', '', 'NOUN', 'ADJ', 'ADV', 'ADP', 'CONJ', 'PRON', 'PART', 'NUM', 'INTJ']:
            result[header + '_common_root_att'] = int(row.common_root_att == header)

        return row.append(pd.Series(list(result.values()), index=list(result.keys())))

    @functools.lru_cache(maxsize=2048)
    def count_marker_(self, word, row):
        return sum([1 for _ in re.finditer(word, row)])

    @functools.lru_cache(maxsize=2048)
    def locate_marker_(self, word, row):
        for m in re.finditer(word, row):
            index = m.start()
            return (index + 1.) / len(row) * 100.
        return -1.

    def _linguistic_features(self, row, tags):
        """ Count occurences of each feature from MORPH_FEATS and/or SYNTAX_LINKS """
        tags = MORPH_FEATS

        def get_tags_for_snippet(morph_annot, mark='_x'):
            result = dict.fromkeys(['%s%s' % (tag, mark) for tag in tags], 0)
            # for sentence in morph_annot:
            for record in morph_annot:
                for key, value in record.items():
                    try:
                        result['%s_%s%s' % (key, value, mark)] += 1
                    except KeyError as e:
                        self.logger.warning('::: Did not find such key in MORPH_FEATS: %s :::' % e)

            return result

        tags_for_snippet_x = get_tags_for_snippet(row.morph_x, '_x')
        tags_for_snippet_y = get_tags_for_snippet(row.morph_y, '_y')

        tags = dict(tags_for_snippet_x, **tags_for_snippet_y)

        return row.append(pd.Series(list(tags.values()), index=list(tags.keys())))

    def _count_stop_words(self, lemmatized_text, threshold=0):
        return len([1 for token in lemmatized_text if len(token) >= threshold and token in self.stop_words])

    def _first_and_last_pair(self, row):
        def get_features_for_snippet(first_pair_text, first_pair_morph, last_pair_text, last_pair_morph, mark='_x'):
            result = {}

            for pos_combination in self.fpos_combinations:
                result['first_' + pos_combination + mark] = int(pos_combination == first_pair_morph)
                result['last_' + pos_combination + mark] = int(pos_combination == last_pair_morph)

            for word in self.first_and_last_occur:
                result['first_pair_' + word + mark] = int(bool(re.findall(word, first_pair_text)))
                result['last_pair_' + word + mark] = int(bool(re.findall(word, last_pair_text)))

            return result

        # snippet X
        first_pair_text_x = ' '.join([token for token in row.tokens_x[:2]])
        first_pair_morph_x = '_'.join(
            [token.get('fPOS') if token.get('fPOS') else 'X' for token in row.morph_x[:2]])
        if len(row.tokens_x) > 2:
            last_pair_text_x = ' '.join([token for token in row.tokens_x[-2:]])
            last_pair_morph_x = '_'.join(
                [token.get('fPOS') if token.get('fPOS') else 'X' for token in row.morph_x[-2:]])
        else:
            last_pair_text_x = ' '
            last_pair_morph_x = 'X'

        features_of_snippet_x = get_features_for_snippet(first_pair_text_x, first_pair_morph_x,
                                                         last_pair_text_x, last_pair_morph_x,
                                                         '_x')

        # snippet Y
        first_pair_text_y = ' '.join([token for token in row.tokens_y[:2]])
        first_pair_morph_y = '_'.join(
            [token.get('fPOS') if token.get('fPOS') else 'X' for token in row.morph_y[:2]])
        if len(row.tokens_y) > 2:
            last_pair_text_y = ' '.join([token for token in row.tokens_y[-2:]])
            last_pair_morph_y = '_'.join(
                [token.get('fPOS') if token.get('fPOS') else 'X' for token in row.morph_y[-2:]])
        else:
            last_pair_text_y = ' '
            last_pair_morph_y = 'X'

        features_of_snippet_y = get_features_for_snippet(first_pair_text_y, first_pair_morph_y,
                                                         last_pair_text_y, last_pair_morph_y,
                                                         '_y')

        tags = dict(features_of_snippet_x, **features_of_snippet_y)

        return row.append(pd.Series(list(tags.values()), index=list(tags.keys())))

    def get_jaccard_sim(self, text1, text2):
        txt1 = set(text1)
        txt2 = set(text2)
        c = len(txt1.intersection(txt2))
        return float(c) / (len(txt1) + len(txt2) - c + 1e-05)

    def get_bleu_score(self, text1, text2):
        return bleu_score.sentence_bleu([text1], text2, weights=(0.5,))

    def get_cosine_sim(self, *strs):
        vectors = [t for t in self.vectorizer.transform(strs)]
        return cosine_similarity(vectors)[0, 1]

    def get_cosine_sim(self, mtrx1, mtrx2):
        res = []
        for line in range(len(mtrx1)):
            vectors = [mtrx1[line], mtrx2[line]]
            res.append(cosine_similarity(vectors)[0, 1])
        return res

    def _tag_mystem(self, text, mapping=None, postags=True):
        # Only lemmas if postags=False

        processed = self.mystem_processor.analyze(text)
        tagged = []
        for w in processed:
            try:
                try:
                    lemma = w["analysis"][0]["lex"].lower().strip()
                    pos = w["analysis"][0]["gr"].split(',')[0]
                    pos = pos.split('=')[0].strip()
                except IndexError:
                    continue
                if mapping:
                    if pos in mapping:
                        pos = mapping[pos]
                    else:
                        pos = 'X'
                tagged.append(lemma.lower() + '_' + pos)
            except KeyError:
                continue  # ommit punctuation
        if not postags:
            tagged = [t.split('_')[0] for t in tagged]
        return tagged

    def _get_vectors(self, df):
        def mean_vector(lemmatized_text):
            res = list([np.zeros(self.word2vec_vector_length), ])
            for word in lemmatized_text:
                try:
                    res.append(self.word2vec_model[word])
                except KeyError:
                    pass
                    # self.logger.warning('There is no "%s" in vocabulary of the given model; ommited' % word)
            mean = sum(np.array(res)) / (len(res) - 1 + 1e-25)
            return mean

        # self.logger.info('FeaturesProcessor._get_vectors() says:')

        # The models don't contain any stop words, so remove it from data before lemmatize and tagging
        # df.snippet_x = df.snippet_x.apply(self._remove_stop_words)
        # df.snippet_y = df.snippet_y.apply(self._remove_stop_words)

        # Add the required UPoS postags (as in the rusvectores word2vec model's vocabulary)
        # df.snippet_x = df.snippet_x.apply(self._preprocess_text)
        # df.snippet_y = df.snippet_y.apply(self._preprocess_text)

        # Make two dataframes with average vectors for x and y,
        # merge them with the original dataframe
        get_embedding = lambda snippet: mean_vector(snippet)

        df_embed_x = pd.DataFrame(df.lemmas_x.apply(get_embedding).values.tolist())
        df_embed_y = pd.DataFrame(df.lemmas_y.apply(get_embedding).values.tolist())
        embeddings = df_embed_x.merge(df_embed_y, left_index=True, right_index=True)
        df = pd.concat([df.reset_index(drop=True), embeddings.reset_index(drop=True)], axis=1)

        return df
