Iterative Modules

Modules that are meant to be called as part of a loop (typically when optimizing a quantity) require special consideration. Consider a module that takes an input \(i_0\) and returns a result \(r_1\). If \(i\) depends on \(r\) we can then compute a new \(i\), \(i_1\) using \(r_1\), and call the module again generating \(r_2\). This can be repeated until convergence. There’s two main considerations for an iterative module: intermediate quantities and memoization.

Many of the more sophisticated algorithms for iterative refinement, in one way or another, use information from previous iterations to help accelerate convergence. It is essential to the algorithm to be able to access the information from the previous iteration. Memoization leads to a similar difficulty. Say our iterative process ran \(m\) times, so we have computed an \(r_1\), an \(r_2\), on up to an \(r_m\). When the calculation restarts we want to jump back to \(r_m\), but the input to the module will be \(i_0\) and thus default memoization will provide us \(r_1\). In theory if we stored all of the intermediate \(r\) values: \(r_2\), \(r_3\), etc. up to \(r_m\), we quickly would regain parity with the original computation. In practice we usually delete the intermediates to save memory meaning we no longer have access to them. While it is possible to solve the memoization problem by recasting iteration as recursion, it is not natural in many circumstances and we wish to avoid it.

PluginPlay’s solution to this problem is to allow each module to cache intermediate results. Here we mean “intermediate” in the sense that they are either not the result the module will return or they are not the converged result. Assuming that the property type for the iterative module is defined such that it always takes the intial inputs, in addition to the current inputs, one can then save the intermediate results in the cache unsing the initial value as the key.

As an example we show how to write a simple module for determining a root of the cosine function given an interval around the root:

/*
 * Copyright 2022 NWChemEx-Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <cmath>
#include <pluginplay/pluginplay.hpp>

namespace pluginplay_examples {

using pair_type = std::pair<float, float>;

DECLARE_PROPERTY_TYPE(RootFinder);

PROPERTY_TYPE_INPUTS(RootFinder) {
    auto rv = pluginplay::declare_input()
                .add_field<pair_type>("Initial Interval")
                .template add_field<pair_type>("Current Interval");
    return rv;
}

PROPERTY_TYPE_RESULTS(RootFinder) {
    return pluginplay::declare_result().add_field<pair_type>("New interval");
}

DECLARE_MODULE(CosBisection);

MODULE_CTOR(CosBisection) {
    description("Computes a root of cosine by bisecting an interval.");
    satisfies_property_type<RootFinder>();
}

MODULE_RUN(CosBisection) {
    // Get 0-th and n-th inputs
    const auto& [i0, in] = RootFinder::unwrap_inputs(inputs);

    // Get the cache
    auto& my_cache = get_cache();

    // If i0 == in we look for a cached value, otherwise we use in
    auto [a, b] = in;
    if(i0 == in && my_cache.count(i0)) {
        std::tie(a, b) = my_cache.uncache<pair_type>(i0);
    }

    // Compute value of cos on interval's enpoints
    auto fa = std::cos(a);
    auto fb = std::cos(b);

    // Make sure one of the endpoints is negative
    if((fa < 0.0 && fb < 0.0) || (fa > 0.0 && fb > 0.0))
        throw std::runtime_error("Must have opposite signs on endpoints");

    // Compute midpoint of interval and the value of cos at that point
    auto c  = (a + b) / 2.0;
    auto fc = std::cos(c);

    // The new interval is c and either a or b, whichever has the opposite sign
    // of c
    pair_type new_interval;
    if(fc < 0.0) { // fc is negative, keep old positive endpoint
        new_interval.first  = c;
        new_interval.second = (fa < 0.0) ? b : a;
    } else { // fc is positive (or zero), keep old negative endpoint
        new_interval.first  = (fa < 0.0) ? a : b;
        new_interval.second = c;
    }

    // Store our newest interval and return it
    my_cache.cache(i0, new_interval);

    auto rv = results();
    return RootFinder::wrap_results(rv, new_interval);
}

} // namespace pluginplay_examples