"""
Contains the path technology behind opt_einsum in addition to several path helpers
"""
import itertools
from . import helpers
[docs]def optimal(input_sets, output_set, idx_dict, memory_limit):
"""
Computes all possible pair contractions, sieves the results based
on ``memory_limit`` and returns the lowest cost path. This algorithm
scales factorial with respect to the elements in the list ``input_sets``.
Parameters
----------
input_sets : list
List of sets that represent the lhs side of the einsum subscript
output_set : set
Set that represents the rhs side of the overall einsum subscript
idx_dict : dictionary
Dictionary of index sizes
memory_limit : int
The maximum number of elements in a temporary array
Returns
-------
path : list
The optimal contraction order within the memory limit constraint.
Examples
--------
>>> isets = [set('abd'), set('ac'), set('bdc')]
>>> oset = set('')
>>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
>>> optimal(isets, oset, idx_sizes, 5000)
[(0, 2), (0, 1)]
"""
full_results = [(0, [], input_sets)]
for iteration in range(len(input_sets) - 1):
iter_results = []
# Compute all unique pairs
comb_iter = tuple(itertools.combinations(range(len(input_sets) - iteration), 2))
for curr in full_results:
cost, positions, remaining = curr
for con in comb_iter:
# Find the contraction
contract = helpers.find_contraction(con, remaining, output_set)
new_result, new_input_sets, idx_removed, idx_contract = contract
# Sieve the results based on memory_limit
new_size = helpers.compute_size_by_dict(new_result, idx_dict)
if new_size > memory_limit:
continue
# Build (total_cost, positions, indices_remaining)
total_cost = cost + helpers.flop_count(idx_contract, idx_removed, len(con), idx_dict)
new_pos = positions + [con]
iter_results.append((total_cost, new_pos, new_input_sets))
# Update combinatorial list, if we did not find anything return best
# path + remaining contractions
if iter_results:
full_results = iter_results
else:
path = min(full_results, key=lambda x: x[0])[1]
path += [tuple(range(len(input_sets) - iteration))]
return path
# Update list to iterate over
full_results = iter_results
# If we have not found anything return single einsum contraction
if len(full_results) == 0:
return [tuple(range(len(input_sets)))]
path = min(full_results, key=lambda x: x[0])[1]
return path
def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost):
"""Compute the cost (removed size + flops) and resultant indices for
performing the contraction specified by ``positions``.
Parameters
----------
positions : tuple of int
The locations of the proposed tensors to contract.
input_sets : list of sets
The indices found on each tensors.
output_set : set
The output indices of the expression.
idx_dict : dict
Mapping of each index to its size.
memory_limit : int
The total allowed size for an intermediary tensor.
path_cost : int
The contraction cost so far.
naive_cost : int
The cost of the unoptimized expression.
Returns
-------
cost : (int, int)
A tuple containing the size of any indices removed, and the flop cost.
positions : tuple of int
The locations of the proposed tensors to contract.
new_input_sets : list of sets
The resulting new list of indices if this proposed contraction is performed.
"""
# Find the contraction
contract = helpers.find_contraction(positions, input_sets, output_set)
idx_result, new_input_sets, idx_removed, idx_contract = contract
# Sieve the results based on memory_limit
new_size = helpers.compute_size_by_dict(idx_result, idx_dict)
if new_size > memory_limit:
return None
# Build sort tuple
old_sizes = (helpers.compute_size_by_dict(input_sets[p], idx_dict) for p in positions)
removed_size = sum(old_sizes) - new_size
# NB: removed_size used to be just the size of any removed indices i.e.:
# helpers.compute_size_by_dict(idx_removed, idx_dict)
cost = helpers.flop_count(idx_contract, idx_removed, len(positions), idx_dict)
sort = (-removed_size, cost)
# Sieve based on total cost as well
if (path_cost + cost) > naive_cost:
return None
# Add contraction to possible choices
return [sort, positions, new_input_sets]
def _update_other_results(results, best):
"""Update the positions and provisional input_sets of ``results`` based on
performing the contraction result ``best``. Remove any involving the tensors
contracted.
Parameters
----------
results :
List of contraction results produced by ``_parse_possible_contraction``.
best :
The best contraction of ``results`` i.e. the one that will be performed.
Returns
-------
mod_results :
The list of modifed results, updated with outcome of ``best`` contraction.
"""
best_con = best[1]
bx, by = best_con
mod_results = []
for cost, (x, y), con_sets in results:
# Ignore results involving tensors just contracted
if x in best_con or y in best_con:
continue
# Update the input_sets
del con_sets[by - int(by > x) - int(by > y)]
del con_sets[bx - int(bx > x) - int(bx > y)]
con_sets.insert(-1, best[2][-1])
# Update the position indices
mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
mod_results.append((cost, mod_con, con_sets))
return mod_results
[docs]def greedy(input_sets, output_set, idx_dict, memory_limit):
"""
Finds the path by contracting the best pair until the input list is
exhausted. The best pair is found by minimizing the tuple
``(-removed_size, cost)``. What this amounts to is prioritizing
inner product operations, matrix multiplication, then Hadamard like
operations, and finally outer operations. Outer products are limited by
``memory_limit`` and are ignored until no other operations are
available. This algorithm scales quadratically with respect to the
number of elements in the list ``input_sets``.
Parameters
----------
input_sets : list
List of sets that represent the lhs side of the einsum subscript
output_set : set
Set that represents the rhs side of the overall einsum subscript
idx_dict : dictionary
Dictionary of index sizes
memory_limit : int
The maximum number of elements in a temporary array
Returns
-------
path : list
The greedy contraction order within the memory limit constraint.
Examples
--------
>>> isets = [set('abd'), set('ac'), set('bdc')]
>>> oset = set('')
>>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
>>> greedy(isets, oset, idx_sizes, 5000)
[(0, 2), (0, 1)]
"""
# Build up a naive cost
contract = helpers.find_contraction(range(len(input_sets)), input_sets, output_set)
idx_result, new_input_sets, idx_removed, idx_contract = contract
naive_cost = helpers.flop_count(idx_contract, idx_removed, len(input_sets), idx_dict)
comb_iter = itertools.combinations(range(len(input_sets)), 2)
iteration_results = []
path_cost = 0
path = []
for iteration in range(len(input_sets) - 1):
# Iterate over all pairs on first step, only previously found pairs on subsequent steps
for positions in comb_iter:
# Always initially ignore outer products
if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
continue
result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost,
naive_cost)
if result is not None:
iteration_results.append(result)
# If we do not have a inner contraction, rescan pairs including outer products
if len(iteration_results) == 0:
# Then check the outer products
for positions in itertools.combinations(range(len(input_sets)), 2):
result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit,
path_cost, naive_cost)
if result is not None:
iteration_results.append(result)
# If we still did not find any remaining contractions, default back to einsum like behavior
if len(iteration_results) == 0:
path.append(tuple(range(len(input_sets))))
break
# Sort based on first index
best = min(iteration_results, key=lambda x: x[0])
# Now propagate as many unused contractions as possible to next iteration
iteration_results = _update_other_results(iteration_results, best)
# Next iteration only compute contractions with the new tensor
# All other contractions have been accounted for
input_sets = best[2]
new_tensor_pos = len(input_sets) - 1
comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
# Update path and total cost
path.append(best[1])
path_cost += best[0][1]
return path