Hif
Hif

Reputation: 165

Writing an LLVM pass to detect malloc function calls, number of bytes assigned and the variable name pointing to that memory

I have recently begun working with LLVM. I am trying to write a pass in LLVM that given the following code

string = (char *)malloc(100);
string = NULL;

and the corresponding LLVM IR

%call = call noalias i8* @malloc(i64 100) #3
store i8* %call, i8** %string, align 8
store i8* null, i8** %string, align 8

detects instructions calling malloc, extracts number of bytes assigned (in this case 100), the address returned and the variable name that the address is assigned to.

std::map<std::string, std::tuple<size_t, int> > mem_addrs;  // stores pointer name, address and no. of bytes allocated
Count() : ModulePass(ID) {}

virtual bool runOnModule(Module &M) {
  for (Function &F: M) { 
    for (BasicBlock &B: F) {
        for (Instruction &I: B) {
            if(CallInst* call_inst = dyn_cast<CallInst>(&I)) {
                Function* fn = call_inst->getCalledFunction();
                StringRef fn_name = fn->getName();
                errs() << fn_name << " : " << "\n";
                for(auto args = fn->arg_begin(); args != fn->arg_end(); ++args) {
                    ConstantInt* arg = dyn_cast<ConstantInt>(&(*args));
                    if (arg != NULL)
                            errs() << arg->getValue() << "\n";
                }    
            }
        }
     }  
  }

The output is

-VirtualBox:~/program_analysis$ opt -load $LLVMLIB/CSE231.so -analyze -count < $BENCHMARKS/leaktest/leaktest.bc > $OUTPUTLOGS/welcome.static.log
ok
allocaimw
allocaleak
allocamalloc : 0x2f5d9e0
0  opt             0x0000000001315cf2 llvm::sys::PrintStackTrace(_IO_FILE*) + 34
1  opt             0x0000000001315914
2  libpthread.so.0 0x00007f0b53f12330
3  opt             0x00000000012ec78f llvm::APInt::toString(llvm::SmallVectorImpl<char>&, unsigned int, bool, bool) const + 79
4  opt             0x00000000012ed309 llvm::APInt::print(llvm::raw_ostream&, bool) const + 57
5  CSE231.so       0x00007f0b52f16661
6  opt             0x00000000012ad6cd llvm::legacy::PassManagerImpl::run(llvm::Module&) + 797
7  opt             0x000000000058e190 main + 2752
8  libc.so.6       0x00007f0b5313af45 __libc_start_main + 245
9  opt             0x00000000005ab2ca
Stack dump:
0.  Program arguments: opt -load /home/hifza/program_analysis/llvm/build/Release+Asserts/lib/CSE231.so -analyze -count 
1.  Running pass 'Instruction Counts Pass' on module '<stdin>'.
Segmentation fault (core dumped)

I am able to detect malloc instructions, but I am not able to find out the corresponding memory address and the number of bytes assigned. Can anyone guide me on how can I go about doing this? Thanks.

Upvotes: 6

Views: 2085

Answers (2)

shrep
shrep

Reputation: 11

I prefer detecting malloc calls,

  1. by first detecting store insts
  2. then checking whether LHS is a pointer
  3. then find out what is RHS (by using a stack approach to find actual value, since LLVM IR is a load-store architecture and hence we don't find the actual value in RHS, always)
  4. if I end up getting a call inst then
  5. check whether its malloc or not

Once you have detected the malloc, you can simply fetch the bytes accessed by ip->getOperand(0) And the variable name pointing to the memory is nothing but the value returned by Store inst that you just started with - lhs in the code

Am sharing the code snippet,which will also work for inter-procedural cases as well and also supports new operator .

void findOperand(Value *itVal) {
            std::stack<Value *> st;
            st.push(itVal);
            while(!st.empty()) {

                auto ele = st.top();
                st.pop();

                if(isa<Instruction>(ele)) {
                    Instruction *tip = (Instruction *)ele;
                    if(isa<AllocaInst>(tip)) {
                        errs()<<"others\n";
                        //opdSet.insert(ele);
                    }else if(isa<LoadInst>(tip)) {
                        Value *ti = tip->getOperand(0);
                        if(!isa<ConstantData>(ti))
                            st.push(ti);
                    } else if(isa<CallInst>(tip)) {
                        Function *calledFp = cast<CallInst>(tip)->getCalledFunction();
                        
                        errs()<<calledFp->getName()<<"\n";
                        if(calledFp->getName() == "malloc" || calledFp->getName() == "_Znwm") {
                            errs()<<"Dynamic memory allocation!\n";
                            errs()<<tip->getNumOperands()<<"\n";
                            errs()<<tip->getOperand(0)<<"\n";
                        } else {
                            //fetch the last bb of the function
                            auto bb = calledFp->end();
                            if(bb != calledFp->begin()) {
                                bb --;
                                BasicBlock *bp = &(*bb);
                                //fetch the terminator
                                Instruction *term = bp->getTerminator();
                                if(isa<ReturnInst>(term)) {
                                    //find Operand
                                    findOperand(term->getOperand(0));
                                    errs()<<"done\n";
                                }
                                
                            }
                        }
                    } else {
                        for(int i=0, numOp = tip->getNumOperands(); i < numOp; i++) {
                            Value *ti = tip->getOperand(i);                     
                            if(!isa<ConstantData>(ti)) {
                                st.push(ti);
                            }
                        }
                    }
                } else if (isa<GlobalVariable>(ele)) {
                    errs()<<"others\n";
                }
            }
            
        }//findOperand

        void visitStoreInst(StoreInst &ip) {
            
            Value *lhs = ip.getOperand(1);
            Value *rhs = ip.getOperand(0);
            if(lhs->getType()->getContainedType(0)->isPointerTy()) {
                //figure out rhs
                errs()<<"pointer assignment!"<<lhs->getName()<<"\n";
                findOperand(rhs);
            }
        }

Upvotes: 0

arrowd
arrowd

Reputation: 34401

You don't check the result of dyn_cast<ConstantInt>(&(*args)). If casted type is not a ConstantInt, it returns nullptr. And in the next line (arg->getValue()) you dereference it.

Upvotes: 1

Related Questions