ericzma
ericzma

Reputation: 763

Why is this OCaml code so slow?

I am a OCaml newbie. I wrote a simple program in OCaml to generate the fair and square numbers ( a number that is a palindrome and the square of another palindrome, more details here: https://code.google.com/codejam/contest/2270488/dashboard#s=p2 ) as follows:

Update 1: Optimized algorithm (takes around 20 secs on my laptop):

open Printf;;

let rec _is_palindrome s i j =
  i >= j || 
  (s.[i] = s.[j] && 
      _is_palindrome s (i + 1) (j - 1))
;;

let is_palindrome s =
  let sl = String.length s in
    sl > 0 && (_is_palindrome s 0 (sl - 1))
;;

let rec del_zeros s =
  let sl = String.length s in
    if (sl < 1) then
      s
    else
      (if s.[0] = '0' then
         del_zeros (String.sub s 1 (sl - 1))
       else
         s
      )
;;

let c2i c =
  Char.code c - Char.code '0'
;;

let i2c i =
  Char.chr (i + Char.code '0')
;;

(* only for finding fair and square numbers *)
let square s =
  let slen = String.length s in
    if slen < 1 then
      ""
    else
      let reslen = 2 * slen in
      let t = ref 0 in
        t := 0;
        (* fast check *)
        (let i = reslen/2 in
           for j = (slen - 1) downto 0 do
             if (i - 1 - j) >= 0 && (i - 1 - j) < slen then
               t := !t + (c2i s.[j]) * (c2i s.[i - 1 - j]);
           done;
           if !t > 9 then
             (* jump out *)
             raise (Invalid_argument "carry");
        ); 
        (let res = String.make reslen '0' in
           (* do the square cal now *)
           for i = (reslen - 1) downto 1 do
             t := 0;
             for j = (slen - 1) downto 0 do
               if (i - 1 - j) >= 0 && (i - 1 - j) < slen then
                 t := !t + (c2i s.[j]) * (c2i s.[i - 1 - j]);
             done;
             if !t > 9 then
               (* jump out *)
               raise (Invalid_argument "carry");
             res.[i] <- i2c !t;
           done;
           del_zeros res
        );
;;

let rec check_fs fsns p =
  try let sq = square p in
    if (is_palindrome sq) then
      sq :: fsns
    else
      fsns
  with Invalid_argument "carry" ->
    fsns
;;

(* build the fair and square number list *)
(* dfs *)
let rec create_fair_square_nums fsns p sum max_num_digs =
  let l = String.length p in
    if l > max_num_digs || sum > 9 then
          fsns
    else
      let fsns = create_fair_square_nums fsns ("0" ^ p ^ "0") sum max_num_digs in
      let fsns = create_fair_square_nums fsns ("1" ^ p ^ "1") (sum + 1) max_num_digs in
      let fsns = create_fair_square_nums fsns ("2" ^ p ^ "2") (sum + 4)  max_num_digs in
      let fsns = create_fair_square_nums fsns ("3" ^ p ^ "3") (sum + 9) max_num_digs in
      let fsns = check_fs fsns p in
        fsns
;;

let rec print_fsns fsns =
  List.iter (fun s -> printf "%s " s) fsns;
  printf "\n"
;;

let num_str_cmp s1 s2 =
  let len1 = String.length s1 in
  let len2 = String.length s2 in
    match (len1 - len2) with
      | 0 ->
          String.compare s1 s2
      | cmp -> cmp
;;

(* works *)

let max_dig = 51;;

let fsns = 
  let fsns = create_fair_square_nums [] "" 0 max_dig in
  let fsns = create_fair_square_nums fsns "0" 0 max_dig in
  let fsns = create_fair_square_nums fsns "1" 1 max_dig in
  let fsns = create_fair_square_nums fsns "2" 4 max_dig in
    create_fair_square_nums fsns "3" 9 max_dig
;;

let fsns = List.sort num_str_cmp fsns;;

print_fsns fsns;;

My original code (naive solution, too slow):

open Printf;;

let rec _is_palindrome s i j =
  if i < j then
    if s.[i] = s.[j] then
      _is_palindrome s (i + 1) (j - 1)
    else
      false
  else
    true
;;

let is_palindrome s =
  if (String.length s < 1) then
    false
  else
    _is_palindrome s 0 ((String.length s) - 1)
;;

let rec del_zeros s =
  let sl = String.length s in
    if (sl < 1) then
      s
    else
      (if s.[0] = '0' then
         del_zeros (String.sub s 1 (sl - 1))
       else
         s
      )
;;

let c2i c =
  Char.code c - Char.code '0'
;;

let i2c i =
  Char.chr (i + Char.code '0')
;;

(* only positive number *)
let square s =
  (* including the carry dig *)
  let slen = String.length s in
  let res = (
    if slen > 0 then
      let reslen = 2 * slen in
      let res = String.make reslen '0' in
      let t = ref 0 in
        for i = (reslen - 1) downto 1 do
          t := c2i (res.[i]);
          for j = (slen - 1) downto 0 do
            if (i - 1 - j) >= 0 && (i - 1 - j) < slen then
              (t := !t + (c2i s.[j]) * (c2i s.[i - 1 - j]);
          (* printf "%d, %d: %d\n" j (i - 1 - j) !t; *) )
          done;
          (* printf "%d: %d\n" i !t; *)
          if !t > 9 then
            (res.[i - 1] <- 
             Char.chr (Char.code res.[i - 1] + (!t / 10));
             t := !t mod 10
            );
          res.[i] <- i2c !t;
        done;
        res;
        else
          ""
  ) in
    (* printf "square %s -> %s\n" s res; *)
    del_zeros res
;;

let extend_palindrome new_ps n = 
  ("0" ^ n ^ "0") :: 
  ("1" ^ n ^ "1") :: 
  ("2" ^ n ^ "2") :: 
  ("3" ^ n ^ "3") :: 
  new_ps
;;

let rec extend_palindromes new_ps ps = 
  match ps with
    | [] -> new_ps
    | h :: t -> 
        let new_ps = extend_palindrome new_ps h in
          extend_palindromes new_ps t
;;

let rec check_fs fsns ps =
  match ps with
    | [] -> fsns
    | h :: t -> 
        let sq = square h in
          if (is_palindrome sq) then
            check_fs (sq :: fsns) t
          else 
            check_fs fsns t
;;

(* build the fair and square number list *)
let rec create_fair_square_nums fsns ps max_num_digs =
  match ps with
    | h :: t ->
        if String.length h > max_num_digs then
          fsns
        else
          let ps = extend_palindromes [] ps in
          let fsns = check_fs fsns ps in
            create_fair_square_nums fsns ps max_num_digs
    | [] ->
        raise (Invalid_argument "fsn should not be []")
;;

let rec print_fsns fsns =
  List.iter (fun s -> printf "%s " s) fsns;
  printf "\n"
;;

let num_str_cmp s1 s2 =
  let len1 = String.length s1 in
  let len2 = String.length s2 in
    match (len1 - len2) with
      | 0 ->
          String.compare s1 s2
      | cmp -> cmp
;;

(* works *)

let max_dig = 50;;

let fsns = 
  let fsns0 = create_fair_square_nums [] [""] max_dig in
  let fsns1 = create_fair_square_nums [] ["0"; "1"; "2"; "3"] max_dig in
    (* print_fsns fsns0;
    print_fsns fsns1; *)
    ["0"; "1"; "4"; "9"] @ fsns0 @ fsns1
;;

(* print_fsns fsns;; *)

let fsns = List.sort num_str_cmp fsns;;

print_fsns fsns;;

This code generates the fair and square numbers which is within 10^100.

This code should have some (or many) problems regarding the performance. It run for more than 30 mins before I killed it. When max_dig = 14 it finishes quickly (< 1min).

Any suggestion on improving this code or criticism to it are both welcome.

Upvotes: 3

Views: 886

Answers (2)

Thomash
Thomash

Reputation: 6379

When dealing with big integers, you can use the Big_int module which will be faster than your custom implementation of square (and it will also save you a lot of time writing the code).

Also, it is bad style to write if a then b else false where you can just simply write a && b.

Upvotes: 3

gasche
gasche

Reputation: 31469

This is probably one issue among many others (and possibly neglectible in your use case, you should profile to find that out), but this code snippet already has algorithmic issues:

let rec del_zeros s =
  let sl = String.length s in
    if (sl < 1) then
      s
    else
      (if s.[0] = '0' then
         del_zeros (String.sub s 1 (sl - 1))
       else
         s
      )
;;

String.sub is linear in the length argument (in memory, therefore time), so the whole function is quadratic: del_zeros (String.make 50_000 '0') is probably going to be slow.

To write this code efficiently you should collect the kept characters in a list accumulator, accumulating the total length as you go, and finally create a string of the right size and write those characters in it.

As an approximation, the natural code using Buffer would already be reasonably efficient (that's what I would recommend writing for usual applications, but maybe not in an algorithmic contest if that's part of the critical path):

let del_zeros s =
  let buf = Buffer.create (String.length s) in
  for i = 0 to (String.length s) - 1 do
    if s.[i] <> '0' then Buffer.add_char buf s.[i]
  done;
  Buffer.contents buf

Upvotes: 4

Related Questions