001/* Copyright (C) 2013 TU Dortmund
002 * This file is part of AutomataLib, http://www.automatalib.net/.
003 * 
004 * AutomataLib is free software; you can redistribute it and/or
005 * modify it under the terms of the GNU Lesser General Public
006 * License version 3.0 as published by the Free Software Foundation.
007 * 
008 * AutomataLib is distributed in the hope that it will be useful,
009 * but WITHOUT ANY WARRANTY; without even the implied warranty of
010 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
011 * Lesser General Public License for more details.
012 * 
013 * You should have received a copy of the GNU Lesser General Public
014 * License along with AutomataLib; if not, see
015 * http://www.gnu.de/documents/lgpl.en.html.
016 */
017package net.automatalib.algorithms.graph.apsp;
018
019import java.util.ArrayList;
020import java.util.Collection;
021import java.util.List;
022
023import net.automatalib.algorithms.graph.GraphAlgorithms;
024import net.automatalib.graphs.Graph;
025import net.automatalib.graphs.concepts.EdgeWeights;
026import net.automatalib.graphs.concepts.NodeIDs;
027
028/**
029 * Implementation of the Floyd-Warshall dynamic programming algorithm for the
030 * all pairs shortest paths problem.
031 * 
032 * @author Malte Isberner <malte.isberner@gmail.com>
033 *
034 * @param <N> node class
035 * @param <E> edge class
036 */
037public class FloydWarshallAPSP<N,E> implements APSPResult<N,E> {
038        
039        private static final class APSPRecord<N,E> {
040                public final E edge;
041                public float distance;
042                public int middle;
043                public int numEdges;
044                
045                public APSPRecord(E edge, float distance) {
046                        this.edge = edge;
047                        this.distance = distance;
048                        this.middle = -1;
049                        this.numEdges = 1;
050                }
051                
052                public APSPRecord(float distance, int middle, int numEdges) {
053                        this.edge = null;
054                        this.distance = distance;
055                        this.middle = middle;
056                }
057        }
058        
059        
060        public static <N,E> APSPResult<N,E> findAPSP(Graph<N,E> graph, EdgeWeights<E> edgeWeights) {
061                FloydWarshallAPSP<N, E> fw = new FloydWarshallAPSP<>(graph, edgeWeights);
062                fw.findAPSP();
063                return fw;
064        }
065        
066        private final int size;
067        private final NodeIDs<N> ids;
068        private final APSPRecord<N,E>[][] table;
069        
070        @SuppressWarnings("unchecked")
071        public FloydWarshallAPSP(Graph<N,E> graph, EdgeWeights<E> ew) {
072                this.size = graph.size();
073                this.ids = graph.nodeIDs();
074                this.table = new APSPRecord[size][size];
075                
076                initialize(graph, ew);
077        }
078        
079        private void initialize(Graph<N,E> graph, EdgeWeights<E> ew) {
080                for(int i = 0; i < size; i++) {
081                        N src = ids.getNode(i);
082                        
083                        Collection<E> edges = graph.getOutgoingEdges(src);
084                        
085                        for(E edge : edges) {
086                                N tgt = graph.getTarget(edge);
087                                if(tgt.equals(src))
088                                        continue;
089                                
090                                int j = ids.getNodeId(tgt);
091                                float w = ew.getEdgeWeight(edge);
092                                APSPRecord<N, E> prev = table[i][j];
093                                if(prev == null || prev.distance > w)
094                                        table[i][j] = new APSPRecord<>(edge, w);
095                        }
096                }
097        }
098
099        public void findAPSP() {
100                for(int i = 0; i < size; i++) {
101                        for(int j = 0; j < size; j++) {
102                                if(j == i)
103                                        continue;
104                                
105                                APSPRecord<N,E> currRec = table[i][j];
106                                
107                                for(int k = 0; k < size; k++) {
108                                        if(k == i || k == j)
109                                                continue;
110                                        
111                                        APSPRecord<N,E> part1 = table[i][k], part2 = table[k][j];
112                                        
113                                        if(part1 == null || part2 == null)
114                                                continue;
115                                        
116                                        float dist1 = part1.distance, dist2 = part2.distance;
117                                        float total = dist1 + dist2;
118                                        
119                                        if(currRec == null) {
120                                                currRec = new APSPRecord<>(total, k, part1.numEdges + part2.numEdges);
121                                                table[i][j] = currRec;
122                                        }
123                                        else if(currRec.distance > total) {
124                                                currRec.distance = total;
125                                                currRec.middle = k;
126                                                currRec.numEdges = part1.numEdges + part2.numEdges;
127                                        }
128                                }
129                        }
130                }
131        }
132        
133        @Override
134        public float getShortestPathDistance(N src, N tgt) {
135                int srcId = ids.getNodeId(src), tgtId = ids.getNodeId(tgt);
136                
137                APSPRecord<N, E> rec = table[srcId][tgtId];
138                if(rec == null)
139                        return GraphAlgorithms.INVALID_DISTANCE;
140                
141                return rec.distance;
142        }
143
144        @Override
145        public List<E> getShortestPath(N src, N tgt) {
146                int srcId = ids.getNodeId(src), tgtId = ids.getNodeId(tgt);
147                
148                APSPRecord<N,E> rec = table[srcId][tgtId];
149                
150                if(rec == null)
151                        return null;
152                
153                List<E> result = new ArrayList<>(rec.numEdges);
154                
155                buildPath(result, srcId, tgtId, rec);
156                
157                return result;
158        }
159        
160        private void buildPath(List<E> path, int srcId, int tgtId, APSPRecord<N,E> rec) {
161                if(rec.middle == -1) {
162                        path.add(rec.edge);
163                        return;
164                }
165                
166                int middle = rec.middle;
167                buildPath(path, srcId, middle, table[srcId][middle]);
168                buildPath(path, middle, tgtId, table[middle][tgtId]);
169        }
170}