Ahead-of-time compilation of a Tensorflow model for lightweight inclusion in a Rust program
Posted on 2019-03-09

Why?

Tensorflow’s new Accelerated Linear Algebra (XLA) framework comes with a lot of advantages, like Just In Time (JIT) compilation of computation graphs leading to speedups during training and inference on both GPU and CPU. Moreover, it supports Ahead-of-time (AOT) compilation as well through the tfcompile tool.

For me, AOT compilation is where it gets interesting. I’ve struggled to find a decent, lightweight way to put Tensorflow models in production on CPU-only systems. Just running everything through Tensorflow’s Python bindings is convenient, but extremely bulky: you’ll need the full Tensorflow runtime compiled on your machine along with all Python deps and your model weights. Deploying on something like AWS Lambda this way gets problematic fast due to the size limits AWS puts on the function bundle.

Tensorflow Serving is cool, but is pretty complex and enterprisey and is only meant for exposing your model through an HTTP API (e.g. from a container). Tensorflow Lite is also pretty promising, as it is aimed at optimizing your model for inference on mobile devices. It does not support all Tensorflow’s operations yet, though, and it seems focussed on iOS/Android for now.

I just want a single compiled blob of code that includes everything I need to execute my trained model, without any extra dependencies or funny business. Just give me an old school object file! It seems tfcompile lets us do just that.

In this post i’ll discuss how to create an object file like this from an example model. I’m wrapping the compiled model with Rust code here, but you might as well use Python or your favourite programming language that supports C-style Foreign Function Interfaces (FFIs).

Prepare the Tensorflow graph for compilation

As an example model for this post, let’s use this extremely cool age and gender estimation Keras model by Yu4u. It is a WideResNet architecture trained on the IMDB-WIKI dataset and it takes (aligned) images of faces as inputs. This model has two output nodes; one for age and one for gender.

The model weights are available as Keras .hdf5 files. As we’re not covering Keras in this post we’ll need to convert these weights to a simple plain Tensorflow .pb graph definition. For this, we can use the brilliant keras_to_tensorflow.py script to take care of the nitty gritty for us. Note that this script also ‘freezes’ the graph by converting al Variables to Constants.

python ./keras_to_tensorflow.py --input_model='./weights.28-3.73.hdf5' --output_model="./frozen_age_gender_model.pb"
Using TensorFlow backend.
2019-02-21 16:01:58.435588: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
I0221 16:02:02.120664 4566390208 keras_to_tensorflow.py:146] Converted output node names are: ['dense_1/Softmax', 'dense_2/Softmax']
INFO:tensorflow:Froze 70 variables.
I0221 16:02:02.258061 4566390208 tf_logging.py:115] Froze 70 variables.
INFO:tensorflow:Converted 70 variables to const ops.
I0221 16:02:02.605583 4566390208 tf_logging.py:115] Converted 70 variables to const ops.
I0221 16:02:03.049557 4566390208 keras_to_tensorflow.py:178] Saved the freezed graph at frozen_age_gender_model.pb

As you can see, it renamed the output nodes to (['dense_1/Softmax', 'dense_2/Softmax']) for us which we’ll use next.

Optimize the graph using using the Graph Transform Tool

First, we need to get a hold of all the ‘feeds’ and ‘fetches’; the tensors representing the model graphs inputs and outputs, respectively. If you’re trying to compile a different model than me and don’t know the names of the relevant tensors, you can get a list of all tensors in a model using the summarize_graph tool.

In this example model, the graph node that takes the input image is called input_1 and it accepts an arbitrary number of tensors of dimension (n, 64, 64, 3) (meaning square RGB images of 64 x 64 pixels). For simplicity, I’ll simply set the batch size (n) to 1.

The output layers are called dense_1/Softmax and dense_2/Softmax which return probability vectors for gender (with a length of length 2) and age (length 101), respectively.

We’ll now apply some simple optimizations to the inference graph using the Graph Transform Tool (GTT).

First we’ll need to grab the Tensorflow source code and install its Bazel build system.

# Clone the tensorflow master branch
git clone --depth=1 --branch v1.13.1 https://github.com/tensorflow/tensorflow.git

cd tensorflow

# configure the build, just hit go with all the defaults when prompted
./configure

Make sure your working directory is the root of the tensorflow repository you just cloned.

Now we can build the GTT:

bazel build tensorflow/tools/graph_transforms:transform_graph

With the GTT binary built, we can apply some simple optimizations aimed at prepping the model for deployment (as shown in the GTT docs).

bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=frozen_age_gender_model.pb \
--out_graph=optimized_age_gender_model.pb \
--inputs='input_1' \
--outputs='dense_1/Softmax,dense_2/Softmax' \
--transforms='
  strip_unused_nodes(type=float, shape="1,64,64,3")
  remove_nodes(op=Identity, op=CheckNumerics)
  fold_constants(ignore_errors=true)
  fold_batch_norms
  fold_old_batch_norms'

Prepare the config file defining the subgraph to compile

tfcompile requires a .pbtext definition of the feeds and fetches that define the subgraph we want to compile. We can use the convenience functions defined in Tensorflow’s tf2xla_pb2.py script for this, but let’s just do it by hand for now as it is exceedingly simple.

Note that XLA’s tfcompile requires static input tensors of a known shape, which we just looked up for our example model: (1, 64, 64, 3).

Let’s save the below config in a file called graph.config.pbtxt.

feed {
  id {
    node_name: "input_1"
  }
  shape {
    dim {
      size: 1
    }
    dim {
      size: 64
    }
    dim {
      size: 64
    }
    dim {
      size: 3
    }
  }
}
fetch {
  id {
    node_name: "dense_1/Softmax"
  }
}
fetch {
  id {
    node_name: "dense_2/Softmax"
  }
}

Use tfcompile to compile the subgraph

Next, we’ll compile the subgraph using tfcompile. Instead of touching tfcompile directly, we’ll use a pre-baked Bazel macro.

cd to the root of the tensorflow repository we just cloned and create (or overwrite) a file called BUILD containing the following:

load('@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl', 'tf_library')

tf_library(
    name = 'graph',
    config = 'graph.config.pbtxt',
    cpp_class = 'Graph',
    graph = 'optimized_age_gender_model.pb',
)

cc_binary(
    name = "libmodel.so",
    srcs = ["graph.cc"],
    deps = [":graph", "//third_party/eigen3"],
    linkopts = ["-lpthread"],
    linkshared = 1,
    copts = ["-fPIC"],
)

The tf_library macro will compile the subgraph, its weights, and all required Tensorflow operations into an optimized static library called libgraph.a and associated header files. The cc_binary macro will compile the following little C++ wrapper around this library into a dynamic library called libmodel.so.

Make sure the code for this wrapper is present in in a file called graph.cc in the working directory (which should be the tensorflow repository root):

#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL

#include "graph.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

extern "C" int run(
		float *input,
		float *output_gender,
		float *output_age,
		int input_size,
		int output_gender_size,
		int output_age_size
    ){
	// allocate an instance of the Graph, along with all its private buffers
	Eigen::ThreadPool tp(std::thread::hardware_concurrency());
	Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
	Graph graph;
	graph.set_thread_pool(&device);

	// copy over the input buffer
	std::copy(input, input + input_size, graph.arg0_data());

	// execute the inference graph
	auto ok = graph.Run();
	if (not ok) return -1;

	// copy the result into the output buffer
	std::copy(graph.result0_data(), graph.result0_data() + output_gender_size, output_gender);
	std::copy(graph.result1_data(), graph.result1_data() + output_age_size, output_age);
	return 0;
}

Make sure the optimized_age_gender_model.pb file is also in the working directory.

Next, we can kick off the build using:

bazel build --show_progress_rate_limit=60 @org_tensorflow//:libmodel.so

Note that this will take a while as quite a significant part of Tensorflow may need to be compiled, depending on your model. We’re talking 1300+ seconds on a 16-core VM. Probably less if you provision more cores.

When compilation is finished, you’ll find the relevant object files in tensorflow/bazel-bin/external/org_tensorflow

Invoke the inference graph from Rust code

Let’s write the Rust code that will invoke the compiled graph from the dynamic library libmodel.so we just created.

main.rs

extern crate image;
extern crate rulinalg;
use std::os::raw::c_float;
use std::os::raw::c_int;
use image::DynamicImage;
use std::env;
use rulinalg::utils;

// Here we tell the compiler to link to the (lib)model library using a link attribute.
// Note that this attribute by default assumes that 
// the link target is a dynamic library instead of a static one.
#[link(name = "model")]
extern {
    // Note that C-types are used along with raw pointers to define the external interface.
    fn run(
      input: *const c_float,
      output_gender: *mut c_float,
      output_age: *mut c_float,

      // As only pointers to the arrays are passed, 
      // we need to tell the receiving code about the size of these arrays.
      input_size: c_int,
      output_gender_size: c_int,
      output_age_size: c_int
     ) -> c_int;
}

// Convert ('flatten') the image to a fixed-size array of c_floats
// The size of this array depends on the dimensions of you input tensor.
fn image_to_carray(img: DynamicImage) -> [c_float; 12288]{
  let rgb_img = img.to_rgb();

  // verify image dimensions
  assert_eq!(12288, rgb_img.height() * rgb_img.width() * 3);

  // get the raw buffer (a Vec[u8]) underpinning the RGB image
  let rawbuf = rgb_img.into_raw();
  let mut input: [c_float; 12288] = [1.0; 12288];

  // Copy the raw buffer over to the target array.
  // Note that we don't change the array layout of the 
  // raw image buffer, which already happens to be in the correct order.
  // Note: this order may well be different for your model.
  for i in 0..input.len(){
    input[i] = rawbuf[i] as c_float
  }

  input
}

fn main() {

  // grab the path to the input image from the command line
  let args: Vec<String> = env::args().collect();
  let input_path = &args[1];

  let input = image_to_carray(image::open(input_path).unwrap());

  let mut output_gender: [c_float; 2] = [0.0; 2];
  let mut output_age: [c_float; 101] = [0.0; 101];

  let exit_code = unsafe { 
      run(
        input.as_ptr(),
        output_gender.as_mut_ptr(),
        output_age.as_mut_ptr(),
        input.len() as c_int, 
        output_gender.len() as c_int,
        output_age.len() as c_int,
      )
    };

    if exit_code != 0{
      panic!("nonzero exit code of graph invocation")
    }

  println!("gender probabilities: {:?}", output_gender);

  let (argmax, maxprob) = utils::argmax(&output_age);
  println!("age bin with highest probability: {} (probability: {:?})", argmax, maxprob);
    
}

Let’s also line up the right cargo boilerplate so we don’t have to link to the image crate manually. Run cargo init in a working directory of choice and add the dependencies to your Cargo.toml file.

Cargo.toml

[package]
name = "rust-ffi-example"
version = "0.1.0"
authors = ["jur"]
edition = "2018"

[dependencies]
image = "0.21.0"
rulinalg = "0.4.2"

[[bin]]
name = "rust-ffi-example"
path = "main.rs"

Next, copy over the libmodel.so file that was compiled in the previous section to your new working directory. For convenience, let’s copy it next to the main.rs code.

cp tensorflow/bazel-bin/external/org_tensorflow/libmodel.so .

Now we can compile our Rust wrapper. To ensure the created binary can find the libmodel.so at runtime, we can specify linker arguments. We use $ORIGIN to have the linker look for the lib in the directory relative to the executable itself (not the working directory).

cargo rustc -- -L. -C link-args='-Wl,-rpath,$ORIGIN/../../'

After successfully compiling the wrapper, we can now execute it on an arbitrary image of size 64 x 64.

Let’s go with this image of Prison Mike, cropped to absolute perfection.

Lets execute the Rust binary we just built and time it:

time ./target/debug/rust-ffi-example ./prison_mike.png
#gender probabilities: [0.011825411, 0.98817456]
#age bin with highest probability: 45 (probability: 0.13215715)

#real	0m0.663s
#user	0m0.647s
#sys	0m0.016s

That’s it! Prison Mike looks like a male (second gender probability is considerably higher than the first) and his age also looks like what I’d guess myself.

After all this, we’re left with libmodel.so (97MB) and our wrapper binary ./target/debug/rust-ffi-example (17MB, including debug flags). Thats just 114MB to run a fully fledged computer vision network, without any dependencies! IMO that’s pretty impressive, and I hope the XLA tooling will get a lot more love from the Tensorflow team in the future.

Limitations

  • AOT-compilation is CPU-only.
  • Verify that the model you’re trying to compile contains only operations supported by XLA. See the list here.
  • The C++ wrapper I wrote here (graph.cc) is based on the one shown in the official tutorial. It is pretty basic, and allocates a new instance of the computation Graph with each call, which is wasteful.
  • Instead of linking the Rust wrapper to a dynamic library object file (libmodel.so) we may be able to directly link it to the static library libgraph.a instead, producing only a single executable binary. This way we don’t have to keep libmodel.so around. For this we would need to fiddle with header files, however, and generate more complex FFI bindings. This may be a nice job for rust-bindgen.
  • I have only tested the above on Ubuntu 18.04