001    /*--------------------------------------------------------------------------+
002    $Id: MaxWeightMatching.java 26283 2010-02-18 11:18:57Z juergens $
003    |                                                                          |
004    | Copyright 2005-2010 Technische Universitaet Muenchen                     |
005    |                                                                          |
006    | Licensed under the Apache License, Version 2.0 (the "License");          |
007    | you may not use this file except in compliance with the License.         |
008    | You may obtain a copy of the License at                                  |
009    |                                                                          |
010    |    http://www.apache.org/licenses/LICENSE-2.0                            |
011    |                                                                          |
012    | Unless required by applicable law or agreed to in writing, software      |
013    | distributed under the License is distributed on an "AS IS" BASIS,        |
014    | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
015    | See the License for the specific language governing permissions and      |
016    | limitations under the License.                                           |
017    +--------------------------------------------------------------------------*/
018    package edu.tum.cs.commons.algo;
019    
020    import java.util.Arrays;
021    import java.util.List;
022    
023    import edu.tum.cs.commons.collections.PairList;
024    
025    /**
026     * A class for calculating maximum weighted matching using an augmenting path
027     * algorithm running in O(n^3*m), where n is the size of the smaller node set
028     * and m the size of the larger one. In practice the running time is much less.
029     * <p>
030     * This class is not thread save!
031     * 
032     * @author hummelb
033     * @author $Author: juergens $
034     * @version $Rev: 26283 $
035     * @levd.rating GREEN Hash: 2069DC784F078E4503328061B520BBB1
036     * 
037     * @param <N1>
038     *            The first node type
039     * @param <N2>
040     *            The second node type
041     */
042    public class MaxWeightMatching<N1, N2> {
043    
044            /**
045             * Flag indicating whether we are running in swapped mode. Swapped mode is
046             * needed as our algorithm requires the second set of nodes not to be
047             * smaller than the first set. If this is not the case, we just swap these
048             * sets, but we need this flag to adjust some parts of the code.
049             */
050            private boolean swapped;
051    
052            /** Size of the first (or second if {@link #swapped}) node set. */
053            private int size1;
054    
055            /** Size of the second (or first if {@link #swapped}) node set. */
056            private int size2;
057    
058            /** The first node set. */
059            private List<N1> nodes1;
060    
061            /** The second node set. */
062            private List<N2> nodes2;
063    
064            /** The provider for the weights (i.e. weight matrix). */
065            private IWeightProvider<N1, N2> weightProvider;
066    
067            /**
068             * This array stores for each node of the second set the index of the node
069             * from the first set, it is matched to (or -1 if is not in matching). If
070             * {@link #swapped}, first and second set change meaning.
071             */
072            private int[] mate = new int[16];
073    
074            /**
075             * This is used while searching shortest path and stores the node index we
076             * came from.
077             */
078            private int[] from = new int[16];
079    
080            /**
081             * This is used while searching shortest path and stores the distance (i.e.
082             * weight sum) to this node.
083             */
084            private double[] dist = new double[16];
085    
086            /**
087             * Calculate the weighted bipartite matching.
088             * 
089             * @param matching
090             *            if this is non <code>null</code>, the matching (i.e. the pairs of nodes
091             *            matched onto each other) will be put into it.
092             * 
093             * @return the weight of the matching.
094             */
095            public double calculateMatching(List<N1> nodes1, List<N2> nodes2,
096                            IWeightProvider<N1, N2> weightProvider, PairList<N1, N2> matching) {
097    
098                    if (matching != null) {
099                            matching.clear();
100                    }
101    
102                    if (nodes1.isEmpty() || nodes2.isEmpty()) {
103                            return 0;
104                    }
105    
106                    init(nodes1, nodes2, weightProvider);
107                    prepareInternalArrays();
108    
109                    for (int i = 0; i < size1; ++i) {
110                            augmentFrom(i);
111                    }
112    
113                    double res = 0;
114                    for (int i = 0; i < size2; ++i) {
115                            if (mate[i] >= 0) {
116                                    if (matching != null) {
117                                            if (swapped) {
118                                                    matching.add(nodes1.get(i), nodes2.get(mate[i]));
119                                            } else {
120                                                    matching.add(nodes1.get(mate[i]), nodes2.get(i));
121                                            }
122                                    }
123                                    res += getWeight(mate[i], i);
124                            }
125                    }
126                    return res;
127            }
128    
129            /**
130             * Initializes the data structures from the parameters to the
131             * {@link #calculateMatching(List, List, edu.tum.cs.commons.algo.MaxWeightMatching.IWeightProvider, PairList)}
132             * method.
133             */
134            private void init(List<N1> nodes1, List<N2> nodes2,
135                            IWeightProvider<N1, N2> weightProvider) {
136                    if (nodes1.size() <= nodes2.size()) {
137                            size1 = nodes1.size();
138                            size2 = nodes2.size();
139                            swapped = false;
140                    } else {
141                            size1 = nodes2.size();
142                            size2 = nodes1.size();
143                            swapped = true;
144                    }
145                    this.nodes1 = nodes1;
146                    this.nodes2 = nodes2;
147                    this.weightProvider = weightProvider;
148            }
149    
150            /** Make sure all internal arrays are large enough. */
151            private void prepareInternalArrays() {
152                    if (size2 > mate.length) {
153                            int newSize = mate.length;
154                            while (newSize < size2) {
155                                    newSize *= 2;
156                            }
157                            mate = new int[newSize];
158                            from = new int[newSize];
159                            dist = new double[newSize];
160                    }
161    
162                    Arrays.fill(mate, 0, size2, -1);
163            }
164    
165            /**
166             * Calculate shortest augmenting path and augment along it starting from the
167             * given node (index).
168             */
169            private void augmentFrom(int u) {
170                    for (int i = 0; i < size2; ++i) {
171                            from[i] = -1;
172                            dist[i] = getWeight(u, i);
173                    }
174                    bellmanFord();
175                    int target = findBestUnmatchedTarget();
176                    augmentAlongPath(u, target);
177            }
178    
179            /** Calculate the shortest path using Bellman-Ford algorithm. */
180            private void bellmanFord() {
181                    boolean changed = true;
182                    while (changed) {
183                            changed = false;
184                            for (int i = 0; i < size2; ++i) {
185                                    if (mate[i] < 0) {
186                                            continue;
187                                    }
188                                    double w = getWeight(mate[i], i);
189                                    for (int j = 0; j < size2; ++j) {
190                                            if (i == j) {
191                                                    continue;
192                                            }
193                                            double newDist = dist[i] - w + getWeight(mate[i], j);
194                                            if (newDist - 1e-15 > dist[j]) {
195                                                    dist[j] = newDist;
196                                                    from[j] = i;
197                                                    changed = true;
198                                            }
199                                    }
200                            }
201                    }
202            }
203    
204            /** Find the best target which is not yet in the matching. */
205            private int findBestUnmatchedTarget() {
206                    int target = -1;
207                    for (int i = 0; i < size2; ++i) {
208                            if (mate[i] < 0) {
209                                    if (target < 0 || dist[i] > dist[target]) {
210                                            target = i;
211                                    }
212                            }
213                    }
214                    return target;
215            }
216    
217            /** Augment along the given path to the target by adjusting the mate array. */
218            private void augmentAlongPath(int u, int target) {
219                    while (from[target] >= 0) {
220                            mate[target] = mate[from[target]];
221                            target = from[target];
222                    }
223                    mate[target] = u;
224            }
225    
226            /**
227             * Returns the weight between two nodes (=indices) handling swapping
228             * transparently.
229             */
230            private double getWeight(int i1, int i2) {
231                    if (swapped) {
232                            return weightProvider.getConnectionWeight(nodes1.get(i2), nodes2
233                                            .get(i1));
234                    }
235                    return weightProvider.getConnectionWeight(nodes1.get(i1), nodes2
236                                    .get(i2));
237            }
238    
239            /** A class providing the weight for a connection between two nodes. */
240            public interface IWeightProvider<N1, N2> {
241    
242                    /** Returns the weight of the connection between both nodes. */
243                    double getConnectionWeight(N1 node1, N2 node2);
244            }
245    }