001 /*--------------------------------------------------------------------------+ 002 $Id: MaxWeightMatching.java 26283 2010-02-18 11:18:57Z juergens $ 003 | | 004 | Copyright 2005-2010 Technische Universitaet Muenchen | 005 | | 006 | Licensed under the Apache License, Version 2.0 (the "License"); | 007 | you may not use this file except in compliance with the License. | 008 | You may obtain a copy of the License at | 009 | | 010 | http://www.apache.org/licenses/LICENSE-2.0 | 011 | | 012 | Unless required by applicable law or agreed to in writing, software | 013 | distributed under the License is distributed on an "AS IS" BASIS, | 014 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 015 | See the License for the specific language governing permissions and | 016 | limitations under the License. | 017 +--------------------------------------------------------------------------*/ 018 package edu.tum.cs.commons.algo; 019 020 import java.util.Arrays; 021 import java.util.List; 022 023 import edu.tum.cs.commons.collections.PairList; 024 025 /** 026 * A class for calculating maximum weighted matching using an augmenting path 027 * algorithm running in O(n^3*m), where n is the size of the smaller node set 028 * and m the size of the larger one. In practice the running time is much less. 029 * <p> 030 * This class is not thread save! 031 * 032 * @author hummelb 033 * @author $Author: juergens $ 034 * @version $Rev: 26283 $ 035 * @levd.rating GREEN Hash: 2069DC784F078E4503328061B520BBB1 036 * 037 * @param <N1> 038 * The first node type 039 * @param <N2> 040 * The second node type 041 */ 042 public class MaxWeightMatching<N1, N2> { 043 044 /** 045 * Flag indicating whether we are running in swapped mode. Swapped mode is 046 * needed as our algorithm requires the second set of nodes not to be 047 * smaller than the first set. If this is not the case, we just swap these 048 * sets, but we need this flag to adjust some parts of the code. 049 */ 050 private boolean swapped; 051 052 /** Size of the first (or second if {@link #swapped}) node set. */ 053 private int size1; 054 055 /** Size of the second (or first if {@link #swapped}) node set. */ 056 private int size2; 057 058 /** The first node set. */ 059 private List<N1> nodes1; 060 061 /** The second node set. */ 062 private List<N2> nodes2; 063 064 /** The provider for the weights (i.e. weight matrix). */ 065 private IWeightProvider<N1, N2> weightProvider; 066 067 /** 068 * This array stores for each node of the second set the index of the node 069 * from the first set, it is matched to (or -1 if is not in matching). If 070 * {@link #swapped}, first and second set change meaning. 071 */ 072 private int[] mate = new int[16]; 073 074 /** 075 * This is used while searching shortest path and stores the node index we 076 * came from. 077 */ 078 private int[] from = new int[16]; 079 080 /** 081 * This is used while searching shortest path and stores the distance (i.e. 082 * weight sum) to this node. 083 */ 084 private double[] dist = new double[16]; 085 086 /** 087 * Calculate the weighted bipartite matching. 088 * 089 * @param matching 090 * if this is non <code>null</code>, the matching (i.e. the pairs of nodes 091 * matched onto each other) will be put into it. 092 * 093 * @return the weight of the matching. 094 */ 095 public double calculateMatching(List<N1> nodes1, List<N2> nodes2, 096 IWeightProvider<N1, N2> weightProvider, PairList<N1, N2> matching) { 097 098 if (matching != null) { 099 matching.clear(); 100 } 101 102 if (nodes1.isEmpty() || nodes2.isEmpty()) { 103 return 0; 104 } 105 106 init(nodes1, nodes2, weightProvider); 107 prepareInternalArrays(); 108 109 for (int i = 0; i < size1; ++i) { 110 augmentFrom(i); 111 } 112 113 double res = 0; 114 for (int i = 0; i < size2; ++i) { 115 if (mate[i] >= 0) { 116 if (matching != null) { 117 if (swapped) { 118 matching.add(nodes1.get(i), nodes2.get(mate[i])); 119 } else { 120 matching.add(nodes1.get(mate[i]), nodes2.get(i)); 121 } 122 } 123 res += getWeight(mate[i], i); 124 } 125 } 126 return res; 127 } 128 129 /** 130 * Initializes the data structures from the parameters to the 131 * {@link #calculateMatching(List, List, edu.tum.cs.commons.algo.MaxWeightMatching.IWeightProvider, PairList)} 132 * method. 133 */ 134 private void init(List<N1> nodes1, List<N2> nodes2, 135 IWeightProvider<N1, N2> weightProvider) { 136 if (nodes1.size() <= nodes2.size()) { 137 size1 = nodes1.size(); 138 size2 = nodes2.size(); 139 swapped = false; 140 } else { 141 size1 = nodes2.size(); 142 size2 = nodes1.size(); 143 swapped = true; 144 } 145 this.nodes1 = nodes1; 146 this.nodes2 = nodes2; 147 this.weightProvider = weightProvider; 148 } 149 150 /** Make sure all internal arrays are large enough. */ 151 private void prepareInternalArrays() { 152 if (size2 > mate.length) { 153 int newSize = mate.length; 154 while (newSize < size2) { 155 newSize *= 2; 156 } 157 mate = new int[newSize]; 158 from = new int[newSize]; 159 dist = new double[newSize]; 160 } 161 162 Arrays.fill(mate, 0, size2, -1); 163 } 164 165 /** 166 * Calculate shortest augmenting path and augment along it starting from the 167 * given node (index). 168 */ 169 private void augmentFrom(int u) { 170 for (int i = 0; i < size2; ++i) { 171 from[i] = -1; 172 dist[i] = getWeight(u, i); 173 } 174 bellmanFord(); 175 int target = findBestUnmatchedTarget(); 176 augmentAlongPath(u, target); 177 } 178 179 /** Calculate the shortest path using Bellman-Ford algorithm. */ 180 private void bellmanFord() { 181 boolean changed = true; 182 while (changed) { 183 changed = false; 184 for (int i = 0; i < size2; ++i) { 185 if (mate[i] < 0) { 186 continue; 187 } 188 double w = getWeight(mate[i], i); 189 for (int j = 0; j < size2; ++j) { 190 if (i == j) { 191 continue; 192 } 193 double newDist = dist[i] - w + getWeight(mate[i], j); 194 if (newDist - 1e-15 > dist[j]) { 195 dist[j] = newDist; 196 from[j] = i; 197 changed = true; 198 } 199 } 200 } 201 } 202 } 203 204 /** Find the best target which is not yet in the matching. */ 205 private int findBestUnmatchedTarget() { 206 int target = -1; 207 for (int i = 0; i < size2; ++i) { 208 if (mate[i] < 0) { 209 if (target < 0 || dist[i] > dist[target]) { 210 target = i; 211 } 212 } 213 } 214 return target; 215 } 216 217 /** Augment along the given path to the target by adjusting the mate array. */ 218 private void augmentAlongPath(int u, int target) { 219 while (from[target] >= 0) { 220 mate[target] = mate[from[target]]; 221 target = from[target]; 222 } 223 mate[target] = u; 224 } 225 226 /** 227 * Returns the weight between two nodes (=indices) handling swapping 228 * transparently. 229 */ 230 private double getWeight(int i1, int i2) { 231 if (swapped) { 232 return weightProvider.getConnectionWeight(nodes1.get(i2), nodes2 233 .get(i1)); 234 } 235 return weightProvider.getConnectionWeight(nodes1.get(i1), nodes2 236 .get(i2)); 237 } 238 239 /** A class providing the weight for a connection between two nodes. */ 240 public interface IWeightProvider<N1, N2> { 241 242 /** Returns the weight of the connection between both nodes. */ 243 double getConnectionWeight(N1 node1, N2 node2); 244 } 245 }