Logo
Fully annotated reference manual - version 1.8.12
Loading...
Searching...
No Matches
forwardevaluation.hpp
Go to the documentation of this file.
1/*
2 Copyright (C) 2021 Quaternion Risk Management Ltd
3 All rights reserved.
4
5 This file is part of ORE, a free-software/open-source library
6 for transparent pricing and risk analysis - http://opensourcerisk.org
7
8 ORE is free software: you can redistribute it and/or modify it
9 under the terms of the Modified BSD License. You should have received a
10 copy of the license along with this program.
11 The license is also available online at <http://opensourcerisk.org>
12
13 This program is distributed on the basis that it will form a useful
14 contribution to risk analytics and model standardisation, but WITHOUT
15 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
16 FITNESS FOR A PARTICULAR PURPOSE. See the license for more details.
17*/
18
19/*! \file qle/ad/forwardevaluation.hpp
20 \brief forward evaluation
21*/
22
23#pragma once
24
26
27#include <ql/shared_ptr.hpp>
28
29namespace QuantExt {
30
31template <class T>
32void forwardEvaluation(const ComputationGraph& g, std::vector<T>& values,
33 const std::vector<std::function<T(const std::vector<const T*>&)>>& ops,
34 std::function<void(T&)> deleter = {}, bool keepValuesForDerivatives = true,
35 const std::vector<std::function<std::pair<std::vector<bool>, bool>(const std::size_t)>>&
36 opRequiresNodesForDerivatives = {},
37 const std::vector<bool>& keepNodes = {}, const std::size_t startNode = 0,
38 const std::size_t endNode = ComputationGraph::nan, const bool redBlockReconstruction = false) {
39
40 std::vector<bool> keepNodesDerivatives;
41 if (deleter && keepValuesForDerivatives)
42 keepNodesDerivatives = std::vector<bool>(g.size(), false);
43
44 // loop over the nodes in the graph in ascending order
45
46 for (std::size_t node = startNode; node < (endNode == ComputationGraph::nan ? g.size() : endNode); ++node) {
47
48 // if a node is computed by an op applied to predecessors ...
49
50 if (!g.predecessors(node).empty()) {
51
52 // evaluate the node
53
54 std::vector<const T*> args(g.predecessors(node).size());
55 for (std::size_t arg = 0; arg < g.predecessors(node).size(); ++arg) {
56 args[arg] = &values[g.predecessors(node)[arg]];
57 }
58 values[node] = ops[g.opId(node)](args);
59
60 QL_REQUIRE(values[node].initialised(), "forwardEvaluation(): value at active node "
61 << node << " is not initialized, opId = " << g.opId(node));
62
63 // then check if we can delete the predecessors
64
65 if (deleter) {
66 for (std::size_t arg = 0; arg < g.predecessors(node).size(); ++arg) {
67 std::size_t p = g.predecessors(node)[arg];
68
69 if (!keepNodesDerivatives.empty()) {
70
71 // is the node required to compute derivatives, then add it to the keep nodes vector
72
73 if (opRequiresNodesForDerivatives[g.opId(p)](args.size()).second ||
74 opRequiresNodesForDerivatives[g.opId(node)](args.size()).first[arg])
75 keepNodesDerivatives[p] = true;
76 }
77
78 // is the node no longer needed for the forward evaluation?
79
80 if (g.maxNodeRequiringArg(p) > node)
81 continue;
82
83 // is the node marked as to be kept ?
84
85 if ((!keepNodes.empty() && keepNodes[p]) ||
86 (!keepNodesDerivatives.empty() && keepNodesDerivatives[p] &&
87 (g.redBlockId(p) == 0 || redBlockReconstruction)))
88 continue;
89
90 // apply the deleter
91
92 deleter(values[p]);
93
94 } // for arg over g.predecessors
95 } // if deleter
96 } // if !g.predecessors empty
97 } // for node
98}
99
100} // namespace QuantExt
const std::vector< std::size_t > & predecessors(const std::size_t node) const
std::size_t maxNodeRequiringArg(const std::size_t node) const
std::size_t redBlockId(const std::size_t node) const
std::size_t opId(const std::size_t node) const
computation graph
void forwardEvaluation(const ComputationGraph &g, std::vector< T > &values, const std::vector< std::function< T(const std::vector< const T * > &)> > &ops, std::function< void(T &)> deleter={}, bool keepValuesForDerivatives=true, const std::vector< std::function< std::pair< std::vector< bool >, bool >(const std::size_t)> > &opRequiresNodesForDerivatives={}, const std::vector< bool > &keepNodes={}, const std::size_t startNode=0, const std::size_t endNode=ComputationGraph::nan, const bool redBlockReconstruction=false)