26class Expression::Term :
public SingleThreadedReferenceCountedObject
32 virtual Type getType() const noexcept = 0;
33 virtual Term* clone() const = 0;
34 virtual ReferenceCountedObjectPtr<Term> resolve (const Scope&,
int recursionDepth) = 0;
35 virtual String toString() const = 0;
36 virtual
double toDouble()
const {
return 0; }
37 virtual int getInputIndexFor (
const Term*)
const {
return -1; }
38 virtual int getOperatorPrecedence()
const {
return 0; }
39 virtual int getNumInputs()
const {
return 0; }
40 virtual Term* getInput (
int)
const {
return nullptr; }
41 virtual ReferenceCountedObjectPtr<Term> negated();
43 virtual ReferenceCountedObjectPtr<Term> createTermToEvaluateInput (
const Scope&,
const Term* ,
44 double , Term* )
const
47 return ReferenceCountedObjectPtr<Term>();
50 virtual String getName()
const
56 virtual void renameSymbol (
const Symbol& oldSymbol,
const String& newName,
const Scope& scope,
int recursionDepth)
58 for (
int i = getNumInputs(); --i >= 0;)
59 getInput (i)->renameSymbol (oldSymbol, newName, scope, recursionDepth);
65 virtual ~SymbolVisitor() {}
66 virtual void useSymbol (
const Symbol&) = 0;
69 virtual void visitAllSymbols (SymbolVisitor& visitor,
const Scope& scope,
int recursionDepth)
71 for (
int i = getNumInputs(); --i >= 0;)
72 getInput(i)->visitAllSymbols (visitor, scope, recursionDepth);
76 JUCE_DECLARE_NON_COPYABLE (Term)
81struct Expression::Helpers
83 using TermPtr = ReferenceCountedObjectPtr<Term>;
85 static void checkRecursionDepth (
int depth)
88 throw EvaluationError (
"Recursive symbol references");
91 friend class Expression::Term;
100 DBG (
"Expression::EvaluationError: " + description);
107 class Constant :
public Term
114 Term* clone()
const {
return new Constant (value, isResolutionTarget); }
115 TermPtr resolve (
const Scope&,
int) {
return *
this; }
116 double toDouble()
const {
return value; }
117 TermPtr negated() {
return *
new Constant (-value, isResolutionTarget); }
119 String toString()
const
122 if (isResolutionTarget)
129 bool isResolutionTarget;
133 class BinaryTerm :
public Term
136 BinaryTerm (TermPtr l, TermPtr r) : left (std::move (l)), right (std::move (r))
138 jassert (left !=
nullptr && right !=
nullptr);
141 int getInputIndexFor (
const Term* possibleInput)
const
143 return possibleInput == left ? 0 : (possibleInput == right ? 1 : -1);
146 Type getType() const noexcept {
return operatorType; }
147 int getNumInputs()
const {
return 2; }
148 Term* getInput (
int index)
const {
return index == 0 ? left.get() : (index == 1 ? right.get() :
nullptr); }
150 virtual double performFunction (
double left,
double right)
const = 0;
151 virtual void writeOperator (String& dest)
const = 0;
153 TermPtr resolve (
const Scope& scope,
int recursionDepth)
155 return *
new Constant (performFunction (left ->resolve (scope, recursionDepth)->toDouble(),
156 right->resolve (scope, recursionDepth)->toDouble()),
false);
159 String toString()
const
162 auto ourPrecendence = getOperatorPrecedence();
164 if (left->getOperatorPrecedence() > ourPrecendence)
165 s <<
'(' << left->toString() <<
')';
167 s = left->toString();
171 if (right->getOperatorPrecedence() >= ourPrecendence)
172 s <<
'(' << right->toString() <<
')';
174 s << right->toString();
180 const TermPtr left, right;
182 TermPtr createDestinationTerm (
const Scope& scope,
const Term* input,
double overallTarget, Term* topLevelTerm)
const
184 jassert (input == left || input == right);
185 if (input != left && input != right)
188 if (
auto dest = findDestinationFor (topLevelTerm,
this))
189 return dest->createTermToEvaluateInput (scope,
this, overallTarget, topLevelTerm);
191 return *
new Constant (overallTarget,
false);
196 class SymbolTerm :
public Term
199 explicit SymbolTerm (
const String& sym) : symbol (sym) {}
201 TermPtr resolve (
const Scope& scope,
int recursionDepth)
203 checkRecursionDepth (recursionDepth);
204 return scope.getSymbolValue (symbol).term->resolve (scope, recursionDepth + 1);
207 Type getType() const noexcept {
return symbolType; }
208 Term* clone()
const {
return new SymbolTerm (symbol); }
209 String toString()
const {
return symbol; }
210 String getName()
const {
return symbol; }
212 void visitAllSymbols (SymbolVisitor& visitor,
const Scope& scope,
int recursionDepth)
214 checkRecursionDepth (recursionDepth);
215 visitor.useSymbol (Symbol (scope.getScopeUID(), symbol));
216 scope.getSymbolValue (symbol).term->visitAllSymbols (visitor, scope, recursionDepth + 1);
219 void renameSymbol (
const Symbol& oldSymbol,
const String& newName,
const Scope& scope,
int )
221 if (oldSymbol.symbolName == symbol && scope.getScopeUID() == oldSymbol.scopeUID)
229 class Function :
public Term
232 explicit Function (
const String& name) : functionName (name) {}
234 Function (
const String& name,
const Array<Expression>& params)
235 : functionName (name), parameters (params)
238 Type getType() const noexcept {
return functionType; }
239 Term* clone()
const {
return new Function (functionName, parameters); }
240 int getNumInputs()
const {
return parameters.size(); }
241 Term* getInput (
int i)
const {
return parameters.getReference(i).term.get(); }
242 String getName()
const {
return functionName; }
244 TermPtr resolve (
const Scope& scope,
int recursionDepth)
246 checkRecursionDepth (recursionDepth);
248 auto numParams = parameters.size();
252 HeapBlock<double> params (numParams);
254 for (
int i = 0; i < numParams; ++i)
255 params[i] = parameters.getReference(i).term->resolve (scope, recursionDepth + 1)->toDouble();
257 result = scope.evaluateFunction (functionName, params, numParams);
261 result = scope.evaluateFunction (functionName,
nullptr, 0);
264 return *
new Constant (result,
false);
267 int getInputIndexFor (
const Term* possibleInput)
const
269 for (
int i = 0; i < parameters.size(); ++i)
270 if (parameters.getReference(i).term == possibleInput)
276 String toString()
const
278 if (parameters.size() == 0)
279 return functionName +
"()";
281 String s (functionName +
" (");
283 for (
int i = 0; i < parameters.size(); ++i)
285 s << parameters.getReference(i).term->toString();
287 if (i < parameters.size() - 1)
295 const String functionName;
296 Array<Expression> parameters;
300 class DotOperator :
public BinaryTerm
303 DotOperator (SymbolTerm* l, TermPtr r) : BinaryTerm (TermPtr (l), r) {}
305 TermPtr resolve (
const Scope& scope,
int recursionDepth)
307 checkRecursionDepth (recursionDepth);
309 EvaluationVisitor visitor (right, recursionDepth + 1);
310 scope.visitRelativeScope (getSymbol()->
symbol, visitor);
311 return visitor.output;
314 Term* clone()
const {
return new DotOperator (getSymbol(), *right); }
315 String getName()
const {
return "."; }
316 int getOperatorPrecedence()
const {
return 1; }
317 void writeOperator (String& dest)
const { dest <<
'.'; }
318 double performFunction (
double,
double)
const {
return 0.0; }
320 void visitAllSymbols (SymbolVisitor& visitor,
const Scope& scope,
int recursionDepth)
322 checkRecursionDepth (recursionDepth);
323 visitor.useSymbol (Symbol (scope.getScopeUID(), getSymbol()->
symbol));
325 SymbolVisitingVisitor v (right, visitor, recursionDepth + 1);
329 scope.visitRelativeScope (getSymbol()->
symbol, v);
334 void renameSymbol (
const Symbol& oldSymbol,
const String& newName,
const Scope& scope,
int recursionDepth)
336 checkRecursionDepth (recursionDepth);
337 getSymbol()->renameSymbol (oldSymbol, newName, scope, recursionDepth);
339 SymbolRenamingVisitor visitor (right, oldSymbol, newName, recursionDepth + 1);
343 scope.visitRelativeScope (getSymbol()->
symbol, visitor);
350 class EvaluationVisitor :
public Scope::Visitor
353 EvaluationVisitor (
const TermPtr& t,
const int recursion)
354 : input (t), output (t), recursionCount (recursion) {}
356 void visit (
const Scope& scope) { output = input->resolve (scope, recursionCount); }
360 const int recursionCount;
363 JUCE_DECLARE_NON_COPYABLE (EvaluationVisitor)
366 class SymbolVisitingVisitor :
public Scope::Visitor
369 SymbolVisitingVisitor (
const TermPtr& t, SymbolVisitor& v,
const int recursion)
370 : input (t), visitor (v), recursionCount (recursion) {}
372 void visit (
const Scope& scope) { input->visitAllSymbols (visitor, scope, recursionCount); }
376 SymbolVisitor& visitor;
377 const int recursionCount;
379 JUCE_DECLARE_NON_COPYABLE (SymbolVisitingVisitor)
382 class SymbolRenamingVisitor :
public Scope::Visitor
385 SymbolRenamingVisitor (
const TermPtr& t,
const Expression::Symbol& symbol_,
const String& newName_,
const int recursionCount_)
386 : input (t), symbol (symbol_), newName (newName_), recursionCount (recursionCount_) {}
388 void visit (
const Scope& scope) { input->renameSymbol (symbol, newName, scope, recursionCount); }
392 const Symbol& symbol;
393 const String newName;
394 const int recursionCount;
396 JUCE_DECLARE_NON_COPYABLE (SymbolRenamingVisitor)
399 SymbolTerm* getSymbol()
const {
return static_cast<SymbolTerm*
> (left.get()); }
401 JUCE_DECLARE_NON_COPYABLE (DotOperator)
405 class Negate :
public Term
408 explicit Negate (
const TermPtr& t) : input (t)
410 jassert (t !=
nullptr);
413 Type getType() const noexcept {
return operatorType; }
414 int getInputIndexFor (
const Term* possibleInput)
const {
return possibleInput == input ? 0 : -1; }
415 int getNumInputs()
const {
return 1; }
416 Term* getInput (
int index)
const {
return index == 0 ? input.get() :
nullptr; }
417 Term* clone()
const {
return new Negate (*input->clone()); }
419 TermPtr resolve (
const Scope& scope,
int recursionDepth)
421 return *
new Constant (-input->resolve (scope, recursionDepth)->toDouble(),
false);
424 String getName()
const {
return "-"; }
425 TermPtr negated() {
return input; }
427 TermPtr createTermToEvaluateInput (
const Scope& scope,
const Term* t,
double overallTarget, Term* topLevelTerm)
const
430 jassert (t == input);
432 const Term*
const dest = findDestinationFor (topLevelTerm,
this);
434 return *
new Negate (dest ==
nullptr ? TermPtr (*
new Constant (overallTarget,
false))
435 : dest->createTermToEvaluateInput (scope, this, overallTarget, topLevelTerm));
438 String toString()
const
440 if (input->getOperatorPrecedence() > 0)
441 return "-(" + input->toString() +
")";
443 return "-" + input->toString();
451 class Add :
public BinaryTerm
454 Add (TermPtr l, TermPtr r) : BinaryTerm (l, r) {}
456 Term* clone()
const {
return new Add (*left->clone(), *right->clone()); }
457 double performFunction (
double lhs,
double rhs)
const {
return lhs + rhs; }
458 int getOperatorPrecedence()
const {
return 3; }
459 String getName()
const {
return "+"; }
460 void writeOperator (String& dest)
const { dest <<
" + "; }
462 TermPtr createTermToEvaluateInput (
const Scope& scope,
const Term* input,
double overallTarget, Term* topLevelTerm)
const
464 if (
auto newDest = createDestinationTerm (scope, input, overallTarget, topLevelTerm))
465 return *
new Subtract (newDest, *(input == left ? right : left)->clone());
471 JUCE_DECLARE_NON_COPYABLE (Add)
475 class Subtract :
public BinaryTerm
478 Subtract (TermPtr l, TermPtr r) : BinaryTerm (l, r) {}
480 Term* clone()
const {
return new Subtract (*left->clone(), *right->clone()); }
481 double performFunction (
double lhs,
double rhs)
const {
return lhs - rhs; }
482 int getOperatorPrecedence()
const {
return 3; }
483 String getName()
const {
return "-"; }
484 void writeOperator (String& dest)
const { dest <<
" - "; }
486 TermPtr createTermToEvaluateInput (
const Scope& scope,
const Term* input,
double overallTarget, Term* topLevelTerm)
const
488 if (
auto newDest = createDestinationTerm (scope, input, overallTarget, topLevelTerm))
491 return *
new Add (*newDest, *right->clone());
493 return *
new Subtract (*left->clone(), *newDest);
500 JUCE_DECLARE_NON_COPYABLE (Subtract)
504 class Multiply :
public BinaryTerm
507 Multiply (TermPtr l, TermPtr r) : BinaryTerm (l, r) {}
509 Term* clone()
const {
return new Multiply (*left->clone(), *right->clone()); }
510 double performFunction (
double lhs,
double rhs)
const {
return lhs * rhs; }
511 String getName()
const {
return "*"; }
512 void writeOperator (String& dest)
const { dest <<
" * "; }
513 int getOperatorPrecedence()
const {
return 2; }
515 TermPtr createTermToEvaluateInput (
const Scope& scope,
const Term* input,
double overallTarget, Term* topLevelTerm)
const
517 if (
auto newDest = createDestinationTerm (scope, input, overallTarget, topLevelTerm))
518 return *
new Divide (newDest, *(input == left ? right : left)->clone());
523 JUCE_DECLARE_NON_COPYABLE (Multiply)
527 class Divide :
public BinaryTerm
530 Divide (TermPtr l, TermPtr r) : BinaryTerm (l, r) {}
532 Term* clone()
const {
return new Divide (*left->clone(), *right->clone()); }
533 double performFunction (
double lhs,
double rhs)
const {
return lhs / rhs; }
534 String getName()
const {
return "/"; }
535 void writeOperator (String& dest)
const { dest <<
" / "; }
536 int getOperatorPrecedence()
const {
return 2; }
538 TermPtr createTermToEvaluateInput (
const Scope& scope,
const Term* input,
double overallTarget, Term* topLevelTerm)
const
540 auto newDest = createDestinationTerm (scope, input, overallTarget, topLevelTerm);
542 if (newDest ==
nullptr)
546 return *
new Multiply (*newDest, *right->clone());
548 return *
new Divide (*left->clone(), *newDest);
551 JUCE_DECLARE_NON_COPYABLE (Divide)
555 static Term* findDestinationFor (Term*
const topLevel,
const Term*
const inputTerm)
557 const int inputIndex = topLevel->getInputIndexFor (inputTerm);
561 for (
int i = topLevel->getNumInputs(); --i >= 0;)
563 Term*
const t = findDestinationFor (topLevel->getInput (i), inputTerm);
572 static Constant* findTermToAdjust (Term*
const term,
const bool mustBeFlagged)
574 jassert (term !=
nullptr);
576 if (term->getType() == constantType)
578 Constant*
const c =
static_cast<Constant*
> (term);
579 if (c->isResolutionTarget || ! mustBeFlagged)
583 if (term->getType() == functionType)
586 const int numIns = term->getNumInputs();
588 for (
int i = 0; i < numIns; ++i)
590 Term*
const input = term->getInput (i);
592 if (input->getType() == constantType)
594 Constant*
const c =
static_cast<Constant*
> (input);
596 if (c->isResolutionTarget || ! mustBeFlagged)
601 for (
int i = 0; i < numIns; ++i)
602 if (
auto c = findTermToAdjust (term->getInput (i), mustBeFlagged))
608 static bool containsAnySymbols (
const Term& t)
610 if (t.getType() == Expression::symbolType)
613 for (
int i = t.getNumInputs(); --i >= 0;)
614 if (containsAnySymbols (*t.getInput (i)))
621 class SymbolCheckVisitor :
public Term::SymbolVisitor
624 SymbolCheckVisitor (
const Symbol& s) : symbol (s) {}
625 void useSymbol (
const Symbol& s) { wasFound = wasFound || s == symbol; }
627 bool wasFound =
false;
630 const Symbol& symbol;
632 JUCE_DECLARE_NON_COPYABLE (SymbolCheckVisitor)
636 class SymbolListVisitor :
public Term::SymbolVisitor
639 SymbolListVisitor (Array<Symbol>& list_) : list (list_) {}
640 void useSymbol (
const Symbol& s) { list.addIfNotAlreadyThere (s); }
645 JUCE_DECLARE_NON_COPYABLE (SymbolListVisitor)
653 Parser (String::CharPointerType& stringToParse) : text (stringToParse)
657 TermPtr readUpToComma()
660 return *
new Constant (0.0,
false);
662 auto e = readExpression();
664 if (e ==
nullptr || ((! readOperator (
",")) && ! text.isEmpty()))
665 return parseError (
"Syntax error: \"" + String (text) +
"\"");
673 String::CharPointerType& text;
675 TermPtr parseError (
const String& message)
684 static inline bool isDecimalDigit (
const juce_wchar c)
noexcept
686 return c >=
'0' && c <=
'9';
689 bool readChar (
const juce_wchar required)
noexcept
691 if (*text == required)
700 bool readOperator (
const char* ops,
char*
const opType =
nullptr) noexcept
702 text = text.findEndOfWhitespace();
706 if (readChar ((juce_wchar) (uint8) *ops))
708 if (opType !=
nullptr)
720 bool readIdentifier (String& identifier)
noexcept
722 text = text.findEndOfWhitespace();
726 if (t.isLetter() || *t ==
'_')
731 while (t.isLetterOrDigit() || *t ==
'_')
740 identifier = String (text, (
size_t) numChars);
748 Term* readNumber() noexcept
750 text = text.findEndOfWhitespace();
752 bool isResolutionTarget = (*t ==
'@');
754 if (isResolutionTarget)
757 t = t.findEndOfWhitespace();
764 t = t.findEndOfWhitespace();
767 if (isDecimalDigit (*t) || (*t ==
'.' && isDecimalDigit (t[1])))
773 TermPtr readExpression()
775 auto lhs = readMultiplyOrDivideExpression();
778 while (lhs !=
nullptr && readOperator (
"+-", &opType))
780 auto rhs = readMultiplyOrDivideExpression();
783 return parseError (
"Expected expression after \"" +
String::charToString ((juce_wchar) (uint8) opType) +
"\"");
786 lhs = *
new Add (lhs, rhs);
788 lhs = *
new Subtract (lhs, rhs);
794 TermPtr readMultiplyOrDivideExpression()
796 auto lhs = readUnaryExpression();
799 while (lhs !=
nullptr && readOperator (
"*/", &opType))
801 TermPtr rhs (readUnaryExpression());
804 return parseError (
"Expected expression after \"" +
String::charToString ((juce_wchar) (uint8) opType) +
"\"");
807 lhs = *
new Multiply (lhs, rhs);
809 lhs = *
new Divide (lhs, rhs);
815 TermPtr readUnaryExpression()
818 if (readOperator (
"+-", &opType))
820 TermPtr e (readUnaryExpression());
823 return parseError (
"Expected expression after \"" +
String::charToString ((juce_wchar) (uint8) opType) +
"\"");
831 return readPrimaryExpression();
834 TermPtr readPrimaryExpression()
836 if (
auto e = readParenthesisedExpression())
839 if (
auto e = readNumber())
842 return readSymbolOrFunction();
845 TermPtr readSymbolOrFunction()
849 if (readIdentifier (identifier))
851 if (readOperator (
"("))
853 auto f =
new Function (identifier);
854 std::unique_ptr<Term> func (f);
856 auto param = readExpression();
858 if (param ==
nullptr)
860 if (readOperator (
")"))
861 return TermPtr (func.release());
863 return parseError (
"Expected parameters after \"" + identifier +
" (\"");
868 while (readOperator (
","))
870 param = readExpression();
872 if (param ==
nullptr)
873 return parseError (
"Expected expression after \",\"");
878 if (readOperator (
")"))
879 return TermPtr (func.release());
881 return parseError (
"Expected \")\"");
884 if (readOperator (
"."))
886 TermPtr rhs (readSymbolOrFunction());
889 return parseError (
"Expected symbol or function after \".\"");
891 if (identifier ==
"this")
894 return *
new DotOperator (
new SymbolTerm (identifier), rhs);
898 jassert (identifier.trim() == identifier);
899 return *
new SymbolTerm (identifier);
905 TermPtr readParenthesisedExpression()
907 if (! readOperator (
"("))
910 auto e = readExpression();
912 if (e ==
nullptr || ! readOperator (
")"))
918 JUCE_DECLARE_NON_COPYABLE (Parser)
934 jassert (term !=
nullptr);
954 : term (std::move (
other.term))
960 term = std::move (
other.term);
967 Helpers::Parser
parser (text);
968 term =
parser.readUpToComma();
969 parseError =
parser.error;
976 parseError =
parser.error;
995 return term->resolve (scope, 0)->toDouble();
1014 return Expression (
new Helpers::Function (functionName, parameters));
1019 std::unique_ptr<Term>
newTerm (term->clone());
1028 newTerm.reset (
new Helpers::Add (*
newTerm.release(), *
new Helpers::Constant (0,
false)));
1057 e.term->renameSymbol (
oldSymbol, newName, scope, 0);
1067 term->visitAllSymbols (visitor, scope, 0);
1072 return visitor.wasFound;
1079 Helpers::SymbolListVisitor visitor (results);
1080 term->visitAllSymbols (visitor, scope, 0);
1096 return *
new Helpers::Negate (*
this);
1100Expression::Symbol::Symbol (
const String& scope,
const String& symbol)
1101 : scopeUID (scope), symbolName (symbol)
1105bool Expression::Symbol::operator== (
const Symbol& other)
const noexcept
1107 return symbolName == other.symbolName && scopeUID == other.scopeUID;
1110bool Expression::Symbol::operator!= (
const Symbol& other)
const noexcept
1112 return ! operator== (other);
1116Expression::Scope::Scope() {}
1117Expression::Scope::~Scope() {}
1131 if (functionName ==
"min")
1133 double v = parameters[0];
1135 v = jmin (v, parameters[i]);
1140 if (functionName ==
"max")
1142 double v = parameters[0];
1144 v = jmax (v, parameters[i]);
1151 if (functionName ==
"sin")
return std::sin (parameters[0]);
1152 if (functionName ==
"cos")
return std::cos (parameters[0]);
1153 if (functionName ==
"tan")
return std::tan (parameters[0]);
1154 if (functionName ==
"abs")
return std::abs (parameters[0]);
static double readDoubleValue(CharPointerType &text) noexcept
virtual Expression getSymbolValue(const String &symbol) const
virtual void visitRelativeScope(const String &scopeName, Visitor &visitor) const
virtual String getScopeUID() const
virtual double evaluateFunction(const String &functionName, const double *parameters, int numParameters) const
Expression operator*(const Expression &) const
Expression adjustedToGiveNewResult(double targetValue, const Scope &scope) const
void findReferencedSymbols(Array< Symbol > &results, const Scope &scope) const
Expression operator+(const Expression &) const
static Expression function(const String &functionName, const Array< Expression > ¶meters)
String getSymbolOrFunction() const
bool usesAnySymbols() const
Expression withRenamedSymbol(const Symbol &oldSymbol, const String &newName, const Scope &scope) const
Expression operator/(const Expression &) const
Type getType() const noexcept
static Expression parse(String::CharPointerType &stringToParse, String &parseError)
Expression getInput(int index) const
bool referencesSymbol(const Symbol &symbol, const Scope &scope) const
static Expression symbol(const String &symbol)
Expression operator-() const
Expression & operator=(const Expression &)
String toLowerCase() const
static String charToString(juce_wchar character)
bool containsOnly(StringRef charactersItMightContain) const noexcept