The Sparta Modeling Framework
Loading...
Searching...
No Matches
WeightedContextCounter.hpp
1
2#pragma once
3
4#include "sparta/statistics/ContextCounter.hpp"
7#include "sparta/statistics/ReadOnlyCounter.hpp"
9
10namespace sparta {
11
21template <class CounterT>
23{
24public:
25 template <class... CounterTArgs>
27 const std::string & name,
28 const std::string & desc,
29 const size_t num_contexts,
30 CounterTArgs && ...args) :
32 name,
33 desc,
34 static_cast<uint32_t>(num_contexts),
35 "Testing",
36 std::forward<CounterTArgs>(args)...),
37 weights_(this->numContexts(), 1)
38 {
39 REGISTER_CONTEXT_COUNTER_AGGREGATE_FCN(
40 WeightedContextCounter<CounterT>, weightedAvg_, calculated_average_);
41
42 REGISTER_CONTEXT_COUNTER_AGGREGATE_FCN(
43 WeightedContextCounter<CounterT>, max_, maximum_);
44 }
45
46 const CounterT & context(const uint32_t idx) const {
48 }
49
50 CounterT & context(const uint32_t idx) {
52 }
53
54 void assignContextWeights(const std::vector<double> & weights) {
55 if (!weights.empty()) {
56 if (weights.size() > 1 && weights.size() != this->numContexts()) {
57 throw SpartaException("Invalid weights passed to WeightedContextCounter. The ")
58 << "weights vector passed in had " << weights.size() << " values in it, "
59 << "but this context counter has " << this->numContexts() << " contexts in it.";
60 }
61 weights_ = weights;
62 }
63 if (weights_.size() == 1) {
64 const double weight = weights_[0];
65 weights_ = std::vector<double>(this->numContexts(), weight);
66 }
67 }
68
69 double calculateWeightedAverage() {
70 weightedAvg_();
71 return calculated_average_;
72 }
73
74private:
75 void weightedAvg_() {
76 calculated_average_ = 0;
77 auto weight_iter = weights_.begin();
78 for (const auto & internal_ctr : *this) {
79 calculated_average_ += internal_ctr.get() * *weight_iter++;
80 }
81 calculated_average_ /= this->numContexts();
82 }
83
84 void max_() {
85 sparta_assert(this->numContexts() > 0);
86 maximum_ = std::numeric_limits<double>::min();
87 for (const auto & internal_ctr : *this) {
88 maximum_ = std::max(maximum_, static_cast<double>(internal_ctr.get()));
89 }
90 }
91
92 std::vector<double> weights_;
93 double calculated_average_;
94 double maximum_;
95};
96
97} // namespace sparta
98
#define sparta_assert(...)
Simple variadic assertion that will throw a sparta_exception if the condition fails.
File that contains the macro used to generate the class callbacks.
Contains a statistic definition (some useful information which can be computed)
File that defines the StatisticSet class.
A container type that allows a modeler to build, store, and charge counts to a specific context.
uint32_t numContexts() const
Return the number of contexts in this ContextCounter.
const counter_type & context(const uint32_t idx) const
Return the internal counter at the given context.
Used to construct and throw a standard C++ exception. Inherits from std::exception.
Set of StatisticDef and CounterBase-derived objects for visiblility through a sparta Tree.
This is an example context counter subclass used to show how users may supply their own "aggregated v...
Macros for handling exponential backoff.