Source code for opt_einsum.paths

"""
Contains the path technology behind opt_einsum in addition to several path helpers
"""
import itertools

from . import helpers


[docs]def optimal(input_sets, output_set, idx_dict, memory_limit): """ Computes all possible pair contractions, sieves the results based on ``memory_limit`` and returns the lowest cost path. This algorithm scales factorial with respect to the elements in the list ``input_sets``. Parameters ---------- input_sets : list List of sets that represent the lhs side of the einsum subscript output_set : set Set that represents the rhs side of the overall einsum subscript idx_dict : dictionary Dictionary of index sizes memory_limit : int The maximum number of elements in a temporary array Returns ------- path : list The optimal contraction order within the memory limit constraint. Examples -------- >>> isets = [set('abd'), set('ac'), set('bdc')] >>> oset = set('') >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} >>> optimal(isets, oset, idx_sizes, 5000) [(0, 2), (0, 1)] """ full_results = [(0, [], input_sets)] for iteration in range(len(input_sets) - 1): iter_results = [] # Compute all unique pairs comb_iter = tuple(itertools.combinations(range(len(input_sets) - iteration), 2)) for curr in full_results: cost, positions, remaining = curr for con in comb_iter: # Find the contraction contract = helpers.find_contraction(con, remaining, output_set) new_result, new_input_sets, idx_removed, idx_contract = contract # Sieve the results based on memory_limit new_size = helpers.compute_size_by_dict(new_result, idx_dict) if new_size > memory_limit: continue # Build (total_cost, positions, indices_remaining) total_cost = cost + helpers.flop_count(idx_contract, idx_removed, len(con), idx_dict) new_pos = positions + [con] iter_results.append((total_cost, new_pos, new_input_sets)) # Update combinatorial list, if we did not find anything return best # path + remaining contractions if iter_results: full_results = iter_results else: path = min(full_results, key=lambda x: x[0])[1] path += [tuple(range(len(input_sets) - iteration))] return path # Update list to iterate over full_results = iter_results # If we have not found anything return single einsum contraction if len(full_results) == 0: return [tuple(range(len(input_sets)))] path = min(full_results, key=lambda x: x[0])[1] return path
def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost): """Compute the cost (removed size + flops) and resultant indices for performing the contraction specified by ``positions``. Parameters ---------- positions : tuple of int The locations of the proposed tensors to contract. input_sets : list of sets The indices found on each tensors. output_set : set The output indices of the expression. idx_dict : dict Mapping of each index to its size. memory_limit : int The total allowed size for an intermediary tensor. path_cost : int The contraction cost so far. naive_cost : int The cost of the unoptimized expression. Returns ------- cost : (int, int) A tuple containing the size of any indices removed, and the flop cost. positions : tuple of int The locations of the proposed tensors to contract. new_input_sets : list of sets The resulting new list of indices if this proposed contraction is performed. """ # Find the contraction contract = helpers.find_contraction(positions, input_sets, output_set) idx_result, new_input_sets, idx_removed, idx_contract = contract # Sieve the results based on memory_limit new_size = helpers.compute_size_by_dict(idx_result, idx_dict) if new_size > memory_limit: return None # Build sort tuple old_sizes = (helpers.compute_size_by_dict(input_sets[p], idx_dict) for p in positions) removed_size = sum(old_sizes) - new_size # NB: removed_size used to be just the size of any removed indices i.e.: # helpers.compute_size_by_dict(idx_removed, idx_dict) cost = helpers.flop_count(idx_contract, idx_removed, len(positions), idx_dict) sort = (-removed_size, cost) # Sieve based on total cost as well if (path_cost + cost) > naive_cost: return None # Add contraction to possible choices return [sort, positions, new_input_sets] def _update_other_results(results, best): """Update the positions and provisional input_sets of ``results`` based on performing the contraction result ``best``. Remove any involving the tensors contracted. Parameters ---------- results : List of contraction results produced by ``_parse_possible_contraction``. best : The best contraction of ``results`` i.e. the one that will be performed. Returns ------- mod_results : The list of modifed results, updated with outcome of ``best`` contraction. """ best_con = best[1] bx, by = best_con mod_results = [] for cost, (x, y), con_sets in results: # Ignore results involving tensors just contracted if x in best_con or y in best_con: continue # Update the input_sets del con_sets[by - int(by > x) - int(by > y)] del con_sets[bx - int(bx > x) - int(bx > y)] con_sets.insert(-1, best[2][-1]) # Update the position indices mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by) mod_results.append((cost, mod_con, con_sets)) return mod_results
[docs]def greedy(input_sets, output_set, idx_dict, memory_limit): """ Finds the path by contracting the best pair until the input list is exhausted. The best pair is found by minimizing the tuple ``(-removed_size, cost)``. What this amounts to is prioritizing inner product operations, matrix multiplication, then Hadamard like operations, and finally outer operations. Outer products are limited by ``memory_limit`` and are ignored until no other operations are available. This algorithm scales quadratically with respect to the number of elements in the list ``input_sets``. Parameters ---------- input_sets : list List of sets that represent the lhs side of the einsum subscript output_set : set Set that represents the rhs side of the overall einsum subscript idx_dict : dictionary Dictionary of index sizes memory_limit : int The maximum number of elements in a temporary array Returns ------- path : list The greedy contraction order within the memory limit constraint. Examples -------- >>> isets = [set('abd'), set('ac'), set('bdc')] >>> oset = set('') >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} >>> greedy(isets, oset, idx_sizes, 5000) [(0, 2), (0, 1)] """ # Build up a naive cost contract = helpers.find_contraction(range(len(input_sets)), input_sets, output_set) idx_result, new_input_sets, idx_removed, idx_contract = contract naive_cost = helpers.flop_count(idx_contract, idx_removed, len(input_sets), idx_dict) comb_iter = itertools.combinations(range(len(input_sets)), 2) iteration_results = [] path_cost = 0 path = [] for iteration in range(len(input_sets) - 1): # Iterate over all pairs on first step, only previously found pairs on subsequent steps for positions in comb_iter: # Always initially ignore outer products if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]): continue result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost) if result is not None: iteration_results.append(result) # If we do not have a inner contraction, rescan pairs including outer products if len(iteration_results) == 0: # Then check the outer products for positions in itertools.combinations(range(len(input_sets)), 2): result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost) if result is not None: iteration_results.append(result) # If we still did not find any remaining contractions, default back to einsum like behavior if len(iteration_results) == 0: path.append(tuple(range(len(input_sets)))) break # Sort based on first index best = min(iteration_results, key=lambda x: x[0]) # Now propagate as many unused contractions as possible to next iteration iteration_results = _update_other_results(iteration_results, best) # Next iteration only compute contractions with the new tensor # All other contractions have been accounted for input_sets = best[2] new_tensor_pos = len(input_sets) - 1 comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos)) # Update path and total cost path.append(best[1]) path_cost += best[0][1] return path