Logo
Fully annotated reference manual - version 1.8.12
Loading...
Searching...
No Matches
computeenvironment.cpp
Go to the documentation of this file.
1/*
2 Copyright (C) 2023 Quaternion Risk Management Ltd
3 All rights reserved.
4
5 This file is part of ORE, a free-software/open-source library
6 for transparent pricing and risk analysis - http://opensourcerisk.org
7
8 ORE is free software: you can redistribute it and/or modify it
9 under the terms of the Modified BSD License. You should have received a
10 copy of the license along with this program.
11 The license is also available online at <http://opensourcerisk.org>
12
13 This program is distributed on the basis that it will form a useful
14 contribution to risk analytics and model standardisation, but WITHOUT
15 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
16 FITNESS FOR A PARTICULAR PURPOSE. See the license for more details.
17*/
18
20
21#include <boost/algorithm/string/join.hpp>
22
23#include <ql/errors.hpp>
24
25namespace QuantExt {
26
28
30
32 for (auto& f : frameworks_)
33 delete f;
34 frameworks_.clear();
35}
36
38 currentContext_ = nullptr;
41 for (auto& c : ComputeFrameworkRegistry::instance().getAll())
42 frameworks_.push_back(c());
43}
44
45std::set<std::string> ComputeEnvironment::getAvailableDevices() const {
46 std::set<std::string> result;
47 for (auto const& f : frameworks_) {
48 auto tmp = f->getAvailableDevices();
49 result.insert(tmp.begin(), tmp.end());
50 }
51 return result;
52}
53
54bool ComputeEnvironment::hasContext() const { return currentContext_ != nullptr; }
55
56void ComputeEnvironment::selectContext(const std::string& deviceName) {
57 if (currentContextDeviceName_ == deviceName)
58 return;
59 for (auto& f : frameworks_) {
60 if (auto tmp = f->getAvailableDevices(); tmp.find(deviceName) != tmp.end()) {
61 currentContext_ = f->getContext(deviceName);
63 currentContextDeviceName_ = deviceName;
64 return;
65 }
66 }
67 QL_FAIL("ComputeEnvironment::selectContext(): device '"
68 << deviceName << "' not found. Available devices: " << boost::join(getAvailableDevices(), ","));
69}
70
72
73void ComputeContext::finalizeCalculation(std::vector<std::vector<double>>& output) {
74 std::vector<double*> outputPtr(output.size());
75 std::transform(output.begin(), output.end(), outputPtr.begin(),
76 [](std::vector<double>& v) -> double* { return &v[0]; });
77 finalizeCalculation(outputPtr);
78}
79
80void ComputeFrameworkRegistry::add(const std::string& name, std::function<ComputeFramework*(void)> creator,
81 const bool allowOverwrite) {
82 boost::unique_lock<boost::shared_mutex> lock(mutex_);
83 QL_REQUIRE(allowOverwrite || std::find(names_.begin(), names_.end(), name) == names_.end(),
84 "FrameworkRegistry::add(): creator for '"
85 << name << "' already exists and allowOverwrite is false, can't add it.");
86 names_.push_back(name);
87 creators_.push_back(creator);
88}
89
90const std::vector<std::function<ComputeFramework*(void)>>& ComputeFrameworkRegistry::getAll() const {
91 boost::shared_lock<boost::shared_mutex> lock(mutex_);
92 return creators_;
93}
94
95}; // namespace QuantExt
virtual void finalizeCalculation(std::vector< double * > &output)=0
virtual void init()=0
std::set< std::string > getAvailableDevices() const
std::vector< ComputeFramework * > frameworks_
void selectContext(const std::string &deviceName)
const std::vector< std::function< ComputeFramework *(void)> > & getAll() const
void add(const std::string &name, std::function< ComputeFramework *(void)> creator, const bool allowOverwrite=false)
std::vector< std::function< ComputeFramework *(void)> > creators_
interface to compute envs