001/* 002 * This file is part of the Jikes RVM project (http://jikesrvm.org). 003 * 004 * This file is licensed to You under the Eclipse Public License (EPL); 005 * You may not use this file except in compliance with the License. You 006 * may obtain a copy of the License at 007 * 008 * http://www.opensource.org/licenses/eclipse-1.0.php 009 * 010 * See the COPYRIGHT.txt file distributed with this work for information 011 * regarding copyright ownership. 012 */ 013package org.jikesrvm.compilers.opt.ir; 014 015import java.util.Enumeration; 016 017import org.jikesrvm.VM; 018import org.jikesrvm.compilers.opt.OptimizingCompilerException; 019import org.jikesrvm.compilers.opt.ir.operand.BranchProfileOperand; 020 021/** 022 * Used to iterate over the branch targets (including the fall through edge) 023 * and associated probabilites of a basic block. 024 * Takes into account the ordering of branch instructions when 025 * computing the edge weights such that the total target weight will always 026 * be equal to 1.0 (flow in == flow out). 027 */ 028public final class WeightedBranchTargets { 029 private BasicBlock[] targets; 030 private float[] weights; 031 private int cur; 032 private int max; 033 034 public void reset() { 035 cur = 0; 036 } 037 038 public boolean hasMoreElements() { 039 return cur < max; 040 } 041 042 public void advance() { 043 cur++; 044 } 045 046 public BasicBlock curBlock() { 047 return targets[cur]; 048 } 049 050 public float curWeight() { 051 return weights[cur]; 052 } 053 054 public WeightedBranchTargets(BasicBlock bb) { 055 targets = new BasicBlock[3]; 056 weights = new float[3]; 057 cur = 0; 058 max = 0; 059 060 float prob = 1f; 061 for (Enumeration<Instruction> ie = bb.enumerateBranchInstructions(); ie.hasMoreElements();) { 062 Instruction s = ie.nextElement(); 063 if (IfCmp.conforms(s)) { 064 BasicBlock target = IfCmp.getTarget(s).target.getBasicBlock(); 065 BranchProfileOperand prof = IfCmp.getBranchProfile(s); 066 float taken = prob * prof.takenProbability; 067 prob = prob * (1f - prof.takenProbability); 068 addEdge(target, taken); 069 } else if (Goto.conforms(s)) { 070 BasicBlock target = Goto.getTarget(s).target.getBasicBlock(); 071 addEdge(target, prob); 072 } else if (InlineGuard.conforms(s)) { 073 BasicBlock target = InlineGuard.getTarget(s).target.getBasicBlock(); 074 BranchProfileOperand prof = InlineGuard.getBranchProfile(s); 075 float taken = prob * prof.takenProbability; 076 prob = prob * (1f - prof.takenProbability); 077 addEdge(target, taken); 078 } else if (IfCmp2.conforms(s)) { 079 BasicBlock target = IfCmp2.getTarget1(s).target.getBasicBlock(); 080 BranchProfileOperand prof = IfCmp2.getBranchProfile1(s); 081 float taken = prob * prof.takenProbability; 082 prob = prob * (1f - prof.takenProbability); 083 addEdge(target, taken); 084 target = IfCmp2.getTarget2(s).target.getBasicBlock(); 085 prof = IfCmp2.getBranchProfile2(s); 086 taken = prob * prof.takenProbability; 087 prob = prob * (1f - prof.takenProbability); 088 addEdge(target, taken); 089 } else if (TableSwitch.conforms(s)) { 090 int lowLimit = TableSwitch.getLow(s).value; 091 int highLimit = TableSwitch.getHigh(s).value; 092 int number = highLimit - lowLimit + 1; 093 float total = 0f; 094 for (int i = 0; i < number; i++) { 095 BasicBlock target = TableSwitch.getTarget(s, i).target.getBasicBlock(); 096 BranchProfileOperand prof = TableSwitch.getBranchProfile(s, i); 097 float taken = prob * prof.takenProbability; 098 total += prof.takenProbability; 099 addEdge(target, taken); 100 } 101 BasicBlock target = TableSwitch.getDefault(s).target.getBasicBlock(); 102 BranchProfileOperand prof = TableSwitch.getDefaultBranchProfile(s); 103 float taken = prob * prof.takenProbability; 104 total += prof.takenProbability; 105 if (VM.VerifyAssertions && !epsilon(total, 1f)) { 106 VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s); 107 } 108 addEdge(target, taken); 109 } else if (LowTableSwitch.conforms(s)) { 110 int number = LowTableSwitch.getNumberOfTargets(s); 111 float total = 0f; 112 for (int i = 0; i < number; i++) { 113 BasicBlock target = LowTableSwitch.getTarget(s, i).target.getBasicBlock(); 114 BranchProfileOperand prof = LowTableSwitch.getBranchProfile(s, i); 115 float taken = prob * prof.takenProbability; 116 total += prof.takenProbability; 117 addEdge(target, taken); 118 } 119 if (VM.VerifyAssertions && !epsilon(total, 1f)) { 120 VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s); 121 } 122 } else if (LookupSwitch.conforms(s)) { 123 int number = LookupSwitch.getNumberOfTargets(s); 124 float total = 0f; 125 for (int i = 0; i < number; i++) { 126 BasicBlock target = LookupSwitch.getTarget(s, i).target.getBasicBlock(); 127 BranchProfileOperand prof = LookupSwitch.getBranchProfile(s, i); 128 float taken = prob * prof.takenProbability; 129 total += prof.takenProbability; 130 addEdge(target, taken); 131 } 132 BasicBlock target = LookupSwitch.getDefault(s).target.getBasicBlock(); 133 BranchProfileOperand prof = LookupSwitch.getDefaultBranchProfile(s); 134 float taken = prob * prof.takenProbability; 135 total += prof.takenProbability; 136 if (VM.VerifyAssertions && !epsilon(total, 1f)) { 137 VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s); 138 } 139 addEdge(target, taken); 140 } else { 141 throw new OptimizingCompilerException("TODO " + s + "\n"); 142 } 143 } 144 BasicBlock ft = bb.getFallThroughBlock(); 145 if (ft != null) addEdge(ft, prob); 146 } 147 148 private void addEdge(BasicBlock target, float weight) { 149 if (max == targets.length) { 150 BasicBlock[] tmp = new BasicBlock[targets.length << 1]; 151 for (int i = 0; i < targets.length; i++) { 152 tmp[i] = targets[i]; 153 } 154 targets = tmp; 155 float[] tmp2 = new float[weights.length << 1]; 156 for (int i = 0; i < weights.length; i++) { 157 tmp2[i] = weights[i]; 158 } 159 weights = tmp2; 160 } 161 targets[max] = target; 162 weights[max] = weight; 163 max++; 164 } 165 166 private boolean epsilon(float a, float b) { 167 return Math.abs(a - b) < 0.1; 168 } 169}