diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 2f8cd074300a..dd0677f35170 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -7725,7 +7725,7 @@ class AdjointGenerator if (!forwardMode) root = lookup(root, Builder2); - Value *comm = lookup(gutils->getNewFromOriginal(orig_comm), Builder2); + Value *comm = gutils->getNewFromOriginal(orig_comm); if (!forwardMode) comm = lookup(comm, Builder2); diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index e1efa55ff2c0..6f03a729e045 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -3822,6 +3822,13 @@ void TypeAnalyzer::visitCallInst(CallInst &call) { } else if (GV->getName() == "ompi_mpi_float") { buf.insert({0}, Type::getFloatTy(C->getContext())); } + } else if (auto CI = dyn_cast(C)) { + // MPICH + if (CI->getValue() == 1275070475) { + buf.insert({0}, Type::getDoubleTy(C->getContext())); + } else if (CI->getValue() == 1275069450) { + buf.insert({0}, Type::getFloatTy(C->getContext())); + } } } updateAnalysis(call.getOperand(0), buf.Only(-1, &call), &call); @@ -3847,6 +3854,13 @@ void TypeAnalyzer::visitCallInst(CallInst &call) { } else if (GV->getName() == "ompi_mpi_float") { buf.insert({0}, Type::getFloatTy(C->getContext())); } + } else if (auto CI = dyn_cast(C)) { + // MPICH + if (CI->getValue() == 1275070475) { + buf.insert({0}, Type::getDoubleTy(C->getContext())); + } else if (CI->getValue() == 1275069450) { + buf.insert({0}, Type::getFloatTy(C->getContext())); + } } } updateAnalysis(call.getOperand(0), buf.Only(-1, &call), &call); @@ -3902,15 +3916,34 @@ void TypeAnalyzer::visitCallInst(CallInst &call) { return; } if (funcName == "MPI_Reduce" || funcName == "PMPI_Reduce") { + TypeTree buf = TypeTree(BaseType::Pointer); + + if (Constant *C = dyn_cast(call.getOperand(3))) { + while (ConstantExpr *CE = dyn_cast(C)) { + C = CE->getOperand(0); + } + if (auto GV = dyn_cast(C)) { + if (GV->getName() == "ompi_mpi_double") { + buf.insert({0}, Type::getDoubleTy(C->getContext())); + } else if (GV->getName() == "ompi_mpi_float") { + buf.insert({0}, Type::getFloatTy(C->getContext())); + } + } else if (auto CI = dyn_cast(C)) { + // MPICH + if (CI->getValue() == 1275070475) { + buf.insert({0}, Type::getDoubleTy(C->getContext())); + } else if (CI->getValue() == 1275069450) { + buf.insert({0}, Type::getFloatTy(C->getContext())); + } + } + } // int MPI_Reduce(const void *sendbuf, void *recvbuf, int count, // MPI_Datatype datatype, // MPI_Op op, int root, MPI_Comm comm) // sendbuf - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); + updateAnalysis(call.getOperand(0), buf.Only(-1, &call), &call); // recvbuf - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); + updateAnalysis(call.getOperand(1), buf.Only(-1, &call), &call); // count updateAnalysis(call.getOperand(2), TypeTree(BaseType::Integer).Only(-1, &call), &call); @@ -3922,14 +3955,33 @@ void TypeAnalyzer::visitCallInst(CallInst &call) { return; } if (funcName == "MPI_Allreduce") { + TypeTree buf = TypeTree(BaseType::Pointer); + + if (Constant *C = dyn_cast(call.getOperand(3))) { + while (ConstantExpr *CE = dyn_cast(C)) { + C = CE->getOperand(0); + } + if (auto GV = dyn_cast(C)) { + if (GV->getName() == "ompi_mpi_double") { + buf.insert({0}, Type::getDoubleTy(C->getContext())); + } else if (GV->getName() == "ompi_mpi_float") { + buf.insert({0}, Type::getFloatTy(C->getContext())); + } + } else if (auto CI = dyn_cast(C)) { + // MPICH + if (CI->getValue() == 1275070475) { + buf.insert({0}, Type::getDoubleTy(C->getContext())); + } else if (CI->getValue() == 1275069450) { + buf.insert({0}, Type::getFloatTy(C->getContext())); + } + } + } // int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) // sendbuf - updateAnalysis(call.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); + updateAnalysis(call.getOperand(0), buf.Only(-1, &call), &call); // recvbuf - updateAnalysis(call.getOperand(1), - TypeTree(BaseType::Pointer).Only(-1, &call), &call); + updateAnalysis(call.getOperand(1), buf.Only(-1, &call), &call); // count updateAnalysis(call.getOperand(2), TypeTree(BaseType::Integer).Only(-1, &call), &call);