#!/usr/bin/python3
# -*- coding: utf-8 -*-

# Copyright © 2009-2010, 2012-2013 marmuta <marmvta@gmail.com>
#
# This file is part of Onboard.
#
# Onboard is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# Onboard is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import os
import sys
import fnmatch
import subprocess
from optparse import OptionParser
from collections import Counter

import sys
from os.path import dirname, abspath

from pypredict import *



def main():
    global model # for debugging

    parser = OptionParser(usage="Usage: %prog [options] model corpus corpus-pattern")
    parser.add_option("-o", "--order", type="int", dest="order", default=3,
              help="order of the language model, defaults to 3")
    parser.add_option("-l", "--language", type="str", dest="lang_id", default="",
              help="language id for the spell checker, e.g. en_US")
    parser.add_option("-q", "--quiet",
              action="store_true", dest="quiet", default=False,
              help="only show the final summary")
    parser.add_option("-v", "--vocabulary", type="str", dest="vocabulary_file",
              help="list of words to consider during model creation")
    parser.add_option("-u", "--max-unigrams", type="int",
              dest="max_unigrams", default=0,
              help="prune n-grams with counts below or equal the one of the "
                   "least frequent of the top max_unigrams unigram;"
                   "default 0, disabled")
    options, args = parser.parse_args()

    order = options.order
    out = None if options.quiet else sys.stdout
    vocabulary = read_vocabulary(options.vocabulary_file) \
                 if options.vocabulary_file else None

    if order == 1:
        model = UnigramModel()
    else:
        model = DynamicModel()
        model.order = order
    max_unigrams = options.max_unigrams
    model_filename = args[0]
    lang_id = options.lang_id

    spell_checker = None
    if lang_id:
        spell_checker = SpellChecker()
        spell_checker.set_backend(0)
        if not spell_checker.set_dict_ids([lang_id]):
            print("No spell checker dictionary found for '{}'".format(lang_id))
        spelling_cache = {"<unk>" : True,
                          "<s>" : True,
                          "</s>" : True,
                          "<num>" : True,}

    if len(args) >= 3:
        filenames = rglob(args[1], args[2])
        n = len(filenames)
        for i, filename in enumerate(filenames):
            count = i + 1
            print("{:6}/{} {:7.2f}%: {}" \
                  .format(count, n, 100.0 * count / n, filename))
            text = read_corpus(filename)
            tokens, spans = tokenize_text(text)
            #tokens += ["Anti-Corrutpion","生辰八字"]

            if vocabulary:
                tokens = filter_tokens(tokens, vocabulary)

            if spell_checker:
                spell_check(spell_checker, spelling_cache, tokens)

            # Skip over the first word of each sentence? Those are usually
            # capitalized and we can't distinguish them from capitalized nouns.
            skip_sentence_begin = True
            if skip_sentence_begin:
                sections = split_tokens(tokens, "<s>")
                token_sections = []
                for section in sections:
                    token_sections.append(section[1:])
            else:
                token_sections = [tokens]

            for token_section in token_sections:
                model.learn_tokens(token_section)

            if count % 3000 == 0:
                print("saving", repr(model_filename))
                model.save(model_filename)
                #print_stats(model)

    elif len(args) >= 2:
        filename = args[1]
        with timeit("read_corpus", out):
            text = read_corpus(filename)

        with timeit("tokenize_text", out):
            tokens, spans = tokenize_text(text)

        if vocabulary:
            with timeit("filter_tokens", out):
                tokens = filter_tokens(tokens, vocabulary)

        with timeit("learn_tokens", out):
            model.learn_tokens(tokens)

    if max_unigrams:
        with timeit("prune n-grams", out):
            cnt = Counter(tokens)
            most_common = cnt.most_common(max_unigrams)
            if most_common:
                min_token = most_common[-1]
                prune_count = min_token[1]
                #print("pruning", min_token, prune_count)
                model = model.prune(prune_count)

    with timeit("save", out):
        model.save(model_filename)

    print_stats(model)

def spell_check(spell_checker, spelling_cache, tokens):
    unknowns = {}
    num_new = 0
    for itoken, token in enumerate(tokens):
        if not token in spelling_cache:
            try:
                unknowns[token].append(itoken)
            except KeyError:
                unknowns[token] = [itoken]

    if unknowns:
        print("spell-checking {} unknowns of {} total tokens" \
              .format(len(unknowns), len(tokens)))
        spell_tokens = unknowns.keys()
        correct_words = set(spell_checker.query(spell_tokens))
        incorrect_words = set(spell_tokens) - correct_words

        for token in spell_tokens:
            spelling_cache[token] = not token in incorrect_words

        print("known tokens {:7}, dropping {:6} of {:6}: " \
              .format(len(spelling_cache),
                      len(incorrect_words),
                      len(spell_tokens)))
        print (incorrect_words)

    # finally filter the tokens
    for itoken, token in enumerate(tokens):
        if not spelling_cache[token]:
            tokens[itoken] = "<unk>"

def print_stats(model):
    counts, totals = model.get_counts()
    print("calculating stats...")
    for i,c in enumerate(counts):
        sys.stdout.write("%d-grams: types %10d, occurences %10d\n" % \
              (i+1, counts[i], totals[i]))
    print(model.memory_size())

def rglob(dir_str, pattern_str):
    filenames = []
    dirs = dir_str.split(",")
    patterns = pattern_str.split(",")

    for dir in dirs:
        for root, dirs, files in os.walk(dir):
            for basename in files:
                for pattern in patterns:
                    if fnmatch.fnmatch(basename, pattern):
                        filenames.append(os.path.join(root, basename))
                        break

    filenames.sort()
    return filenames


class SpellChecker:

    def __init__(self):
        self.dict_ids = []

    def set_backend(self, backend):
        pass

    def set_dict_ids(self, dict_ids):
        self.dict_ids = dict_ids
        return True

    def query(self, tokens):
        correct_words = []

        args = ["hunspell", "-G", "-i", "UTF-8"]
        if self.dict_ids:
            args += ["-d", ",".join(self.dict_ids)]

        p = None
        try:
            p = subprocess.Popen(args, stdin=subprocess.PIPE,
                                       stdout=subprocess.PIPE,
                                       close_fds=True)
        except OSError as e:
            _logger.error(_format("Failed to execute '{}', {}", \
                            " ".join(args), e))

        # Check if the process is still running, it might have
        # exited on start due to an unknown dictinary name.
        if p and p.poll() is None:
            for token in tokens:
                line = (token + "\n").encode("UTF-8")
                p.stdin.write(line)
            p.stdin.close()

            while True:
                line = p.stdout.readline().decode("UTF-8")
                if not line:
                    break
                token = line.strip()
                correct_words.append(token)

        if p:
            p.terminate()
            p.wait()

        return correct_words


if __name__ == '__main__':
    main()

