persistent_tasmanian

Required: Tasmanian, pypackaging, scikit-build

Example usage: batched, async

Note that Tasmanian can be pip installed, but currently must use either venv or –user install.

E.g: pip install scikit-build packaging Tasmanian --user

A persistent generator using the uncertainty quantification capabilities in Tasmanian.

persistent_tasmanian.sparse_grid_batched(H, persis_info, gen_specs, libE_info)

Implements batched construction for a Tasmanian sparse grid, using the loop described in Tasmanian Example 09: sparse grid example

persistent_tasmanian.sparse_grid_async(H, persis_info, gen_specs, libE_info)

Implements asynchronous construction for a Tasmanian sparse grid, using the logic in the dynamic Tasmanian model construction function: sparse grid dynamic example

persistent_tasmanian.py
  1"""
  2A persistent generator using the uncertainty quantification capabilities in
  3`Tasmanian <https://github.com/ORNL/Tasmanian>`_.
  4"""
  5
  6import numpy as np
  7
  8from libensemble.alloc_funcs.start_only_persistent import only_persistent_gens as allocf
  9from libensemble.message_numbers import EVAL_GEN_TAG, FINISHED_PERSISTENT_GEN_TAG, PERSIS_STOP, STOP_TAG
 10from libensemble.tools import parse_args
 11from libensemble.tools.persistent_support import PersistentSupport
 12
 13__all__ = [
 14    "sparse_grid_batched",
 15    "sparse_grid_async",
 16]
 17
 18
 19def lex_le(x, y, tol=1e-12):
 20    """
 21    Returns True if x <= y lexicographically up to some tolerance.
 22    """
 23    cmp = np.fabs(x - y) > tol
 24    ind = np.argmax(cmp)
 25    if not cmp[ind]:
 26        return True
 27    return x[ind] <= y[ind]
 28
 29
 30def get_2D_insert_indices(x, y, x_ord=np.empty(0, dtype="int"), y_ord=np.empty(0, dtype="int"), tol=1e-12):
 31    """
 32    Finds the row indices in a 2D numpy array `x` for which the sorted values of `y` can be inserted
 33    into. If `x_ord` (resp. `y_ord`) is empty, then `x` (resp. `y`) must be lexicographically
 34    sorted. Otherwise, `x[x_ord]` (resp. `y[y_ord]`) must be lexicographically sorted. Complexity is
 35    O(x.shape[0] + y.shape[0]).
 36    """
 37    assert len(x.shape) == 2
 38    assert len(y.shape) == 2
 39    if x.size == 0:
 40        return np.zeros(y.shape[0], dtype="int")
 41    else:
 42        if x_ord.size == 0:
 43            x_ord = np.arange(x.shape[0], dtype="int")
 44        if y_ord.size == 0:
 45            y_ord = np.arange(y.shape[0], dtype="int")
 46        x_ptr = 0
 47        y_ptr = 0
 48        out_ord = np.empty(0, dtype="int")
 49        while y_ptr < y.shape[0]:
 50            # The case where y[k] <= max of x[k:end, :]
 51            xk = x[x_ord[x_ptr], :]
 52            yk = y[y_ord[y_ptr], :]
 53            if lex_le(yk, xk, tol=tol):
 54                out_ord = np.append(out_ord, x_ord[x_ptr])
 55                y_ptr += 1
 56            else:
 57                x_ptr += 1
 58                # The edge case where y[k] is the largest of all elements of x.
 59                if x_ptr >= x_ord.shape[0]:
 60                    for i in range(y_ptr, y_ord.shape[0], 1):
 61                        out_ord = np.append(out_ord, x_ord.shape[0])
 62                        y_ptr += 1
 63                    break
 64        return out_ord
 65
 66
 67def get_2D_duplicate_indices(x, y, x_ord=np.empty(0, dtype="int"), y_ord=np.empty(0, dtype="int"), tol=1e-12):
 68    """
 69    Finds the row indices of a 2D numpy array `x` which overlap with `y`. If `x_ord` (resp. `y_ord`)
 70    is empty, then `x` (resp. `y`) must be lexicographically sorted. Otherwise, `x[x_ord]` (resp.
 71    `y[y_ord]`) must be lexicographically sorted.Complexity is O(x.shape[0] + y.shape[0]).
 72    """
 73    assert len(x.shape) == 2
 74    assert len(y.shape) == 2
 75    if x.size == 0:
 76        return np.empty(0, dtype="int")
 77    else:
 78        if x_ord.size == 0:
 79            x_ord = np.arange(x.shape[0], dtype="int")
 80        if y_ord.size == 0:
 81            y_ord = np.arange(y.shape[0], dtype="int")
 82        x_ptr = 0
 83        y_ptr = 0
 84        out_ord = np.empty(0, dtype="int")
 85        while y_ptr < y.shape[0] and x_ptr < x.shape[0]:
 86            # The case where y[k] <= max of x[k:end, :]
 87            xk = x[x_ord[x_ptr], :]
 88            yk = y[y_ord[y_ptr], :]
 89            if all(np.fabs(yk - xk) <= tol):
 90                out_ord = np.append(out_ord, x_ord[x_ptr])
 91                x_ptr += 1
 92            elif lex_le(xk, yk, tol=tol):
 93                x_ptr += 1
 94            else:
 95                y_ptr += 1
 96        return out_ord
 97
 98
 99def get_state(queued_pts, queued_ids, id_offset, new_points=np.array([]), completed_points=np.array([]), tol=1e-12):
100    """
101    Creates the data to be sent and updates the state arrays and scalars if new information
102    (new_points or completed_points) arrives. Ensures that the output state arrays remain sorted if
103    the input state arrays are already sorted.
104    """
105    if new_points.size > 0:
106        new_points_ord = np.lexsort(np.rot90(new_points))
107        new_points_ids = id_offset + np.arange(new_points.shape[0])
108        id_offset += new_points.shape[0]
109        insert_idx = get_2D_insert_indices(queued_pts, new_points, y_ord=new_points_ord, tol=tol)
110        queued_pts = np.insert(queued_pts, insert_idx, new_points[new_points_ord], axis=0)
111        queued_ids = np.insert(queued_ids, insert_idx, new_points_ids[new_points_ord], axis=0)
112
113    if completed_points.size > 0:
114        completed_ord = np.lexsort(np.rot90(completed_points))
115        delete_ind = get_2D_duplicate_indices(queued_pts, completed_points, y_ord=completed_ord, tol=tol)
116        queued_pts = np.delete(queued_pts, delete_ind, axis=0)
117        queued_ids = np.delete(queued_ids, delete_ind, axis=0)
118
119    return queued_pts, queued_ids, id_offset
120
121
122def get_H0(gen_specs, refined_pts, refined_ord, queued_pts, queued_ids, tol=1e-12):
123    """
124    For runs following the first one, get the history array H0 based on the ordering in `refined_pts`
125    """
126
127    def approx_eq(x, y):
128        return np.argmax(np.fabs(x - y)) <= tol
129
130    num_ids = queued_ids.shape[0]
131    H0 = np.zeros(num_ids, dtype=gen_specs["out"])
132    refined_priority = np.flip(np.arange(refined_pts.shape[0], dtype="int"))
133    rptr = 0
134    for qptr in range(num_ids):
135        while not approx_eq(refined_pts[refined_ord[rptr]], queued_pts[qptr]):
136            rptr += 1
137        assert rptr <= refined_pts.shape[0]
138        H0["x"][qptr] = queued_pts[qptr]
139        H0["sim_id"][qptr] = queued_ids[qptr]
140        H0["priority"][qptr] = refined_priority[refined_ord[rptr]]
141    return H0
142
143
144# ========================
145# Main generator functions
146# ========================
147
148
149def sparse_grid_batched(H, persis_info, gen_specs, libE_info):
150    """
151    Implements batched construction for a Tasmanian sparse grid,
152    using the loop described in Tasmanian Example 09:
153    `sparse grid example <https://github.com/ORNL/TASMANIAN/blob/master/InterfacePython/example_sparse_grids_09.py>`_
154
155    """
156    U = gen_specs["user"]
157    ps = PersistentSupport(libE_info, EVAL_GEN_TAG)
158    grid = U["tasmanian_init"]()  # initialize the grid
159    allowed_refinements = [
160        "setAnisotropicRefinement",
161        "getAnisotropicRefinement",
162        "setSurplusRefinement",
163        "getSurplusRefinement",
164        "none",
165    ]
166    assert (
167        "refinement" in U and U["refinement"] in allowed_refinements
168    ), f"Must provide a gen_specs['user']['refinement'] in: {allowed_refinements}"
169
170    while grid.getNumNeeded() > 0:
171        aPoints = grid.getNeededPoints()
172
173        H0 = np.zeros(len(aPoints), dtype=gen_specs["out"])
174        H0["x"] = aPoints
175
176        # Receive values from manager
177        tag, Work, calc_in = ps.send_recv(H0)
178        if tag in [STOP_TAG, PERSIS_STOP]:
179            break
180        aModelValues = calc_in["f"]
181
182        # Update surrogate on grid
183        t = aModelValues.reshape((aModelValues.shape[0], grid.getNumOutputs()))
184        t = t.flatten()
185        t = np.atleast_2d(t).T
186        grid.loadNeededPoints(t)
187
188        if "tasmanian_checkpoint_file" in U:
189            grid.write(U["tasmanian_checkpoint_file"])
190
191        # set refinement, using user["refinement"] to pick the refinement strategy
192        if U["refinement"] in ["setAnisotropicRefinement", "getAnisotropicRefinement"]:
193            assert "sType" in U
194            assert "iMinGrowth" in U
195            assert "iOutput" in U
196            grid.setAnisotropicRefinement(U["sType"], U["iMinGrowth"], U["iOutput"])
197        elif U["refinement"] in ["setSurplusRefinement", "getSurplusRefinement"]:
198            assert "fTolerance" in U
199            assert "iOutput" in U
200            assert "sCriteria" in U
201            grid.setSurplusRefinement(U["fTolerance"], U["iOutput"], U["sCriteria"])
202
203    return None, persis_info, FINISHED_PERSISTENT_GEN_TAG
204
205
206def sparse_grid_async(H, persis_info, gen_specs, libE_info):
207    """
208    Implements asynchronous construction for a Tasmanian sparse grid,
209    using the logic in the dynamic Tasmanian model construction function:
210    `sparse grid dynamic example <https://github.com/ORNL/TASMANIAN/blob/master/Addons/tsgConstructSurrogate.hpp>`_
211
212    """
213    U = gen_specs["user"]
214    ps = PersistentSupport(libE_info, EVAL_GEN_TAG)
215    grid = U["tasmanian_init"]()  # initialize the grid
216    allowed_refinements = ["getCandidateConstructionPoints", "getCandidateConstructionPointsSurplus"]
217    assert (
218        "refinement" in U and U["refinement"] in allowed_refinements
219    ), f"Must provide a gen_specs['user']['refinement'] in: {allowed_refinements}"
220    tol = U["_match_tolerance"] if "_match_tolerance" in U else 1.0e-12
221
222    # Choose the refinement function based on U["refinement"].
223    if U["refinement"] == "getCandidateConstructionPoints":
224        assert "sType" in U
225        assert "liAnisotropicWeightsOrOutput" in U
226    if U["refinement"] == "getCandidateConstructionPointsSurplus":
227        assert "fTolerance" in U
228        assert "sRefinementType" in U
229
230    def get_refined_points(g, U):
231        if U["refinement"] == "getCandidateConstructionPoints":
232            return g.getCandidateConstructionPoints(U["sType"], U["liAnisotropicWeightsOrOutput"])
233        else:
234            assert U["refinement"] == "getCandidateConstructionPointsSurplus"
235            return g.getCandidateConstructionPointsSurplus(U["fTolerance"], U["sRefinementType"])
236        # else:
237        #     raise ValueError("Unknown refinement string")
238
239    # Asynchronous helper and state variables.
240    num_dims = grid.getNumDimensions()
241    num_completed = 0
242    offset = 0
243    queued_pts = np.empty((0, num_dims), dtype="float")
244    queued_ids = np.empty(0, dtype="int")
245
246    # First run.
247    grid.beginConstruction()
248    init_pts = get_refined_points(grid, U)
249    queued_pts, queued_ids, offset = get_state(queued_pts, queued_ids, offset, new_points=init_pts, tol=tol)
250    H0 = np.zeros(init_pts.shape[0], dtype=gen_specs["out"])
251    H0["x"] = init_pts
252    H0["sim_id"] = np.arange(init_pts.shape[0], dtype="int")
253    H0["priority"] = np.flip(H0["sim_id"])
254    tag, Work, calc_in = ps.send_recv(H0)
255
256    # Subsequent runs.
257    while tag not in [STOP_TAG, PERSIS_STOP]:
258        # Parse the points returned by the allocator.
259        num_completed += calc_in["x"].shape[0]
260        queued_pts, queued_ids, offset = get_state(
261            queued_pts, queued_ids, offset, completed_points=calc_in["x"], tol=tol
262        )
263
264        # Compute the next batch of points (if they exist).
265        new_pts = np.empty((0, num_dims), dtype="float")
266        refined_pts = np.empty((0, num_dims), dtype="float")
267        refined_ord = np.empty(0, dtype="int")
268        if grid.getNumLoaded() < 1000 or num_completed > 0.2 * grid.getNumLoaded():
269            # A copy is needed because the data in the calc_in arrays are not contiguous.
270            grid.loadConstructedPoint(np.copy(calc_in["x"]), np.copy(calc_in["f"]))
271            if "tasmanian_checkpoint_file" in U:
272                grid.write(U["tasmanian_checkpoint_file"])
273            refined_pts = get_refined_points(grid, U)
274            # If the refined points are empty, then there is a stopping condition internal to the
275            # Tasmanian sparse grid that is being triggered by the loaded points.
276            if refined_pts.size == 0:
277                break
278            refined_ord = np.lexsort(np.rot90(refined_pts))
279            delete_ind = get_2D_duplicate_indices(refined_pts, queued_pts, x_ord=refined_ord, tol=tol)
280            new_pts = np.delete(refined_pts, delete_ind, axis=0)
281
282        if new_pts.shape[0] > 0:
283            # Update the state variables with the refined points and update the queue in the allocator.
284            num_completed = 0
285            queued_pts, queued_ids, offset = get_state(queued_pts, queued_ids, offset, new_points=new_pts, tol=tol)
286            H0 = get_H0(gen_specs, refined_pts, refined_ord, queued_pts, queued_ids, tol=tol)
287            tag, Work, calc_in = ps.send_recv(H0)
288        else:
289            tag, Work, calc_in = ps.recv()
290
291    return None, persis_info, FINISHED_PERSISTENT_GEN_TAG
292
293
294def get_sparse_grid_specs(user_specs, sim_f, num_dims, num_outputs=1, mode="batched"):
295    """
296    Helper function that generates the simulator, generator, and allocator specs as well as the
297    persis_info dictionary to ensure that they are compatible with the custom generators in this
298    script. The outputs should be used in the main libE() call.
299
300    INPUTS:
301        user_specs  (dict)   : a dictionary of user specs that is needed in the generator specs;
302                               expects the key "tasmanian_init" whose value is a 0-argument lambda
303                               that initializes an appropriate Tasmanian sparse grid object.
304
305        sim_f       (func)   : a lambda function that takes in generator outputs (simulator inputs)
306                               and returns simulator outputs.
307
308        num_dims    (int)    : number of model inputs.
309
310        num_outputs (int)    : number of model outputs.
311
312        mode        (string) : can either be "batched" or "async".
313
314    OUTPUTS:
315        sim_specs   (dict) : a dictionary of simulation specs and also one of the inputs of libE().
316
317        gen_specs   (dict) : a dictionary of generator specs and also one of the inputs of libE().
318
319        alloc_specs (dict) : a dictionary of allocation specs and also one of the inputs of libE().
320
321        persis_info (dict) : a dictionary containing common information that is passed to all
322                             workers and also one of the inputs of libE().
323
324    """
325
326    assert "tasmanian_init" in user_specs
327    assert mode in ["batched", "async"]
328
329    sim_specs = {
330        "sim_f": sim_f,
331        "in": ["x"],
332    }
333    gen_out = [
334        ("x", float, (num_dims,)),
335        ("sim_id", int),
336        ("priority", int),
337    ]
338    gen_specs = {
339        "persis_in": [t[0] for t in gen_out] + ["f"],
340        "out": gen_out,
341        "user": user_specs,
342    }
343    alloc_specs = {
344        "alloc_f": allocf,
345        "user": {},
346    }
347
348    if mode == "batched":
349        gen_specs["gen_f"] = sparse_grid_batched
350        sim_specs["out"] = [("f", float, (num_outputs,))]
351    if mode == "async":
352        gen_specs["gen_f"] = sparse_grid_async
353        sim_specs["out"] = [("x", float, (num_dims,)), ("f", float, (num_outputs,))]
354        alloc_specs["user"]["active_recv_gen"] = True
355        alloc_specs["user"]["async_return"] = True
356
357    nworkers, _, _, _ = parse_args()
358    persis_info = {}
359    for i in range(nworkers + 1):
360        persis_info[i] = {"worker_num": i}
361
362    return sim_specs, gen_specs, alloc_specs, persis_info