Logo
Fully annotated reference manual - version 1.8.12
Loading...
Searching...
No Matches
forwardderivatives.hpp
Go to the documentation of this file.
1/*
2 Copyright (C) 2021 Quaternion Risk Management Ltd
3 All rights reserved.
4
5 This file is part of ORE, a free-software/open-source library
6 for transparent pricing and risk analysis - http://opensourcerisk.org
7
8 ORE is free software: you can redistribute it and/or modify it
9 under the terms of the Modified BSD License. You should have received a
10 copy of the license along with this program.
11 The license is also available online at <http://opensourcerisk.org>
12
13 This program is distributed on the basis that it will form a useful
14 contribution to risk analytics and model standardisation, but WITHOUT
15 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
16 FITNESS FOR A PARTICULAR PURPOSE. See the license for more details.
17*/
18
19/*! \file qle/ad/forwardderivatives.hpp
20 \brief forward derivatives computation
21*/
22
23#pragma once
24
26
27#include <ql/shared_ptr.hpp>
28
29namespace QuantExt {
30
31/* Note: This formulation assumes a separate forward run to calculate the values. We could combine the calculation of
32 the values and derivatives and apply the deleter to improve memory consumption */
33
34template <class T>
35void forwardDerivatives(const ComputationGraph& g, const std::vector<T>& values, std::vector<T>& derivatives,
36 const std::vector<std::function<std::vector<T>(const std::vector<const T*>&, const T*)>>& grad,
37 std::function<void(T&)> deleter = {}, const std::vector<bool>& keepNodes = {},
38 const std::size_t conditionalExpectationOpId = 0,
39 const std::function<T(const std::vector<const T*>&)>& conditionalExpectation = {}) {
40
41 if (g.size() == 0)
42 return;
43
44 // loop over the nodes in the graph in forward order
45
46 for (std::size_t node = 0; node < g.size(); ++node) {
47 if (!g.predecessors(node).empty()) {
48
49 // propagate the derivatives from predecessors of a node to the node
50
51 std::vector<const T*> args(g.predecessors(node).size());
52 for (std::size_t arg = 0; arg < g.predecessors(node).size(); ++arg) {
53 args[arg] = &values[g.predecessors(node)[arg]];
54 }
55
56 if (g.opId(node) == conditionalExpectationOpId && conditionalExpectation) {
57
58 args[0] = &derivatives[g.predecessors(node)[0]];
59 derivatives[node] = conditionalExpectation(args);
60
61 } else {
62
63 auto gr = grad[g.opId(node)](args, &values[node]);
64
65 for (std::size_t p = 0; p < g.predecessors(node).size(); ++p) {
66 derivatives[node] += derivatives[g.predecessors(node)[p]] * gr[p];
67 }
68 }
69
70 // the check if we can delete the predecessors
71
72 if (deleter) {
73 for (std::size_t arg = 0; arg < g.predecessors(node).size(); ++arg) {
74 std::size_t p = g.predecessors(node)[arg];
75
76 // is the node no longer needed for other target nodes?
77
78 if (g.maxNodeRequiringArg(p) > node)
79 continue;
80
81 // is the node marked as to be kept ?
82
83 if (!keepNodes.empty() && keepNodes[p])
84 continue;
85
86 // apply the deleter
87
88 deleter(derivatives[p]);
89
90 } // for arg over g.predecessors
91 } // if deleter
92 } // if !g.predecessors empty
93 } // for node
94}
95
96} // namespace QuantExt
const std::vector< std::size_t > & predecessors(const std::size_t node) const
std::size_t maxNodeRequiringArg(const std::size_t node) const
std::size_t opId(const std::size_t node) const
computation graph
void forwardDerivatives(const ComputationGraph &g, const std::vector< T > &values, std::vector< T > &derivatives, const std::vector< std::function< std::vector< T >(const std::vector< const T * > &, const T *)> > &grad, std::function< void(T &)> deleter={}, const std::vector< bool > &keepNodes={}, const std::size_t conditionalExpectationOpId=0, const std::function< T(const std::vector< const T * > &)> &conditionalExpectation={})
RandomVariable conditionalExpectation(const std::vector< const RandomVariable * > &regressor, const std::vector< std::function< RandomVariable(const std::vector< const RandomVariable * > &)> > &basisFn, const Array &coefficients)