Logo
Fully annotated reference manual - version 1.8.12
Loading...
Searching...
No Matches
backwardderivatives.hpp
Go to the documentation of this file.
1/*
2 Copyright (C) 2021 Quaternion Risk Management Ltd
3 All rights reserved.
4
5 This file is part of ORE, a free-software/open-source library
6 for transparent pricing and risk analysis - http://opensourcerisk.org
7
8 ORE is free software: you can redistribute it and/or modify it
9 under the terms of the Modified BSD License. You should have received a
10 copy of the license along with this program.
11 The license is also available online at <http://opensourcerisk.org>
12
13 This program is distributed on the basis that it will form a useful
14 contribution to risk analytics and model standardisation, but WITHOUT
15 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
16 FITNESS FOR A PARTICULAR PURPOSE. See the license for more details.
17*/
18
19/*! \file qle/ad/backwardderivatives.hpp
20 \brief backward derivatives computation
21*/
22
23#pragma once
24
26
27#include <ql/errors.hpp>
28
29#include <ql/shared_ptr.hpp>
30
31namespace QuantExt {
32
33template <class T>
34void backwardDerivatives(const ComputationGraph& g, std::vector<T>& values, std::vector<T>& derivatives,
35 const std::vector<std::function<std::vector<T>(const std::vector<const T*>&, const T*)>>& grad,
36 std::function<void(T&)> deleter = {}, const std::vector<bool>& keepNodes = {},
37 const std::vector<std::function<T(const std::vector<const T*>&)>>& fwdOps = {},
38 const std::vector<std::function<std::pair<std::vector<bool>, bool>(const std::size_t)>>&
39 fwdOpRequiresNodesForDerivatives = {},
40 const std::vector<bool>& fwdKeepNodes = {}, const std::size_t conditionalExpectationOpId = 0,
41 const std::function<T(const std::vector<const T*>&)>& conditionalExpectation = {}) {
42
43 if (g.size() == 0)
44 return;
45
46 std::size_t redBlockId = 0;
47
48 // loop over the nodes in the graph in reverse order
49
50 for (std::size_t node = g.size() - 1; node > 0; --node) {
51
52 if (g.redBlockId(node) != redBlockId) {
53
54 // delete the values in the previous red block
55
56 if (deleter && redBlockId > 0) {
57 auto range = g.redBlockRanges()[redBlockId - 1];
58 QL_REQUIRE(range.second != ComputationGraph::nan,
59 "backwardDerivatives(): red block " << redBlockId << " was not closed.");
60 for (std::size_t n = range.first; n < range.second; ++n) {
61 if (g.redBlockId(n) == redBlockId && !fwdKeepNodes[n])
62 deleter(values[n]);
63 }
64 }
65
66 // populate the values in the current red block
67
68 if (g.redBlockId(node) > 0) {
69 auto range = g.redBlockRanges()[g.redBlockId(node) - 1];
70 QL_REQUIRE(range.second != ComputationGraph::nan,
71 "backwardDerivatives(): red block " << g.redBlockId(node) << " was not closed.");
72 forwardEvaluation(g, values, fwdOps, deleter, true, fwdOpRequiresNodesForDerivatives, fwdKeepNodes,
73 range.first, range.second, true);
74 }
75
76 // update the red block id
77
78 redBlockId = g.redBlockId(node);
79 }
80
81 if (!g.predecessors(node).empty() && !isDeterministicAndZero(derivatives[node])) {
82
83 // propagate the derivative at a node to its predecessors
84
85 std::vector<const T*> args(g.predecessors(node).size());
86 for (std::size_t arg = 0; arg < g.predecessors(node).size(); ++arg) {
87 args[arg] = &values[g.predecessors(node)[arg]];
88 }
89
90 QL_REQUIRE(derivatives[node].initialised(),
91 "backwardDerivatives(): derivative at active node " << node << " is not initialized.");
92
93 if (g.opId(node) == conditionalExpectationOpId && conditionalExpectation) {
94
95 // expected stochastic automatic differentiaion, Fries, 2017
96 args[0] = &derivatives[node];
97 derivatives[g.predecessors(node)[0]] += conditionalExpectation(args);
98
99 } else {
100
101 auto gr = grad[g.opId(node)](args, &values[node]);
102
103 for (std::size_t p = 0; p < g.predecessors(node).size(); ++p) {
104 QL_REQUIRE(derivatives[g.predecessors(node)[p]].initialised(),
105 "backwardDerivatives: derivative at node "
106 << g.predecessors(node)[p] << " not initialized, which is an active predecessor of "
107 << node);
108 QL_REQUIRE(gr[p].initialised(),
109 "backwardDerivatives: gradient at node "
110 << node << " (opId " << g.opId(node) << ") not initialized at component " << p
111 << " but required to push to predecessor " << g.predecessors(node)[p]);
112 derivatives[g.predecessors(node)[p]] += derivatives[node] * gr[p];
113 }
114 }
115 }
116
117 // then check if we can delete the node
118
119 if (deleter) {
120
121 // is the node marked as to be kept?
122
123 if (!keepNodes.empty() && keepNodes[node])
124 continue;
125
126 // apply the deleter
127
128 deleter(derivatives[node]);
129 }
130
131 } // for node
132}
133
134} // namespace QuantExt
const std::vector< std::size_t > & predecessors(const std::size_t node) const
std::size_t redBlockId(const std::size_t node) const
const std::vector< std::pair< std::size_t, std::size_t > > & redBlockRanges() const
std::size_t opId(const std::size_t node) const
computation graph
bool isDeterministicAndZero(const ExternalRandomVariable &x)
RandomVariable conditionalExpectation(const std::vector< const RandomVariable * > &regressor, const std::vector< std::function< RandomVariable(const std::vector< const RandomVariable * > &)> > &basisFn, const Array &coefficients)
void backwardDerivatives(const ComputationGraph &g, std::vector< T > &values, std::vector< T > &derivatives, const std::vector< std::function< std::vector< T >(const std::vector< const T * > &, const T *)> > &grad, std::function< void(T &)> deleter={}, const std::vector< bool > &keepNodes={}, const std::vector< std::function< T(const std::vector< const T * > &)> > &fwdOps={}, const std::vector< std::function< std::pair< std::vector< bool >, bool >(const std::size_t)> > &fwdOpRequiresNodesForDerivatives={}, const std::vector< bool > &fwdKeepNodes={}, const std::size_t conditionalExpectationOpId=0, const std::function< T(const std::vector< const T * > &)> &conditionalExpectation={})
void forwardEvaluation(const ComputationGraph &g, std::vector< T > &values, const std::vector< std::function< T(const std::vector< const T * > &)> > &ops, std::function< void(T &)> deleter={}, bool keepValuesForDerivatives=true, const std::vector< std::function< std::pair< std::vector< bool >, bool >(const std::size_t)> > &opRequiresNodesForDerivatives={}, const std::vector< bool > &keepNodes={}, const std::size_t startNode=0, const std::size_t endNode=ComputationGraph::nan, const bool redBlockReconstruction=false)