/************************************************************************
 ************************************************************************
    FAUST compiler
    Copyright (C) 2003-2018 GRAME, Centre National de Creation Musicale
    ---------------------------------------------------------------------
    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU Lesser General Public License as published by
    the Free Software Foundation; either version 2.1 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Lesser General Public License for more details.

    You should have received a copy of the GNU Lesser General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 ************************************************************************
 ************************************************************************/

#include <math.h>

#include "Text.hh"
#include "floats.hh"
#include "xtended.hh"

class SqrtPrim : public xtended {
   public:
    SqrtPrim() : xtended("sqrt") {}

    virtual unsigned int arity() override { return 1; }

    virtual bool needCache() override { return true; }

    virtual ::Type inferSigType(ConstTypes args) override
    {
        faustassert(args.size() == 1);
        Type     t = args[0];
        interval i = t->getInterval();
        if (i.isValid() && i.lo() < 0 && gGlobal->gMathExceptions) {
            std::stringstream error;
            error << "WARNING : potential out of domain in sqrt(" << i << ")" << std::endl;
            gWarningMessages.push_back(error.str());
        }
        return castInterval(floatCast(t), gAlgebra.Sqrt(i));
    }

    virtual int inferSigOrder(const std::vector<int>& args) override { return args[0]; }

    virtual Tree computeSigOutput(const std::vector<Tree>& args) override
    {
        // check simplifications
        num n;
        if (isNum(args[0], n)) {
            if (double(n) < 0) {
                std::stringstream error;
                error << "ERROR : out of domain in sqrt(" << ppsig(args[0], MAX_ERROR_SIZE) << ")"
                      << std::endl;
                throw faustexception(error.str());
            } else {
                return tree(sqrt(double(n)));
            }
        } else {
            return tree(symbol(), args[0]);
        }
    }

    virtual ValueInst* generateCode(CodeContainer* container, Values& args, ::Type result,
                                    ConstTypes types) override
    {
        faustassert(args.size() == arity());
        faustassert(types.size() == arity());

        return generateFun(container, subst("sqrt$0", isuffix()), args, result, types);
    }

    virtual std::string generateCode(Klass* klass, const std::vector<std::string>& args,
                                     ConstTypes types) override
    {
        faustassert(args.size() == arity());
        faustassert(types.size() == arity());

        return subst("sqrt$1($0)", args[0], isuffix());
    }

    virtual std::string generateLateq(Lateq* lateq, const std::vector<std::string>& args,
                                      ConstTypes types) override
    {
        faustassert(args.size() == arity());
        faustassert(types.size() == arity());

        return subst("\\sqrt{$0}", args[0]);
    }

    virtual Tree diff(const std::vector<Tree>& args) override
    {
        // (x^{1/2})' =  1/2 * x^{-1/2}
        return sigMul(sigReal(0.5), sigPow(args[0], sigReal(-0.5)));
    }

    double compute(const std::vector<Node>& args) override { return sqrt(args[0].getDouble()); }
};
